#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
usingnamespacemlir;
namespace {
struct TestTensorTransforms
: public PassWrapper<TestTensorTransforms, OperationPass<>> { … };
}
static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) { … }
static void applyBubbleUpExpandShapePatterns(Operation *rootOp) { … }
static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) { … }
static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) { … }
static void applyFoldConsecutiveInsertExtractSlicePatterns(Operation *rootOp) { … }
static void
applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) { … }
static void applySimplifyPackUnpackPatterns(Operation *rootOp) { … }
namespace {
struct RewriteExtractSliceFromCollapseShapeBase
: public OpRewritePattern<tensor::ExtractSliceOp> { … };
struct RewriteExtractSliceFromCollapseShapeUsingScfFor
: public RewriteExtractSliceFromCollapseShapeBase { … };
struct RewriteExtractSliceFromCollapseShapeUsingScfForeach
: public RewriteExtractSliceFromCollapseShapeBase { … };
}
static LogicalResult
applyRewriteExtractFromCollapseShapePatterns(Operation *rootOp,
bool useForeach) { … }
namespace {
class DummyTrackingListener : public transform::TrackingListener { … };
}
static LogicalResult testTrackingListenerReplacements(Operation *rootOp) { … }
void TestTensorTransforms::runOnOperation() { … }
namespace mlir {
namespace test {
void registerTestTensorTransforms() { … }
}
}