llvm/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp

//===- TosaInferShapes.cpp ------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Propogate shapes forward along TOSA operations to resolve dynamic shape
// operations.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/Transforms/Passes.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
namespace tosa {
#define GEN_PASS_DEF_TOSAINFERSHAPES
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
} // namespace tosa
} // namespace mlir

usingnamespacemlir;
usingnamespacemlir::tosa;

namespace {

// Check whether this use case is replaceable. We define an op as
// being replaceable if it is used by a TosaOp, or an op with a
// type-inference related interface.
// When a non-replaceable use is encountered, the value is wrapped in a
// cast back to the original type after inference.
bool canBeRefined(Operation *user) {}

// During type propagation, the types of values in the operator graph are
// updated. For the tosa.while_loop operation, types are speculatively updated
// within the body region to determine the output type of the while_loop. This
// process is performed until a fixed point is reached, then the types are
// rolled back.
//
// This class encapsulates the state information needed to perform the roll back
// process or to commit to the final changes.
class TypeModificationState {};

void propagateShapesInRegion(Region &region, TypeModificationState &state);

void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {}

void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {}

void propagateShapesInRegion(Region &region, TypeModificationState &state) {}

/// Pass that performs shape propagation across TOSA operations. This includes
/// migrating to within the regions of if/while operations.
struct TosaInferShapes
    : public tosa::impl::TosaInferShapesBase<TosaInferShapes> {};
} // namespace

std::unique_ptr<Pass> mlir::tosa::createTosaInferShapesPass() {}