llvm/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

//===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <utility>

#include "Detail/DimLvlMapParser.h"

#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/Bitset.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"

// Forward declarations, following custom print/parsing methods are referenced
// by the generated code for SparseTensorTypes.td.
static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
                                         mlir::sparse_tensor::Level &,
                                         mlir::sparse_tensor::Level &);
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level,
                            mlir::sparse_tensor::Level);

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"

usingnamespacemlir;
usingnamespacemlir::sparse_tensor;

// Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
// well.
namespace mlir::sparse_tensor {
llvm::hash_code hash_value(LevelType lt) {}
} // namespace mlir::sparse_tensor

//===----------------------------------------------------------------------===//
// Local Convenience Methods.
//===----------------------------------------------------------------------===//

static constexpr bool acceptBitWidth(unsigned bitWidth) {}

static SmallVector<Size>
getSparseFieldShape(const SparseTensorEncodingAttr enc,
                    std::optional<ArrayRef<int64_t>> dimShape) {}

//===----------------------------------------------------------------------===//
// SparseTensorDialect StorageLayout.
//===----------------------------------------------------------------------===//

static constexpr Level kInvalidLevel =;
static constexpr Level kInvalidFieldIndex =;
static constexpr FieldIndex kDataFieldStartingIdx =;

void StorageLayout::foreachField(
    llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
                            LevelType)>
        callback) const {}

void sparse_tensor::foreachFieldAndTypeInSparseTensor(
    SparseTensorType stt,
    llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
                            LevelType)>
        callback) {}

unsigned StorageLayout::getNumFields() const {}

unsigned StorageLayout::getNumDataFields() const {}

std::pair<FieldIndex, unsigned>
StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
                                      std::optional<Level> lvl) const {}

//===----------------------------------------------------------------------===//
// SparseTensorDialect Attribute Methods.
//===----------------------------------------------------------------------===//

std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {}

std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {}

std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {}

std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {}

bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {}

std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {}

void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {}

void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {}

static ParseResult parseOptionalStaticSlice(int64_t &result,
                                            AsmParser &parser) {}

Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {}

LogicalResult
SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                                 int64_t offset, int64_t size, int64_t stride) {}

SparseTensorEncodingAttr
SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {}

SparseTensorEncodingAttr
SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {}

SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {}

SparseTensorEncodingAttr
SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
                                        unsigned crdWidth) const {}

SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {}

SparseTensorEncodingAttr
SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {}

SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {}

SparseTensorEncodingAttr
SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {}

SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {}

SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
    ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {}

SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {}

uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {}

bool SparseTensorEncodingAttr::isAllDense() const {}

bool SparseTensorEncodingAttr::isAllOrdered() const {}

Type SparseTensorEncodingAttr::getCrdElemType() const {}

Type SparseTensorEncodingAttr::getPosElemType() const {}

MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
    std::optional<ArrayRef<int64_t>> dimShape) const {}

MemRefType SparseTensorEncodingAttr::getPosMemRefType(
    std::optional<ArrayRef<int64_t>> dimShape) const {}

bool SparseTensorEncodingAttr::isIdentity() const {}

bool SparseTensorEncodingAttr::isPermutation() const {}

Dimension SparseTensorEncodingAttr::getDimRank() const {}

Level SparseTensorEncodingAttr::getLvlRank() const {}

LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {}

bool SparseTensorEncodingAttr::isSlice() const {}

SparseTensorDimSliceAttr
SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {}

std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {}

std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {}

std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {}

std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {}

SmallVector<int64_t>
SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
                                         CrdTransDirectionKind dir) const {}

ValueRange
SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
                                        ValueRange crds,
                                        CrdTransDirectionKind dir) const {}

Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {}

void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {}

void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
                                            AsmPrinter &printer) const {}

void SparseTensorEncodingAttr::printDimensions(
    AffineMap &map, AsmPrinter &printer,
    ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {}

void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
                                           ArrayRef<LevelType> lvlTypes) const {}

LogicalResult SparseTensorEncodingAttr::verify(
    function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes,
    AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
    unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
    ArrayRef<SparseTensorDimSliceAttr> dimSlices) {}

LogicalResult SparseTensorEncodingAttr::verifyEncoding(
    ArrayRef<Size> dimShape, Type elementType,
    function_ref<InFlightDiagnostic()> emitError) const {}

Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {}

SmallVector<COOSegment>
mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {}

//===----------------------------------------------------------------------===//
// SparseTensorType Methods.
//===----------------------------------------------------------------------===//

bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
                                                      bool isUnique) const {}

RankedTensorType
mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {}

//===----------------------------------------------------------------------===//
// Convenience Methods.
//===----------------------------------------------------------------------===//

SparseTensorEncodingAttr
mlir::sparse_tensor::getSparseTensorEncoding(Type type) {}

AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
                                             MLIRContext *context) {}

AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
                                                    MLIRContext *context) {}

SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) {}

bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {}

bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {}

Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {}

Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {}

/// We normalized sparse tensor encoding attribute by always using
/// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
/// as other variants) lead to the same storage specifier type, and stripping
/// irrelevant fields that do not alter the sparse tensor memory layout.
static SparseTensorEncodingAttr
getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {}

StorageSpecifierType
StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {}

StorageSpecifierType
StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError,
                                 MLIRContext *ctx,
                                 SparseTensorEncodingAttr encoding) {}

//===----------------------------------------------------------------------===//
// SparseTensorDialect Operations.
//===----------------------------------------------------------------------===//

static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {}

static LogicalResult isMatchingWidth(Value mem, unsigned width) {}

static LogicalResult verifySparsifierGetterSetter(
    StorageSpecifierKind mdKind, std::optional<Level> lvl,
    TypedValue<StorageSpecifierType> md, Operation *op) {}

static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) {}

static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
                                      SparseTensorType stt,
                                      RankedTensorType valTp,
                                      TypeRange lvlTps) {}

LogicalResult AssembleOp::verify() {}

LogicalResult DisassembleOp::verify() {}

LogicalResult ConvertOp::verify() {}

OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {}

bool ConvertOp::needsExtraSort() {}

LogicalResult CrdTranslateOp::verify() {}

LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
                                   SmallVectorImpl<OpFoldResult> &results) {}

void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
                  int64_t index) {}

LogicalResult LvlOp::verify() {}

std::optional<uint64_t> LvlOp::getConstantLvlIndex() {}

Speculation::Speculatability LvlOp::getSpeculatability() {}

OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {}

void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                             SparseTensorEncodingAttr dstEnc, Value source) {}

LogicalResult ReinterpretMapOp::verify() {}

OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {}

template <typename ToBufferOp>
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
                                           OpaqueProperties prop,
                                           RegionRange region,
                                           SmallVectorImpl<mlir::Type> &ret) {}

LogicalResult ToPositionsOp::verify() {}

LogicalResult
ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
                                ValueRange ops, DictionaryAttr attr,
                                OpaqueProperties prop, RegionRange region,
                                SmallVectorImpl<mlir::Type> &ret) {}

LogicalResult ToCoordinatesOp::verify() {}

LogicalResult
ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
                                  ValueRange ops, DictionaryAttr attr,
                                  OpaqueProperties prop, RegionRange region,
                                  SmallVectorImpl<mlir::Type> &ret) {}

LogicalResult ToCoordinatesBufferOp::verify() {}

LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
    MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
    DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
    SmallVectorImpl<mlir::Type> &ret) {}

LogicalResult ToValuesOp::verify() {}

LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
                                           std::optional<Location> loc,
                                           ValueRange ops, DictionaryAttr attr,
                                           OpaqueProperties prop,
                                           RegionRange region,
                                           SmallVectorImpl<mlir::Type> &ret) {}

LogicalResult ToSliceOffsetOp::verify() {}

LogicalResult ToSliceStrideOp::verify() {}

LogicalResult GetStorageSpecifierOp::verify() {}

template <typename SpecifierOp>
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {}

OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {}

LogicalResult SetStorageSpecifierOp::verify() {}

template <class T>
static LogicalResult verifyNumBlockArgs(T *op, Region &region,
                                        const char *regionName,
                                        TypeRange inputTypes, Type outputType) {}

LogicalResult BinaryOp::verify() {}

LogicalResult UnaryOp::verify() {}

bool ConcatenateOp::needsExtraSort() {}

LogicalResult ConcatenateOp::verify() {}

void PushBackOp::build(OpBuilder &builder, OperationState &result,
                       Value curSize, Value inBuffer, Value value) {}

LogicalResult PushBackOp::verify() {}

LogicalResult CompressOp::verify() {}

void ForeachOp::build(
    OpBuilder &builder, OperationState &result, Value tensor,
    ValueRange initArgs, AffineMapAttr order,
    function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
        bodyBuilder) {}

LogicalResult ForeachOp::verify() {}

OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {}

LogicalResult ReorderCOOOp::verify() {}

LogicalResult ReduceOp::verify() {}

LogicalResult SelectOp::verify() {}

LogicalResult SortOp::verify() {}

//===----------------------------------------------------------------------===//
// Sparse Tensor Iteration Operations.
//===----------------------------------------------------------------------===//

IterSpaceType IteratorType::getIterSpaceType() const {}

IteratorType IterSpaceType::getIteratorType() const {}

/// Parses a level range in the form "$lo `to` $hi"
/// or simply "$lo" if $hi - $lo = 1
static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
                                   Level &lvlHi) {}

/// Parses a level range in the form "$lo `to` $hi"
/// or simply "$lo" if $hi - $lo = 1
static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
                                   IntegerAttr &lvlHiAttr) {}

/// Prints a level range in the form "$lo `to` $hi"
/// or simply "$lo" if $hi - $lo = 1
static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {}

/// Prints a level range in the form "$lo `to` $hi"
/// or simply "$lo" if $hi - $lo = 1
static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
                            IntegerAttr lvlHi) {}

/// Parses a list of `optional` defined list in the form of
/// "(%val0, _, %val1, ...)", where `_` is used to annotate that the
/// corresponding value is not defined (e.g., to represent an undefined
/// coordinate in the sparse iteration space).
static ParseResult parseOptionalDefinedList(
    OpAsmParser &parser, OperationState &state, I64BitSet &definedSet,
    SmallVectorImpl<OpAsmParser::Argument> &definedArgs,
    unsigned maxCnt = std::numeric_limits<unsigned>::max(),
    OpAsmParser::Delimiter delimiter = OpAsmParser::Delimiter::Paren) {}

static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size,
                                     Block::BlockArgListType blocksArgs,
                                     I64BitSet definedSet) {}

static ParseResult
parseUsedCoordList(OpAsmParser &parser, OperationState &state,
                   SmallVectorImpl<OpAsmParser::Argument> &coords) {}

static ParseResult
parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
                       SmallVectorImpl<OpAsmParser::Argument> &iterators,
                       SmallVectorImpl<OpAsmParser::Argument> &blockArgs) {}

static ParseResult
parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state,
                         SmallVectorImpl<Value> &spacesVals,
                         SmallVectorImpl<OpAsmParser::Argument> &blockArgs) {}

LogicalResult ExtractIterSpaceOp::inferReturnTypes(
    MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
    DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
    SmallVectorImpl<mlir::Type> &ret) {}

LogicalResult ExtractIterSpaceOp::verify() {}

LogicalResult ExtractValOp::verify() {}

struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {};

void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
                                            mlir::MLIRContext *context) {}

void IterateOp::build(OpBuilder &builder, OperationState &odsState,
                      Value iterSpace, ValueRange initArgs) {}

void IterateOp::build(OpBuilder &builder, OperationState &odsState,
                      Value iterSpace, ValueRange initArgs,
                      I64BitSet crdUsedLvls) {}

ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {}

/// Prints the initialization list in the form of
///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
/// where 'inner' values are assumed to be region arguments and 'outer' values
/// are regular SSA values.
static void printInitializationList(OpAsmPrinter &p,
                                    Block::BlockArgListType blocksArgs,
                                    ValueRange initializers,
                                    StringRef prefix = "") {}

template <typename SparseLoopOp>
static LogicalResult verifySparseLoopOp(SparseLoopOp op) {}

LogicalResult IterateOp::verify() {}
LogicalResult CoIterateOp::verify() {}

void IterateOp::print(OpAsmPrinter &p) {}

LogicalResult IterateOp::verifyRegions() {}

/// OpInterfaces' methods implemented by IterateOp.
SmallVector<Region *> IterateOp::getLoopRegions() {}

MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {}

Block::BlockArgListType IterateOp::getRegionIterArgs() {}

std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {}

std::optional<ResultRange> IterateOp::getLoopResults() {}

OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {}

void IterateOp::getSuccessorRegions(RegionBranchPoint point,
                                    SmallVectorImpl<RegionSuccessor> &regions) {}

void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
                        ValueRange iterSpaces, ValueRange initArgs,
                        unsigned numCases) {}

ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {}

void CoIterateOp::print(OpAsmPrinter &p) {}

ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {}

LogicalResult CoIterateOp::verifyRegions() {}

SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {}

//===----------------------------------------------------------------------===//
// Sparse Tensor Dialect Setups.
//===----------------------------------------------------------------------===//

/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
                                                    Attribute value, Type type,
                                                    Location loc) {}

namespace {
struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface {};
} // namespace

void SparseTensorDialect::initialize() {}

#define GET_OP_CLASSES
#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"

#include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"