//===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements rewriting rules that are specific to sparse tensors. // //===----------------------------------------------------------------------===// #include "Utils/CodegenUtils.h" #include "Utils/LoopEmitter.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Matchers.h" #include "mlir/Support/LLVM.h" usingnamespacemlir; usingnamespacemlir::bufferization; usingnamespacemlir::linalg; usingnamespacemlir::sparse_tensor; //===---------------------------------------------------------------------===// // Helper methods for the actual rewriting rules. //===---------------------------------------------------------------------===// // Helper method to match any typed zero. static bool isZeroValue(Value val) { … } // Helper to detect a sparse tensor type operand. static bool isSparseTensor(Value v) { … } static bool isSparseTensor(OpOperand *op) { … } // Helper method to find zero/uninitialized tensor materialization. static bool isMaterializing(OpOperand *op, bool isZero) { … } // Helper to detect sampling operation. static bool isSampling(GenericOp op) { … } // Helper to detect chain of multiplications that do not involve x. static bool isMulChain(Value val, Value x) { … } // Helper to detect x = x + <multiplications>. static bool isSumOfMul(GenericOp op) { … } // Helper to detect direct yield of a zero value. static bool isZeroYield(GenericOp op) { … } /// Populates given sizes array from type (for static sizes) and from /// the tensor (for dynamic sizes). static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes, Location loc, ShapedType stp, Value tensor) { … } static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO) { … } /// Collects the dynamic dimension sizes for `tp` with the assumption that /// `sizes` are the dimension sizes for the type. Stores the dynamic dimension /// sizes to dynSizes. static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl<Value> &dynSizes) { … } static LogicalResult genForeachOnSparseConstant(ForeachOp op, RewriterBase &rewriter, SparseElementsAttr attr) { … } /// Populates the given sizes array for concatenation from types (for static /// sizes) and from the source tensors (for dynamic sizes). static void concatSizesFromInputs(OpBuilder &builder, SmallVectorImpl<Value> &sizes, Location loc, ShapedType dstTp, ValueRange srcs, unsigned dim) { … } //===---------------------------------------------------------------------===// // The actual sparse tensor rewriting rules. //===---------------------------------------------------------------------===// namespace { /// TODO: move it to tensor dialect instead. /// /// Fold `tensor.concat` and `tensor.extract_slice` /// /// %concat = tensor.concat dim(2) %t0, %t1 /// : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32> /// %extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1] /// : tensor<1x64x2xf32> to tensor<1x64x1xf32> /// %extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1] /// : tensor<1x64x2xf32> to tensor<1x64x1xf32> /// /// Becomes /// /// %extract0, %extract1 = %t0, %t1 struct FuseExtractSliceWithConcat : public OpRewritePattern<tensor::ExtractSliceOp> { … }; /// Rewriting rule that fuses sparse_tensor.convert into producer. struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> { … }; /// Rewriting rule that converts direct yield of zero with initial allocation. struct FoldInvariantYield : public OpRewritePattern<GenericOp> { … }; /// Rewriting rule that converts two kernels: /// /// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... ) /// X(i,j) = S(i,j) * T(i,j) /// /// into a single kernel, using distributive law: /// /// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... ) /// /// This kind of fusion (merging two ops into one but using arithmetic /// equalities that may not hold for floating-point computations) would /// be undesirable in the dense case, since we distribute the multiplication /// into the reduction loop. However, for sparse sampling tensor S, such /// a fusion may actually reduce the asymptotic complexity of the kernel, /// since intermediate results may be nullified. struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> { … }; // Fuse a tensor cast into producing operation. Note that a tensor.cast // should really not be used to convert between sparse encodings. Since // the pattern currently appears as a result of some prior rewriting // we make an attempt to repair very obvious cases. // TODO: audit the pure tensor dialect rewriting rules struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> { … }; /// Rewrites a sequence of operations for sparse tensor selections in to /// semi-ring operations such that they can be compiled correctly by the /// sparsifier. E.g., transforming the following sequence /// /// %sel = arith.select %cond, %sp1, %sp2 /// /// to /// /// %sel = binary %sp1, %sp2: /// both (%l, %r) {yield select %cond, %l, %r} /// left (%l) {yield select %cond, %l, 0} /// right (%r) {yield select %cond, 0, %r} /// /// TODO: We require that the tensor used for extracting conditions to be dense /// to sparsify the code. To support a sparse condition tensor, we need a /// tri-nary operation. struct GenSemiRingSelect : public OpRewritePattern<GenericOp> { … }; /// Rewrites a sparse reduction that would not sparsify directly since /// doing so would only iterate over the stored elements, ignoring the /// implicit zeros, into a semi-ring. Applies to all prod/and/min/max /// (note that reductions like add/sub/or/xor can directly be sparsified /// since the implicit zeros do not contribute to the final result). /// Note that prod/and are still included since, even though they often /// are nullified in sparse data, they may still occur for special /// situations in which e.g. some rows in a sparse matrix are fully /// dense. For min/max, including the implicit zeros is a much more /// common situation. /// /// TODO: this essentially "densifies" the operation; we want to implement /// this much more efficiently by performing the reduction over the /// stored values, and feed in the zero once if there were *any* /// implicit zeros as well; but for now, at least we provide /// the functionality /// struct GenSemiRingReduction : public OpRewritePattern<GenericOp> { … }; /// Sparse rewriting rule for the print operator. This operation is mainly used /// for debugging and testing. As such, it lowers to the vector.print operation /// which only require very light-weight runtime support. struct PrintRewriter : public OpRewritePattern<PrintOp> { … }; /// Sparse rewriting rule for sparse-to-sparse reshape operator. struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> { … }; /// Sparse rewriting rule for sparse-to-sparse reshape operator. template <typename ReshapeOp> struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> { … }; /// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape /// operator. template <typename ReshapeOp> struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> { … }; // A trivial wrapper to help generate different operations for dense/sparse // tensors. struct TensorLike { … }; struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> { … }; struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> { … }; struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> { … }; struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> { … }; /// Sparse rewriting rule for the foreach operator. struct ForeachRewriter : public OpRewritePattern<ForeachOp> { … }; /// Sparse rewriting rule for the new operator. struct NewRewriter : public OpRewritePattern<NewOp> { … }; /// Sparse rewriting rule for the out operator. struct OutRewriter : public OpRewritePattern<OutOp> { … }; } // namespace //===---------------------------------------------------------------------===// // Methods that add patterns described in this file to a pattern list. //===---------------------------------------------------------------------===// void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) { … } void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert) { … } void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) { … }