//===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/ // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "Utils/CodegenUtils.h" #include "Utils/IterationGraphSorter.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" usingnamespacemlir; usingnamespacemlir::sparse_tensor; namespace { //===----------------------------------------------------------------------===// // File Local Helper classes. //===----------------------------------------------------------------------===// // CRTP to help implementing a rewriter that demaps all its inputs. template <typename SubClass, typename SourceOp> struct DemapInsRewriter : public OpRewritePattern<SourceOp> { … }; // Flattens an affine expression into a list of AffineDimExprs. struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> { … }; // Flattens an affine expression into a list of AffineDimExprs. struct AffineExprAdmissibleVisitor : public AffineExprVisitor<AffineExprAdmissibleVisitor> { … }; // The first BitVector stores levels where inadmissible exprs are used. // The second BitVector stores the AffineDimExp that are used by the // inadmissible expressions. InadmissInfo; } // namespace //===----------------------------------------------------------------------===// // File Local Helper methods. //===----------------------------------------------------------------------===// // Collects the inadmissible affine expression imposed on levels. static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) { … } // Builds the AffineMap to replace the idx in idxMap to lvl such that all tht // inadmissible affine expressions can be eliminated. // For example, we can rewrite // idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3) // to // idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3) // by composing inverse(idxMap), that is // inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3) // -> ((l0 * 2 + l2) floordiv 2, // (l1 * 3 + l3) floordiv 3, // (l0 * 2 + l2) mod 2, // (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3) // // This function builds the inverse(idxMap) that replace every dimensions used // in `info` to levels, and updates the iterator type array `itTps` for the new // index variable introduced. // // Note that the returned affine map does not retain the order of the input // affine map. Instead, it always uses the first `info.inAdlvls.count()` for the // replaced levels, and remaining ones for unused dimensions. // For example, to handle // idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4) // which is a typical map for block_2to4. The function returns: // inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1) // in which, (l0, l1) together replaces `d1`, yet they appear // before `d0` in the resulting affine map. // The index (loop) order can later be canonicalized by a topo sort. static AffineMap genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap, SmallVector<utils::IteratorType> &itTps) { … } // Translates the index map in the linalg::GenericOp from idx->dim map to // idx->lvl map. Returns failure if the index map can not be translated to an // admissible form. // Returns the translated index map array and the iterator type array. static std::optional<std::pair<ArrayAttr, ArrayAttr>> translateMap(linalg::GenericOp op, PatternRewriter &rewriter) { … } // Generates a "de"mapping reinterpretation of the map. static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val) { … } // Generates a "re"mapping reinterpretation of the map. static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val) { … } static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types, ValueRange outs) { … } namespace { //===----------------------------------------------------------------------===// // Rewriting rules for linalg generic ops. //===----------------------------------------------------------------------===// /// Sparse rewriting rule for the generic `linalg` operation. struct GenericOpReinterpretMap : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> { … }; struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> { … }; //===----------------------------------------------------------------------===// // Reinterpret Map Rewriters for operations other than linalg.generics //===----------------------------------------------------------------------===// template <typename AllocOp> struct TensorAllocDemapper : public OpRewritePattern<AllocOp> { … }; struct TensorInsertDemapper : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> { … }; struct SparseAssembleDemapper : public OpRewritePattern<AssembleOp> { … }; struct SparseDisassembleDemapper : public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> { … }; struct ForeachOpDemapper : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> { … }; } // namespace void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope) { … }