//===- TosaInferShapes.cpp ------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Propogate shapes forward along TOSA operations to resolve dynamic shape // operations. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace tosa { #define GEN_PASS_DEF_TOSAINFERSHAPES #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" } // namespace tosa } // namespace mlir usingnamespacemlir; usingnamespacemlir::tosa; namespace { // Check whether this use case is replaceable. We define an op as // being replaceable if it is used by a TosaOp, or an op with a // type-inference related interface. // When a non-replaceable use is encountered, the value is wrapped in a // cast back to the original type after inference. bool canBeRefined(Operation *user) { … } // During type propagation, the types of values in the operator graph are // updated. For the tosa.while_loop operation, types are speculatively updated // within the body region to determine the output type of the while_loop. This // process is performed until a fixed point is reached, then the types are // rolled back. // // This class encapsulates the state information needed to perform the roll back // process or to commit to the final changes. class TypeModificationState { … }; void propagateShapesInRegion(Region ®ion, TypeModificationState &state); void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) { … } void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) { … } void propagateShapesInRegion(Region ®ion, TypeModificationState &state) { … } /// Pass that performs shape propagation across TOSA operations. This includes /// migrating to within the regions of if/while operations. struct TosaInferShapes : public tosa::impl::TosaInferShapesBase<TosaInferShapes> { … }; } // namespace std::unique_ptr<Pass> mlir::tosa::createTosaInferShapesPass() { … }