llvm/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp

//===- MergeConsecutiveInsertExtractSlicePatterns.cpp ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"

usingnamespacemlir;
usingnamespacemlir::tensor;

namespace {
/// Merges consecutive tensor.extract_slice ops into one.
// TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {};

/// Merges consecutive tensor.insert_slice ops into one.
// TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
template <typename OpTy>
struct MergeConsecutiveInsertSlice : public OpRewritePattern<OpTy> {};

/// Drop redundant rank expansion of insert_slice that are directly followed
/// by extract_slice. E.g.:
/// %0 = tensor.insert_slice ... : tensor<5x10xf32> into tensor<1x1x5x10xf32>
/// %1 = tensor.extract_slice %0[0, 0, 2, 3] [1, 1, 2, 2] [1, 1, 1, 1]
///     : tensor<1x1x5x10xf32> to tensor<2x2xf32>
struct DropRedundantRankExpansionOnExtractSliceOfInsertSlice
    : public OpRewritePattern<ExtractSliceOp> {};

/// Drop redundant rank expansion of insert_slice that direclty follows
/// extract_slice.
///
/// This can be done when the insert_slice op purely expands ranks (adds unit
/// dims) and the extrace_slice drops corresponding unit dims. For example:
///
/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
///     : tensor<2x8xf32> to tensor<8xf32>
/// %inserted_slice = tensor.insert_slice %extracted_slice
///     into %dest[0, 0] [1, 8] [1, 1]
///     : tensor<8xf32> into tensor<1x8xf32>
///
/// can be folded into:
///
/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
///     : tensor<2x8xf32> to tensor<1x8xf32>
struct DropRedundantRankExpansionOnInsertSliceOfExtractSlice final
    : public OpRewritePattern<tensor::InsertSliceOp> {};
} // namespace

void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
    RewritePatternSet &patterns) {}

void mlir::tensor::populateDropRedundantInsertSliceRankExpansionPatterns(
    RewritePatternSet &patterns) {}