#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
#include <type_traits>
namespace mlir {
namespace tensor {
#define GEN_PASS_DEF_FOLDTENSORSUBSETOPS
#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
}
}
usingnamespacemlir;
static Value getTensorOperand(vector::TransferReadOp op) { … }
static Value getTensorOperand(tensor::InsertSliceOp op) { … }
namespace {
class TransferReadOfExtractSliceOpFolder final
: public vector::MaskableOpRewritePattern<vector::TransferReadOp> { … };
class InsertSliceOfTransferWriteOpFolder final
: public OpRewritePattern<tensor::InsertSliceOp> { … };
}
template <typename XferOp, typename ExtractOrInsertOp>
static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(
RewriterBase &rewriter, XferOp xferOp,
ExtractOrInsertOp extractOrInsertSliceOp) { … }
FailureOr<mlir::Value>
TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp(
vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp,
PatternRewriter &rewriter) const { … }
LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const { … }
template <typename OpTy>
struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> { … };
void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) { … }
void tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(
RewritePatternSet &patterns) { … }
namespace {
struct FoldTensorSubsetOpsPass final
: public tensor::impl::FoldTensorSubsetOpsBase<FoldTensorSubsetOpsPass> { … };
}
void FoldTensorSubsetOpsPass::runOnOperation() { … }
std::unique_ptr<Pass> tensor::createFoldTensorSubsetOpsPass() { … }