llvm/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp

//===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "Utils/CodegenUtils.h"
#include "Utils/IterationGraphSorter.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"

usingnamespacemlir;
usingnamespacemlir::sparse_tensor;

namespace {

//===----------------------------------------------------------------------===//
// File Local Helper classes.
//===----------------------------------------------------------------------===//

// CRTP to help implementing a rewriter that demaps all its inputs.
template <typename SubClass, typename SourceOp>
struct DemapInsRewriter : public OpRewritePattern<SourceOp> {};

// Flattens an affine expression into a list of AffineDimExprs.
struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {};

// Flattens an affine expression into a list of AffineDimExprs.
struct AffineExprAdmissibleVisitor
    : public AffineExprVisitor<AffineExprAdmissibleVisitor> {};

// The first BitVector stores levels where inadmissible exprs are used.
// The second BitVector stores the AffineDimExp that are used by the
// inadmissible expressions.
InadmissInfo;

} // namespace

//===----------------------------------------------------------------------===//
// File Local Helper methods.
//===----------------------------------------------------------------------===//

// Collects the inadmissible affine expression imposed on levels.
static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {}

// Builds the AffineMap to replace the idx in idxMap to lvl such that all tht
// inadmissible affine expressions can be eliminated.
// For example, we can rewrite
// idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
// to
// idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
// by composing inverse(idxMap), that is
// inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
//                         -> ((l0 * 2 + l2) floordiv 2,
//                             (l1 * 3 + l3) floordiv 3,
//                             (l0 * 2 + l2) mod 2,
//                             (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
//
// This function builds the inverse(idxMap) that replace every dimensions used
// in `info` to levels, and updates the iterator type array `itTps` for the new
// index variable introduced.
//
// Note that the returned affine map does not retain the order of the input
// affine map. Instead, it always uses the first `info.inAdlvls.count()` for the
// replaced levels, and remaining ones for unused dimensions.
// For example, to handle
// idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
// which is a typical map for block_2to4. The function returns:
// inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
// in which, (l0, l1) together replaces `d1`, yet they appear
// before `d0` in the resulting affine map.
// The index (loop) order can later be canonicalized by a topo sort.
static AffineMap
genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
                      SmallVector<utils::IteratorType> &itTps) {}

// Translates the index map in the linalg::GenericOp from idx->dim map to
// idx->lvl map. Returns failure if the index map can not be translated to an
// admissible form.
// Returns the translated index map array and the iterator type array.
static std::optional<std::pair<ArrayAttr, ArrayAttr>>
translateMap(linalg::GenericOp op, PatternRewriter &rewriter) {}

// Generates a "de"mapping reinterpretation of the map.
static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
                      Value val) {}

// Generates a "re"mapping reinterpretation of the map.
static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
                      Value val) {}

static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
                                          ValueRange outs) {}

namespace {

//===----------------------------------------------------------------------===//
// Rewriting rules for linalg generic ops.
//===----------------------------------------------------------------------===//

/// Sparse rewriting rule for the generic `linalg` operation.
struct GenericOpReinterpretMap
    : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {};

struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {};

//===----------------------------------------------------------------------===//
// Reinterpret Map Rewriters for operations other than linalg.generics
//===----------------------------------------------------------------------===//

template <typename AllocOp>
struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {};

struct TensorInsertDemapper
    : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {};

struct SparseAssembleDemapper : public OpRewritePattern<AssembleOp> {};

struct SparseDisassembleDemapper
    : public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {};

struct ForeachOpDemapper
    : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {};

} // namespace

void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
                                        ReinterpretMapScope scope) {}