//===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // A pass that converts sparse tensor primitives into calls into a runtime // support library. Sparse tensor types are converted into opaque pointers // to the underlying sparse storage schemes. The use of opaque pointers // together with runtime support library keeps the conversion relatively // simple, but at the expense of IR opacity, which obscures opportunities // for subsequent optimization of the IR. An alternative is provided by // the SparseTensorCodegen pass. // //===----------------------------------------------------------------------===// #include "Utils/CodegenUtils.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/Enums.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" usingnamespacemlir; usingnamespacemlir::sparse_tensor; namespace { //===----------------------------------------------------------------------===// // Helper methods. //===----------------------------------------------------------------------===// /// Maps each sparse tensor type to an opaque pointer. static std::optional<Type> convertSparseTensorTypes(Type type) { … } /// Generates call to lookup a level-size. N.B., this only generates /// the raw function call, and therefore (intentionally) does not perform /// any dim<->lvl conversion or other logic. static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor, uint64_t lvl) { … } /// Generates call to lookup a dimension-size. N.B., this only generates /// the raw function call, and therefore (intentionally) does not perform /// any dim<->lvl conversion or other logic. static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor, uint64_t dim) { … } /// Looks up a level-size by returning a statically-computed constant /// (when possible), or by calling `genLvlSizeCall` (when dynamic). static Value createOrFoldLvlCall(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, Level lvl) { … } /// Looks up a dimension-size by returning a constant from the shape /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes /// of dense tensors). static Value createOrFoldDimCall(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, Dimension dim) { … } /// Populates the array with the dimension-sizes of the given tensor. static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl<Value> &out) { … } /// Returns an array with the dimension-sizes of the given tensor. /// If the *tensor* parameters is null, the tensor type is assumed to have a /// static shape. static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor = Value()) { … } /// Generates an uninitialized buffer of the given size and type, /// but returns it as type `memref<? x $tp>` (rather than as type /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack, /// this buffer must be explicitly deallocated by client. static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) { … } /// Generates a temporary buffer for the level-types of the given encoding. static Value genLvlTypesBuffer(OpBuilder &builder, Location loc, SparseTensorType stt) { … } /// Extracts the bare (aligned) pointers that point to the tensor. static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc, Value tensor) { … } /// Generates a temporary buffer for the level-types of the given encoding. static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc, ValueRange lvlTensors, Value valTensor) { … } /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`: /// the "swiss army knife" method of the sparse runtime support library /// for materializing sparse tensors into the computation. This abstraction /// reduces the need for modifications when the API changes. class NewCallParams final { … }; /// Generates a call to obtain the values array. static Value genValuesCall(OpBuilder &builder, Location loc, SparseTensorType stt, Value ptr) { … } /// Generates a call to obtain the positions array. static Value genPositionsCall(OpBuilder &builder, Location loc, SparseTensorType stt, Value ptr, Level l) { … } /// Generates a call to obtain the coordinates array. static Value genCoordinatesCall(OpBuilder &builder, Location loc, SparseTensorType stt, Value ptr, Level l) { … } /// Generates a call to obtain the coordinates array (AoS view). static Value genCoordinatesBufferCall(OpBuilder &builder, Location loc, SparseTensorType stt, Value ptr, Level l) { … } //===----------------------------------------------------------------------===// // Conversion rules. //===----------------------------------------------------------------------===// /// Sparse conversion rule for returns. class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { … }; /// Sparse conversion rule for accessing level-sizes. class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> { … }; /// Sparse conversion rule for trivial tensor casts. class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { … }; class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> { … }; /// Sparse conversion rule for the new operator. class SparseTensorNewConverter : public OpConversionPattern<NewOp> { … }; /// Sparse conversion rule for the alloc operator. /// TODO(springerm): remove when bufferization.alloc_tensor is gone class SparseTensorAllocConverter : public OpConversionPattern<bufferization::AllocTensorOp> { … }; /// Sparse conversion rule for the empty tensor. class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> { … }; /// Sparse conversion rule for the convert operator. class SparseTensorReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> { … }; /// Sparse conversion rule for the dealloc operator. class SparseTensorDeallocConverter : public OpConversionPattern<bufferization::DeallocTensorOp> { … }; /// Sparse conversion rule for position accesses. class SparseTensorToPositionsConverter : public OpConversionPattern<ToPositionsOp> { … }; /// Sparse conversion rule for coordinate accesses. class SparseTensorToCoordinatesConverter : public OpConversionPattern<ToCoordinatesOp> { … }; /// Sparse conversion rule for coordinate accesses (AoS style). class SparseToCoordinatesBufferConverter : public OpConversionPattern<ToCoordinatesBufferOp> { … }; /// Sparse conversion rule for value accesses. class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { … }; /// Sparse conversion rule for number of entries operator. class SparseNumberOfEntriesConverter : public OpConversionPattern<NumberOfEntriesOp> { … }; /// Sparse conversion rule for tensor rematerialization. class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { … }; /// Sparse conversion rule for the insertion operator. class SparseTensorInsertConverter : public OpConversionPattern<tensor::InsertOp> { … }; /// Sparse conversion rule for the expand operator. class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> { … }; /// Sparse conversion rule for the compress operator. class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> { … }; /// Sparse conversion rule for the sparse_tensor.assemble operator. class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> { … }; /// Sparse conversion rule for the sparse_tensor.disassemble operator. /// Note that the current implementation simply exposes the buffers to /// the external client. This assumes the client only reads the buffers /// (usually copying it to the external data structures, such as numpy /// arrays). The semantics of the disassemble operation technically /// require that the copying is done here already using the out-levels /// and out-values clause. class SparseTensorDisassembleConverter : public OpConversionPattern<DisassembleOp> { … }; struct SparseHasRuntimeLibraryConverter : public OpConversionPattern<HasRuntimeLibraryOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Sparse tensor type conversion into opaque pointer. //===----------------------------------------------------------------------===// mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() { … } //===----------------------------------------------------------------------===// // Public method for populating conversion rules. //===----------------------------------------------------------------------===// /// Populates the given patterns list with conversion rules required for /// the sparsification of linear algebra operations. void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) { … }