llvm/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp

//====----- OutlineShapeComputation.cpp -----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/Debug.h"
#include <queue>
#include <unordered_set>
#include <vector>

namespace mlir {
#define GEN_PASS_DEF_OUTLINESHAPECOMPUTATION
#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
} // namespace mlir

#define DEBUG_TYPE

usingnamespacemlir;

namespace {

// A Value is an input of the cluster if it is an operand of an operation in the
// cluster and its defining operation is not in the cluster.
SmallVector<Value, 4>
getInputsOfCluster(const llvm::SmallVector<Operation *, 8> &cluster) {}

// Create a shape.func representing the shape computation for `shape`.
std::pair<shape::FuncOp, SmallVector<Value>>
createFuncFromCluster(OpBuilder &b, const SmallVector<Operation *, 8> &cluster,
                      Value shape, StringRef fnName, Location loc) {}

// The operations in the cluster might be unsorted, which could be inconvenient
// when creating shape.func op.
DenseMap<Value, SmallVector<Operation *, 8>>
getOrderedClusters(const DenseMap<Value, DenseSet<Operation *>> &clusters,
                   func::FuncOp funcOp) {}

void constructShapeFunc(
    const std::vector<shape::WithOp> &allWithOps, MLIRContext *context,
    DenseMap<Value, SmallVector<Operation *, 8>> &clusters,
    SymbolTable &symbolTable,
    DenseMap<Value, shape::ShapeMappingValue> &dynShape2ShapeFunc,
    func::FuncOp funcOp, shape::ShapeMappingAnalysis &shapeMappingAnalysis) {}

struct OutlineShapeComputationPass
    : public impl::OutlineShapeComputationBase<OutlineShapeComputationPass> {};

class TensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {};

void OutlineShapeComputationPass::runOnOperation() {}

DenseMap<Value, SmallVector<Operation *, 8>>
OutlineShapeComputationPass::constructClustersForEachShape(
    const std::vector<shape::WithOp> &allWithOps, func::FuncOp funcOp) {}

// The output of a cluster is the `shape`, and the inputs are the outputs of
// operations who are not in `onlyUsedByWithShapes`
void OutlineShapeComputationPass::getClusterFromValue(
    Value shape, DenseMap<Value, DenseSet<Operation *>> &clusters) {}

// Returns whether `op` is a shape.with_shape, or all the users' of `op`
// eventually point to the shape operand of shape.with_shape ops
bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively(
    Operation *op, Value prevOutput) {}

} // namespace

std::unique_ptr<OperationPass<ModuleOp>>
mlir::createOutlineShapeComputationPass() {}