#include <utility>
#include "Detail/DimLvlMapParser.h"
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/Bitset.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
mlir::sparse_tensor::Level &,
mlir::sparse_tensor::Level &);
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level,
mlir::sparse_tensor::Level);
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
usingnamespacemlir;
usingnamespacemlir::sparse_tensor;
namespace mlir::sparse_tensor {
llvm::hash_code hash_value(LevelType lt) { … }
}
static constexpr bool acceptBitWidth(unsigned bitWidth) { … }
static SmallVector<Size>
getSparseFieldShape(const SparseTensorEncodingAttr enc,
std::optional<ArrayRef<int64_t>> dimShape) { … }
static constexpr Level kInvalidLevel = …;
static constexpr Level kInvalidFieldIndex = …;
static constexpr FieldIndex kDataFieldStartingIdx = …;
void StorageLayout::foreachField(
llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
LevelType)>
callback) const { … }
void sparse_tensor::foreachFieldAndTypeInSparseTensor(
SparseTensorType stt,
llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
LevelType)>
callback) { … }
unsigned StorageLayout::getNumFields() const { … }
unsigned StorageLayout::getNumDataFields() const { … }
std::pair<FieldIndex, unsigned>
StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
std::optional<Level> lvl) const { … }
std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) { … }
std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const { … }
std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const { … }
std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const { … }
bool SparseTensorDimSliceAttr::isCompletelyDynamic() const { … }
std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) { … }
void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const { … }
void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const { … }
static ParseResult parseOptionalStaticSlice(int64_t &result,
AsmParser &parser) { … }
Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) { … }
LogicalResult
SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
int64_t offset, int64_t size, int64_t stride) { … }
SparseTensorEncodingAttr
SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const { … }
SparseTensorEncodingAttr
SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const { … }
SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const { … }
SparseTensorEncodingAttr
SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
unsigned crdWidth) const { … }
SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const { … }
SparseTensorEncodingAttr
SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const { … }
SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const { … }
SparseTensorEncodingAttr
SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const { … }
SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const { … }
SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { … }
SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const { … }
uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const { … }
bool SparseTensorEncodingAttr::isAllDense() const { … }
bool SparseTensorEncodingAttr::isAllOrdered() const { … }
Type SparseTensorEncodingAttr::getCrdElemType() const { … }
Type SparseTensorEncodingAttr::getPosElemType() const { … }
MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
std::optional<ArrayRef<int64_t>> dimShape) const { … }
MemRefType SparseTensorEncodingAttr::getPosMemRefType(
std::optional<ArrayRef<int64_t>> dimShape) const { … }
bool SparseTensorEncodingAttr::isIdentity() const { … }
bool SparseTensorEncodingAttr::isPermutation() const { … }
Dimension SparseTensorEncodingAttr::getDimRank() const { … }
Level SparseTensorEncodingAttr::getLvlRank() const { … }
LevelType SparseTensorEncodingAttr::getLvlType(Level l) const { … }
bool SparseTensorEncodingAttr::isSlice() const { … }
SparseTensorDimSliceAttr
SparseTensorEncodingAttr::getDimSlice(Dimension dim) const { … }
std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const { … }
std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const { … }
std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const { … }
std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const { … }
SmallVector<int64_t>
SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
CrdTransDirectionKind dir) const { … }
ValueRange
SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
ValueRange crds,
CrdTransDirectionKind dir) const { … }
Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { … }
void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { … }
void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
AsmPrinter &printer) const { … }
void SparseTensorEncodingAttr::printDimensions(
AffineMap &map, AsmPrinter &printer,
ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { … }
void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
ArrayRef<LevelType> lvlTypes) const { … }
LogicalResult SparseTensorEncodingAttr::verify(
function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes,
AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
ArrayRef<SparseTensorDimSliceAttr> dimSlices) { … }
LogicalResult SparseTensorEncodingAttr::verifyEncoding(
ArrayRef<Size> dimShape, Type elementType,
function_ref<InFlightDiagnostic()> emitError) const { … }
Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const { … }
SmallVector<COOSegment>
mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const { … }
bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
bool isUnique) const { … }
RankedTensorType
mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const { … }
SparseTensorEncodingAttr
mlir::sparse_tensor::getSparseTensorEncoding(Type type) { … }
AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
MLIRContext *context) { … }
AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
MLIRContext *context) { … }
SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) { … }
bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) { … }
bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) { … }
Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) { … }
Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) { … }
static SparseTensorEncodingAttr
getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) { … }
StorageSpecifierType
StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) { … }
StorageSpecifierType
StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError,
MLIRContext *ctx,
SparseTensorEncodingAttr encoding) { … }
static LogicalResult lvlIsInBounds(Level lvl, Value tensor) { … }
static LogicalResult isMatchingWidth(Value mem, unsigned width) { … }
static LogicalResult verifySparsifierGetterSetter(
StorageSpecifierKind mdKind, std::optional<Level> lvl,
TypedValue<StorageSpecifierType> md, Operation *op) { … }
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) { … }
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
SparseTensorType stt,
RankedTensorType valTp,
TypeRange lvlTps) { … }
LogicalResult AssembleOp::verify() { … }
LogicalResult DisassembleOp::verify() { … }
LogicalResult ConvertOp::verify() { … }
OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) { … }
bool ConvertOp::needsExtraSort() { … }
LogicalResult CrdTranslateOp::verify() { … }
LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) { … }
void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
int64_t index) { … }
LogicalResult LvlOp::verify() { … }
std::optional<uint64_t> LvlOp::getConstantLvlIndex() { … }
Speculation::Speculatability LvlOp::getSpeculatability() { … }
OpFoldResult LvlOp::fold(FoldAdaptor adaptor) { … }
void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
SparseTensorEncodingAttr dstEnc, Value source) { … }
LogicalResult ReinterpretMapOp::verify() { … }
OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) { … }
template <typename ToBufferOp>
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
OpaqueProperties prop,
RegionRange region,
SmallVectorImpl<mlir::Type> &ret) { … }
LogicalResult ToPositionsOp::verify() { … }
LogicalResult
ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
ValueRange ops, DictionaryAttr attr,
OpaqueProperties prop, RegionRange region,
SmallVectorImpl<mlir::Type> &ret) { … }
LogicalResult ToCoordinatesOp::verify() { … }
LogicalResult
ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
ValueRange ops, DictionaryAttr attr,
OpaqueProperties prop, RegionRange region,
SmallVectorImpl<mlir::Type> &ret) { … }
LogicalResult ToCoordinatesBufferOp::verify() { … }
LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
SmallVectorImpl<mlir::Type> &ret) { … }
LogicalResult ToValuesOp::verify() { … }
LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
std::optional<Location> loc,
ValueRange ops, DictionaryAttr attr,
OpaqueProperties prop,
RegionRange region,
SmallVectorImpl<mlir::Type> &ret) { … }
LogicalResult ToSliceOffsetOp::verify() { … }
LogicalResult ToSliceStrideOp::verify() { … }
LogicalResult GetStorageSpecifierOp::verify() { … }
template <typename SpecifierOp>
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) { … }
OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) { … }
LogicalResult SetStorageSpecifierOp::verify() { … }
template <class T>
static LogicalResult verifyNumBlockArgs(T *op, Region ®ion,
const char *regionName,
TypeRange inputTypes, Type outputType) { … }
LogicalResult BinaryOp::verify() { … }
LogicalResult UnaryOp::verify() { … }
bool ConcatenateOp::needsExtraSort() { … }
LogicalResult ConcatenateOp::verify() { … }
void PushBackOp::build(OpBuilder &builder, OperationState &result,
Value curSize, Value inBuffer, Value value) { … }
LogicalResult PushBackOp::verify() { … }
LogicalResult CompressOp::verify() { … }
void ForeachOp::build(
OpBuilder &builder, OperationState &result, Value tensor,
ValueRange initArgs, AffineMapAttr order,
function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
bodyBuilder) { … }
LogicalResult ForeachOp::verify() { … }
OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) { … }
LogicalResult ReorderCOOOp::verify() { … }
LogicalResult ReduceOp::verify() { … }
LogicalResult SelectOp::verify() { … }
LogicalResult SortOp::verify() { … }
IterSpaceType IteratorType::getIterSpaceType() const { … }
IteratorType IterSpaceType::getIteratorType() const { … }
static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
Level &lvlHi) { … }
static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
IntegerAttr &lvlHiAttr) { … }
static void printLevelRange(AsmPrinter &p, Level lo, Level hi) { … }
static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
IntegerAttr lvlHi) { … }
static ParseResult parseOptionalDefinedList(
OpAsmParser &parser, OperationState &state, I64BitSet &definedSet,
SmallVectorImpl<OpAsmParser::Argument> &definedArgs,
unsigned maxCnt = std::numeric_limits<unsigned>::max(),
OpAsmParser::Delimiter delimiter = OpAsmParser::Delimiter::Paren) { … }
static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size,
Block::BlockArgListType blocksArgs,
I64BitSet definedSet) { … }
static ParseResult
parseUsedCoordList(OpAsmParser &parser, OperationState &state,
SmallVectorImpl<OpAsmParser::Argument> &coords) { … }
static ParseResult
parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
SmallVectorImpl<OpAsmParser::Argument> &iterators,
SmallVectorImpl<OpAsmParser::Argument> &blockArgs) { … }
static ParseResult
parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state,
SmallVectorImpl<Value> &spacesVals,
SmallVectorImpl<OpAsmParser::Argument> &blockArgs) { … }
LogicalResult ExtractIterSpaceOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
SmallVectorImpl<mlir::Type> &ret) { … }
LogicalResult ExtractIterSpaceOp::verify() { … }
LogicalResult ExtractValOp::verify() { … }
struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> { … };
void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
mlir::MLIRContext *context) { … }
void IterateOp::build(OpBuilder &builder, OperationState &odsState,
Value iterSpace, ValueRange initArgs) { … }
void IterateOp::build(OpBuilder &builder, OperationState &odsState,
Value iterSpace, ValueRange initArgs,
I64BitSet crdUsedLvls) { … }
ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) { … }
static void printInitializationList(OpAsmPrinter &p,
Block::BlockArgListType blocksArgs,
ValueRange initializers,
StringRef prefix = "") { … }
template <typename SparseLoopOp>
static LogicalResult verifySparseLoopOp(SparseLoopOp op) { … }
LogicalResult IterateOp::verify() { … }
LogicalResult CoIterateOp::verify() { … }
void IterateOp::print(OpAsmPrinter &p) { … }
LogicalResult IterateOp::verifyRegions() { … }
SmallVector<Region *> IterateOp::getLoopRegions() { … }
MutableArrayRef<OpOperand> IterateOp::getInitsMutable() { … }
Block::BlockArgListType IterateOp::getRegionIterArgs() { … }
std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() { … }
std::optional<ResultRange> IterateOp::getLoopResults() { … }
OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) { … }
void IterateOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) { … }
void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
ValueRange iterSpaces, ValueRange initArgs,
unsigned numCases) { … }
ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) { … }
void CoIterateOp::print(OpAsmPrinter &p) { … }
ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) { … }
LogicalResult CoIterateOp::verifyRegions() { … }
SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) { … }
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) { … }
namespace {
struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface { … };
}
void SparseTensorDialect::initialize() { … }
#define GET_OP_CLASSES
#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"