//===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/PatternMatch.h" namespace mlir { namespace tensor { namespace { static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) { … } /// Returns the number of shape sizes that is either dynamic or greater than 1. static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) { … } /// Returns success() if there is only 1 dimension size in non-packed domain /// being greater than 1 and packing only happens on the dimension. /// Note: this method should only be used by pack/unpack to reshape conversion. /// It assumes that non-unit inner tile size must be used by the non-unit /// dimension. static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> srcShape, ArrayRef<int64_t> innerPackTileSize) { … } // If the `linalgOp` represents a transpose, return the permutation vector for // the transpose. Otherwise, return failure. static FailureOr<SmallVector<int64_t>> getTransposeOpPermutation(linalg::LinalgOp linalgOp) { … } /// Packing one-dimensional tensor can be expressed as an expand shape op. struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> { … }; struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> { … }; /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and /// the pad op has zero low paddings, or if `pack` has no padding values. struct FoldPadWithPackOp : public OpRewritePattern<PackOp> { … }; /// Fold a `unpack` -> `extract_slice` into the `unpack` since it already /// has extract_slice semantics. struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> { … }; // Applies 'permutation' on 'inVec' and stores the result in resVec. // 'inVec' may be empty, in that case it's one-to-one mapping with permutation. // `rank` sets the boundary for permutation i.e., the permutation dim can't be // greater than the rank specified. If it's so then return false. // For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in // permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is // not allowed since `3` exceeds the value of the rank in the given range. static bool checkAndPermute(ArrayRef<int64_t> permutation, ArrayRef<int64_t> inVec, SmallVectorImpl<int64_t> &resVec, int64_t rank) { … } /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose /// semantics. struct FoldProducerPackWithConsumerLinalgTransposeOp : public OpInterfaceRewritePattern<linalg::LinalgOp> { … }; /// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose /// semantics. struct FoldConsumerPackWithProducerLinalgTransposeOp : public OpRewritePattern<PackOp> { … }; /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has /// transpose semantics. struct FoldProducerUnPackWithConsumerLinalgTransposeOp : public OpInterfaceRewritePattern<linalg::LinalgOp> { … }; /// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has /// transpose semantics. struct FoldConsumerUnPackWithProducerLinalgTransposeOp : public OpRewritePattern<UnPackOp> { … }; } // namespace void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) { … } void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) { … } } // namespace tensor } // namespace mlir