//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a Function level pass performing interprocedural // propagation of array shapes through function specialization. // //===----------------------------------------------------------------------===// #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/TypeID.h" #include "toy/Dialect.h" #include "toy/Passes.h" #include "toy/ShapeInferenceInterface.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include <memory> #define DEBUG_TYPE … usingnamespacemlir; usingnamespacetoy; /// Include the auto-generated definitions for the shape inference interfaces. #include "toy/ShapeInferenceOpInterfaces.cpp.inc" namespace { /// The ShapeInferencePass is a pass that performs intra-procedural /// shape inference. /// /// Algorithm: /// /// 1) Build a worklist containing all the operations that return a /// dynamically shaped tensor: these are the operations that need shape /// inference. /// 2) Iterate on the worklist: /// a) find an operation to process: the next ready operation in the /// worklist has all of its arguments non-generic, /// b) if no operation is found, break out of the loop, /// c) remove the operation from the worklist, /// d) infer the shape of its output from the argument types. /// 3) If the worklist is empty, the algorithm succeeded. /// struct ShapeInferencePass : public mlir::PassWrapper<ShapeInferencePass, OperationPass<toy::FuncOp>> { … }; } // namespace /// Create a Shape Inference pass. std::unique_ptr<mlir::Pass> mlir::toy::createShapeInferencePass() { … }