//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Module Bufferization is an extension of One-Shot Bufferize that // bufferizes function boundaries. It provides `BufferizableOpInterface` // implementations for FuncOp, CallOp and ReturnOp. // // Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`. // This function analyzes the given module and determines the order of analysis // and bufferization: Functions that are called are processed before their // respective callers. // // After analyzing a FuncOp, additional information about its bbArgs is // gathered and stored in `FuncAnalysisState`. // // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs // for // each tensor return value (if any). // * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is // read/written. // // Module Bufferization implements the following calling convention. // // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always // be written to in-place. // * If a tensor operand of a CallOp is read after the CallOp, the operand of // the CallOp must bufferize out-of-place. // // Example: The tensor.insert op bufferizes in-place because it is allowed to // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize // out-of-place because `%t0` is modified by the callee but read by the // tensor.extract op. The analysis of CallOps decides whether an OpOperand must // bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`. // ``` // func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> { // %f = ... : f32 // %0 = tensor.insert %f into %t1[...] : tensor<?xf32> // return %0 : tensor<?xf32> // } // // func @caller() -> () { // %t0 = ... : tensor<?xf32> // %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>) // %2 = tensor.extract %1[...] : tensor<?xf32> // } // ``` // // Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot // analyze the function body. In such a case, the CallOp analysis conservatively // assumes that each tensor OpOperand is both read and written. // // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked // as "not reading" and/or "not writing". #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" usingnamespacemlir; usingnamespacemlir::bufferization; usingnamespacemlir::bufferization::func_ext; /// A mapping of FuncOps to their callers. FuncCallerMap; /// Get or create FuncAnalysisState. static FuncAnalysisState & getOrCreateFuncAnalysisState(OneShotAnalysisState &state) { … } /// Return the unique ReturnOp that terminates `funcOp`. /// Return nullptr if there is no such unique ReturnOp. static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { … } namespace { /// Annotate IR with the results of the analysis. For testing purposes only. static void annotateEquivalentReturnBbArg(OpOperand &returnVal, BlockArgument bbArg) { … } /// Store function BlockArguments that are equivalent to/aliasing a returned /// value in FuncAnalysisState. static LogicalResult aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, FuncAnalysisState &funcState) { … } static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead, bool isWritten) { … } /// Determine which FuncOp bbArgs are read and which are written. When run on a /// function with unknown ops, we conservatively assume that such ops bufferize /// to a read + write. static LogicalResult funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state, FuncAnalysisState &funcState) { … } } // namespace /// Remove bufferization attributes on FuncOp arguments. static void removeBufferizationAttributes(BlockArgument bbArg) { … } /// Return the func::FuncOp called by `callOp`. static func::FuncOp getCalledFunction(func::CallOp callOp) { … } /// Gather equivalence info of CallOps. /// Note: This only adds new equivalence info if the called function was already /// analyzed. // TODO: This does not handle cyclic function call graphs etc. static void equivalenceAnalysis(func::FuncOp funcOp, OneShotAnalysisState &state, FuncAnalysisState &funcState) { … } /// Return "true" if the given function signature has tensor semantics. static bool hasTensorSignature(func::FuncOp funcOp) { … } /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by /// callee-caller order (i.e. callees without callers first). /// Store the map of FuncOp to all its callers in `callerMap`. /// Return `failure()` if a cycle of calls is detected or if we are unable to /// retrieve the called FuncOp from any func::CallOp. static LogicalResult getFuncOpsOrderedByCalls(ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps, FuncCallerMap &callerMap) { … } /// Fold return values that are memref casts and update function return types. /// /// During FuncOp bufferization, the exact type of the returned memrefs (if any) /// is not known yet. Therefore, the bufferization uses memref types with the /// most generic layout map as function return types. After bufferizing the /// entire function body, a more concise memref type can potentially be used for /// the return type of the function. static void foldMemRefCasts(func::FuncOp funcOp) { … } LogicalResult mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics) { … } void mlir::bufferization::removeBufferizationAttributesInModule( ModuleOp moduleOp) { … } LogicalResult mlir::bufferization::bufferizeModuleOp( ModuleOp moduleOp, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics) { … } LogicalResult mlir::bufferization::runOneShotModuleBufferize( ModuleOp moduleOp, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics) { … }