#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/Debug.h"
#include <queue>
#include <unordered_set>
#include <vector>
namespace mlir {
#define GEN_PASS_DEF_OUTLINESHAPECOMPUTATION
#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
}
#define DEBUG_TYPE …
usingnamespacemlir;
namespace {
SmallVector<Value, 4>
getInputsOfCluster(const llvm::SmallVector<Operation *, 8> &cluster) { … }
std::pair<shape::FuncOp, SmallVector<Value>>
createFuncFromCluster(OpBuilder &b, const SmallVector<Operation *, 8> &cluster,
Value shape, StringRef fnName, Location loc) { … }
DenseMap<Value, SmallVector<Operation *, 8>>
getOrderedClusters(const DenseMap<Value, DenseSet<Operation *>> &clusters,
func::FuncOp funcOp) { … }
void constructShapeFunc(
const std::vector<shape::WithOp> &allWithOps, MLIRContext *context,
DenseMap<Value, SmallVector<Operation *, 8>> &clusters,
SymbolTable &symbolTable,
DenseMap<Value, shape::ShapeMappingValue> &dynShape2ShapeFunc,
func::FuncOp funcOp, shape::ShapeMappingAnalysis &shapeMappingAnalysis) { … }
struct OutlineShapeComputationPass
: public impl::OutlineShapeComputationBase<OutlineShapeComputationPass> { … };
class TensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> { … };
void OutlineShapeComputationPass::runOnOperation() { … }
DenseMap<Value, SmallVector<Operation *, 8>>
OutlineShapeComputationPass::constructClustersForEachShape(
const std::vector<shape::WithOp> &allWithOps, func::FuncOp funcOp) { … }
void OutlineShapeComputationPass::getClusterFromValue(
Value shape, DenseMap<Value, DenseSet<Operation *>> &clusters) { … }
bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively(
Operation *op, Value prevOutput) { … }
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::createOutlineShapeComputationPass() { … }