#include "Utils/CodegenUtils.h"
#include "Utils/SparseTensorDescriptor.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include <optional>
usingnamespacemlir;
usingnamespacemlir::sparse_tensor;
static void flattenOperands(ValueRange operands,
SmallVectorImpl<Value> &flattened) { … }
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) { … }
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
Value idx) { … }
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
MutableArrayRef<Value> fields,
Value lower = Value()) { … }
static void createPushback(OpBuilder &builder, Location loc,
MutSparseTensorDescriptor desc,
SparseTensorFieldKind kind, std::optional<Level> lvl,
Value value, Value repeat = Value()) { … }
static void allocSchemeForRank(OpBuilder &builder, Location loc,
MutSparseTensorDescriptor desc, Level startLvl) { … }
static Value createAllocation(OpBuilder &builder, Location loc,
MemRefType memRefType, Value sz,
bool enableInit) { … }
static void createDimSizes(OpBuilder &builder, Location loc,
SparseTensorType stt, ValueRange dynSizes,
SmallVectorImpl<Value> &dimSizesValues) { … }
static void createAllocFields(OpBuilder &builder, Location loc,
SparseTensorType stt, bool enableInit,
Value sizeHint,
SmallVectorImpl<Value> &lvlSizesValues,
SmallVectorImpl<Value> &fields) { … }
static Value genCompressed(OpBuilder &builder, Location loc,
MutSparseTensorDescriptor desc, ValueRange lvlCoords,
Value , Value parentPos, Level lvl) { … }
static void genEndInsert(OpBuilder &builder, Location loc,
SparseTensorDescriptor desc) { … }
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
Value sz) { … }
static SmallVector<ReassociationIndices>
getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) { … }
namespace {
class SparseInsertGenerator
: public FuncCallOrInlineGenerator<SparseInsertGenerator> { … };
class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { … };
class SparseCallConverter : public OpConversionPattern<func::CallOp> { … };
class SparseLvlOpConverter : public OpConversionPattern<LvlOp> { … };
struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> { … };
template <typename Op, StorageSpecifierKind kind>
class SparseSliceGetterOpConverter : public OpConversionPattern<Op> { … };
class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { … };
class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> { … };
class SparseTensorAllocConverter
: public OpConversionPattern<bufferization::AllocTensorOp> { … };
class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> { … };
class SparseTensorDeallocConverter
: public OpConversionPattern<bufferization::DeallocTensorOp> { … };
class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { … };
class SparseExpandConverter : public OpConversionPattern<ExpandOp> { … };
class SparseCompressConverter : public OpConversionPattern<CompressOp> { … };
class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> { … };
class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> { … };
class SparseToCoordinatesConverter
: public OpConversionPattern<ToCoordinatesOp> { … };
class SparseToCoordinatesBufferConverter
: public OpConversionPattern<ToCoordinatesBufferOp> { … };
class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> { … };
class SparseConvertConverter : public OpConversionPattern<ConvertOp> { … };
class SparseExtractSliceConverter
: public OpConversionPattern<tensor::ExtractSliceOp> { … };
class SparseNumberOfEntriesConverter
: public OpConversionPattern<NumberOfEntriesOp> { … };
struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> { … };
struct SparseDisassembleOpConverter
: public OpConversionPattern<DisassembleOp> { … };
struct SparseNewConverter : public OpConversionPattern<NewOp> { … };
struct SparseHasRuntimeLibraryConverter
: public OpConversionPattern<HasRuntimeLibraryOp> { … };
}
void mlir::populateSparseTensorCodegenPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns,
bool createSparseDeallocs, bool enableBufferInitialization) { … }