llvm/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

//===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements rewriting rules that are specific to sparse tensors.
//
//===----------------------------------------------------------------------===//

#include "Utils/CodegenUtils.h"
#include "Utils/LoopEmitter.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"

usingnamespacemlir;
usingnamespacemlir::bufferization;
usingnamespacemlir::linalg;
usingnamespacemlir::sparse_tensor;

//===---------------------------------------------------------------------===//
// Helper methods for the actual rewriting rules.
//===---------------------------------------------------------------------===//

// Helper method to match any typed zero.
static bool isZeroValue(Value val) {}

// Helper to detect a sparse tensor type operand.
static bool isSparseTensor(Value v) {}
static bool isSparseTensor(OpOperand *op) {}

// Helper method to find zero/uninitialized tensor materialization.
static bool isMaterializing(OpOperand *op, bool isZero) {}

// Helper to detect sampling operation.
static bool isSampling(GenericOp op) {}

// Helper to detect chain of multiplications that do not involve x.
static bool isMulChain(Value val, Value x) {}

// Helper to detect x = x + <multiplications>.
static bool isSumOfMul(GenericOp op) {}

// Helper to detect direct yield of a zero value.
static bool isZeroYield(GenericOp op) {}

/// Populates given sizes array from type (for static sizes) and from
/// the tensor (for dynamic sizes).
static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
                           Location loc, ShapedType stp, Value tensor) {}

static RankedTensorType getBufferType(const SparseTensorType &stt,
                                      bool needTmpCOO) {}

/// Collects the dynamic dimension sizes for `tp` with the assumption that
/// `sizes` are the dimension sizes for the type. Stores the dynamic dimension
/// sizes to dynSizes.
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes,
                            SmallVectorImpl<Value> &dynSizes) {}

static LogicalResult genForeachOnSparseConstant(ForeachOp op,
                                                RewriterBase &rewriter,
                                                SparseElementsAttr attr) {}

/// Populates the given sizes array for concatenation from types (for static
/// sizes) and from the source tensors (for dynamic sizes).
static void concatSizesFromInputs(OpBuilder &builder,
                                  SmallVectorImpl<Value> &sizes, Location loc,
                                  ShapedType dstTp, ValueRange srcs,
                                  unsigned dim) {}

//===---------------------------------------------------------------------===//
// The actual sparse tensor rewriting rules.
//===---------------------------------------------------------------------===//

namespace {

/// TODO: move it to tensor dialect instead.
///
/// Fold `tensor.concat` and `tensor.extract_slice`
///
/// %concat = tensor.concat dim(2) %t0, %t1
///   : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
/// %extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
///   : tensor<1x64x2xf32> to tensor<1x64x1xf32>
/// %extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
///   : tensor<1x64x2xf32> to tensor<1x64x1xf32>
///
/// Becomes
///
/// %extract0, %extract1 = %t0, %t1
struct FuseExtractSliceWithConcat
    : public OpRewritePattern<tensor::ExtractSliceOp> {};

/// Rewriting rule that fuses sparse_tensor.convert into producer.
struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {};

/// Rewriting rule that converts direct yield of zero with initial allocation.
struct FoldInvariantYield : public OpRewritePattern<GenericOp> {};

/// Rewriting rule that converts two kernels:
///
///      T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
///      X(i,j) = S(i,j) * T(i,j)
///
/// into a single kernel, using distributive law:
///
///      X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
///
/// This kind of fusion (merging two ops into one but using arithmetic
/// equalities that may not hold for floating-point computations) would
/// be undesirable in the dense case, since we distribute the multiplication
/// into the reduction loop. However, for sparse sampling tensor S, such
/// a fusion may actually reduce the asymptotic complexity of the kernel,
/// since intermediate results may be nullified.
struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {};

// Fuse a tensor cast into producing operation. Note that a tensor.cast
// should really not be used to convert between sparse encodings. Since
// the pattern currently appears as a result of some prior rewriting
// we make an attempt to repair very obvious cases.
// TODO: audit the pure tensor dialect rewriting rules
struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {};

/// Rewrites a sequence of operations for sparse tensor selections in to
/// semi-ring operations such that they can be compiled correctly by the
/// sparsifier. E.g., transforming the following sequence
///
/// %sel = arith.select %cond, %sp1, %sp2
///
/// to
///
/// %sel = binary %sp1, %sp2:
///         both  (%l, %r) {yield select %cond, %l, %r}
///         left  (%l)     {yield select %cond, %l,  0}
///         right (%r)     {yield select %cond,  0, %r}
///
/// TODO: We require that the tensor used for extracting conditions to be dense
/// to sparsify the code. To support a sparse condition tensor, we need a
/// tri-nary operation.
struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {};

/// Rewrites a sparse reduction that would not sparsify directly since
/// doing so would only iterate over the stored elements, ignoring the
/// implicit zeros, into a semi-ring. Applies to all prod/and/min/max
/// (note that reductions like add/sub/or/xor can directly be sparsified
/// since the implicit zeros do not contribute to the final result).
/// Note that prod/and are still included since, even though they often
/// are nullified in sparse data, they may still occur for special
/// situations in which e.g. some rows in a sparse matrix are fully
/// dense. For min/max, including the implicit zeros is a much more
/// common situation.
///
/// TODO: this essentially "densifies" the operation; we want to implement
///       this much more efficiently by performing the reduction over the
///       stored values, and feed in the zero once if there were *any*
///       implicit zeros as well; but for now, at least we provide
///       the functionality
///
struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {};

/// Sparse rewriting rule for the print operator. This operation is mainly used
/// for debugging and testing. As such, it lowers to the vector.print operation
/// which only require very light-weight runtime support.
struct PrintRewriter : public OpRewritePattern<PrintOp> {};

/// Sparse rewriting rule for sparse-to-sparse reshape operator.
struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {};

/// Sparse rewriting rule for sparse-to-sparse reshape operator.
template <typename ReshapeOp>
struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {};

/// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape
/// operator.
template <typename ReshapeOp>
struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {};

// A trivial wrapper to help generate different operations for dense/sparse
// tensors.
struct TensorLike {};

struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {};

struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {};

struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {};

struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {};

/// Sparse rewriting rule for the foreach operator.
struct ForeachRewriter : public OpRewritePattern<ForeachOp> {};

/// Sparse rewriting rule for the new operator.
struct NewRewriter : public OpRewritePattern<NewOp> {};

/// Sparse rewriting rule for the out operator.
struct OutRewriter : public OpRewritePattern<OutOp> {};

} // namespace

//===---------------------------------------------------------------------===//
// Methods that add patterns described in this file to a pattern list.
//===---------------------------------------------------------------------===//

void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {}

void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
                                                   bool enableRT,
                                                   bool enableConvert) {}

void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) {}