//===- MergeConsecutiveInsertExtractSlicePatterns.cpp ---------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" usingnamespacemlir; usingnamespacemlir::tensor; namespace { /// Merges consecutive tensor.extract_slice ops into one. // TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps. struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> { … }; /// Merges consecutive tensor.insert_slice ops into one. // TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps. template <typename OpTy> struct MergeConsecutiveInsertSlice : public OpRewritePattern<OpTy> { … }; /// Drop redundant rank expansion of insert_slice that are directly followed /// by extract_slice. E.g.: /// %0 = tensor.insert_slice ... : tensor<5x10xf32> into tensor<1x1x5x10xf32> /// %1 = tensor.extract_slice %0[0, 0, 2, 3] [1, 1, 2, 2] [1, 1, 1, 1] /// : tensor<1x1x5x10xf32> to tensor<2x2xf32> struct DropRedundantRankExpansionOnExtractSliceOfInsertSlice : public OpRewritePattern<ExtractSliceOp> { … }; /// Drop redundant rank expansion of insert_slice that direclty follows /// extract_slice. /// /// This can be done when the insert_slice op purely expands ranks (adds unit /// dims) and the extrace_slice drops corresponding unit dims. For example: /// /// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1] /// : tensor<2x8xf32> to tensor<8xf32> /// %inserted_slice = tensor.insert_slice %extracted_slice /// into %dest[0, 0] [1, 8] [1, 1] /// : tensor<8xf32> into tensor<1x8xf32> /// /// can be folded into: /// /// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1] /// : tensor<2x8xf32> to tensor<1x8xf32> struct DropRedundantRankExpansionOnInsertSliceOfExtractSlice final : public OpRewritePattern<tensor::InsertSliceOp> { … }; } // namespace void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns( RewritePatternSet &patterns) { … } void mlir::tensor::populateDropRedundantInsertSliceRankExpansionPatterns( RewritePatternSet &patterns) { … }