llvm/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

//===- Sparsification.cpp - Implementation of sparsification --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements converting sparse tensor types to actual sparse code.
//
//===----------------------------------------------------------------------===//

#include "Utils/CodegenEnv.h"
#include "Utils/CodegenUtils.h"
#include "Utils/LoopEmitter.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TensorEncoding.h"
#include "llvm/ADT/SmallBitVector.h"

#include <optional>

usingnamespacemlir;
usingnamespacemlir::sparse_tensor;

//===----------------------------------------------------------------------===//
// Sparsifier analysis methods.
//===----------------------------------------------------------------------===//

/// Returns true iff affine expression is invariant. Sets the
/// parameter `isCurrentLoop` when expression just became invariant.
static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop) {}

/// Helper method to inspect affine expressions. Rejects cases where the
/// same index is used more than once. Also rejects compound affine
/// expressions in sparse dimensions.
static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
                       LevelType lt, bool setLvlFormat = true) {}

/// Helper method to inspect affine expressions for index variable reduction
/// based codegen. It finds the dependent index set for all tensor levels in the
/// current expression we are generating.
///
/// For example, when handling A[i+j][j+k], we build the two way mapping in
/// merger between (tensor, level) pairs and their dependent index variable set:
/// A_0 <=> [i, j] and A_1 <=> [j, k]
///
/// It rejects cases (returns false)
/// 1st, when the same index is used more than once, e.g., A[i+j][i]
/// 2nd, when multiplication is used in the non-trivial index expression.
/// 3rd, when a constant operand is used in the non-trivial index expression.
///
/// TODO: constant should be easy to handle.
static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
                          AffineExpr a, LevelType lt, bool isSubExp = false,
                          int64_t coefficient = 1) {}

/// Gets the total number of compound affine expressions in the
/// `getMatchingIndexingMap` for the given tensor.  For the following inputs:
///
/// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed)
///
/// Returns 1 (because the first level is compressed and its corresponding
/// indexing-expression is `d0 + d1`)
static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
                                                   Value tensor) {}

/// Gets the total number of sparse levels with compound affine
/// expressions, summed over all operands of the `GenericOp`.
static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {}

// Returns true iff output has nontrivial affine indices.
static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {}

/// Helper method to inspect sparse encodings in the tensor types.
/// Fills the per-dimension sparsity information for all tensors.
/// Returns true if the sparse annotations and affine subscript
/// expressions of all tensors are admissible. Returns false if
/// no annotations are found or inadmissible constructs occur.
/// We currently support two different ways to handle non-trivial index
/// expression on sparse tensors, and they accept different affine expressions.
/// When using dependent index reducton-based approach, it currently only
/// supports affine addition index expression.
static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {}

//===----------------------------------------------------------------------===//
// Sparsifier synthesis methods (statements and expressions).
//===----------------------------------------------------------------------===//

/// Local bufferization of all dense and sparse data structures.
static void genBuffers(CodegenEnv &env, OpBuilder &builder) {}

/// Generates index for load/store on sparse tensor.
static Value genIndex(CodegenEnv &env, OpOperand *t) {}

/// Generates subscript for load/store on a dense or sparse tensor.
static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
                          SmallVectorImpl<Value> &args) {}

/// Generates insertion code to implement dynamic tensor load.
static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder,
                              OpOperand *t) {}

/// Generates insertion code to implement dynamic tensor load for reduction.
static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder,
                                    OpOperand *t) {}

static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond,
                                  Value sparseOut, ValueRange ivs, Value v) {}

/// Generates insertion code to implement dynamic tensor store.
static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
                              Value rhs) {}

/// Generates a load on a dense or sparse tensor.
static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {}

/// Generates a store on a dense or sparse tensor.
static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
                           Value rhs) {}

/// Generates an invariant value.
inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {}

/// Semi-ring branches are simply inlined by the sparsifier. Prior
/// analysis has verified that all computations are "local" to the inlined
/// branch or otherwise invariantly defined outside the loop nest, with the
/// exception of index computations, which need to be relinked to actual
/// inlined cloned code.
static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
                          Value e) {}

/// Recursively generates tensor expression.
static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {}

/// Hoists loop invariant tensor loads for which indices have been exhausted.
static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
                          LoopId curr, bool isStart) {}

/// Generates an expanded access pattern in innermost dimension.
static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr,
                      bool isStart) {}

/// Returns parallelization strategy. Any implicit loop in the Linalg
/// operation that is marked "parallel" is a candidate. Whether it is actually
/// converted to a parallel operation depends on the requested strategy.
static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {}

/// Whether or not the current loop being generated should be parallized (if
/// possible) according to the configuration.
static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
                               ArrayRef<TensorLevel> tidLvls) {}

/// Emit a loop to coiterate over the list of tensor levels. The generated loop
/// can either be a for loop or while loop depending on whether there is at most
/// one sparse level in the list.
static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
                                 ArrayRef<TensorLevel> tidLvls,
                                 unsigned numCases, bool tryParallel,
                                 bool needsUniv) {}

/// Generates a for-loop or a while-loop, depending on whether it implements
/// singleton iteration or co-iteration over the given conjunction.
static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr,
                          unsigned numCases, bool needsUniv,
                          ArrayRef<TensorLevel> tidLvls) {}

/// Generates the induction structure for a while-loop.
static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
                            bool needsUniv) {}

/// Generates a case region in the coiterate operation.
static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder,
                               unsigned caseIdx, LatPointId allCase,
                               LatPointId curCase,
                               MutableArrayRef<Value> reduc) {}

/// Generates a single if-statement within a while-loop.
static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
                       LatPointId p) {}

/// Generates end of true branch of if-statement within a while-loop.
static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
                  Value redInput, Value cntInput, Value insInput,
                  Value validIns) {}

//===----------------------------------------------------------------------===//
// Sparsifier synthesis methods (loop sequence).
//===----------------------------------------------------------------------===//

static bool getAllTidLvlsInLatPoints(
    CodegenEnv &env, LatPointId li, LoopId curr,
    llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {}

/// Starts a loop sequence at given level. Returns true if
/// the universal loop index must be maintained at this level.
static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
                         LoopId curr, LatSetId lts) {}

// Generates dense affine address for encoding.
static void genConstantDenseAddressFromLevel(CodegenEnv &env,
                                             OpBuilder &builder, TensorId tid,
                                             Level startLvl) {}

// We can generate address for constant affine expression before any loops
// starting from the first level as they do not depend on anything.
// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
// levels can be determined before loops.
static void genInitConstantDenseAddress(CodegenEnv &env,
                                        RewriterBase &rewriter) {}

/// Returns true if the lattice bit can be iterated by a for loop.
static bool translateBitsToTidLvlPairs(
    CodegenEnv &env, LatPointId li, LoopId curr,
    SmallVectorImpl<TensorLevel> &tidLvls,
    SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {}

/// Starts a single loop in current sequence.
static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
                                              OpBuilder &builder, LoopId curr,
                                              LatPointId li, unsigned numCases,
                                              bool needsUniv) {}

/// Ends a single loop in current sequence. Returns new values for needsUniv.
static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
                    LatPointId li, bool needsUniv, bool isSingleCond) {}

/// Ends a loop sequence at given level.
static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
                       unsigned at) {}

/// Recursively generates code while computing iteration lattices in order
/// to manage the complexity of implementing co-iteration over unions
/// and intersections of sparse iterations spaces.
static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
                    LoopId curr) {}

/// Converts the result computed by the sparse kernel into the required form.
static void genResult(CodegenEnv &env, RewriterBase &rewriter) {}

//===----------------------------------------------------------------------===//
// Sparsifier rewriting methods.
//===----------------------------------------------------------------------===//

namespace {

/// Sparse rewriting rule for generic Lingalg operation.
struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {};

} // namespace

/// Populates the given patterns list with rewriting rules required for
/// the sparsification of linear algebra operations.
void mlir::populateSparsificationPatterns(
    RewritePatternSet &patterns, const SparsificationOptions &options) {}