//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" usingnamespacemlir; usingnamespacemlir::bufferization; usingnamespacemlir::scf; namespace mlir { namespace scf { namespace { /// Helper function for loop bufferization. Cast the given buffer to the given /// memref type. static Value castBuffer(OpBuilder &b, Value buffer, Type type) { … } /// Helper function for loop bufferization. Return "true" if the given value /// is guaranteed to not alias with an external tensor apart from values in /// `exceptions`. A value is external if it is defined outside of the given /// region or if it is an entry block argument of the region. static bool doesNotAliasExternalValue(Value value, Region *region, ValueRange exceptions, const OneShotAnalysisState &state) { … } /// Bufferization of scf.condition. struct ConditionOpInterface : public BufferizableOpInterface::ExternalModel<ConditionOpInterface, scf::ConditionOp> { … }; /// Return the unique scf.yield op. If there are multiple or no scf.yield ops, /// return an empty op. static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) { … } /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not /// fully implemented at the moment. struct ExecuteRegionOpInterface : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel< ExecuteRegionOpInterface, scf::ExecuteRegionOp> { … }; /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. struct IfOpInterface : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> { … }; /// Bufferization of scf.index_switch. Replace with a new scf.index_switch that /// yields memrefs. struct IndexSwitchOpInterface : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface, scf::IndexSwitchOp> { … }; /// Helper function for loop bufferization. Return the indices of all values /// that have a tensor type. static DenseSet<int64_t> getTensorIndices(ValueRange values) { … } /// Helper function for loop bufferization. Return the indices of all /// bbArg/yielded value pairs who's buffer relation is "Equivalent". DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs, ValueRange yieldedValues, const AnalysisState &state) { … } /// Helper function for loop bufferization. Return the bufferized values of the /// given OpOperands. If an operand is not a tensor, return the original value. static FailureOr<SmallVector<Value>> getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands, const BufferizationOptions &options) { … } /// Helper function for loop bufferization. Given a list of bbArgs of the new /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into /// ToTensorOps, so that the block body can be moved over to the new op. static SmallVector<Value> getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, const DenseSet<int64_t> &tensorIndices) { … } /// Compute the bufferized type of a loop iter_arg. This type must be equal to /// the bufferized type of the corresponding init_arg and the bufferized type /// of the corresponding yielded value. /// /// This function uses bufferization::getBufferType to compute the bufferized /// type of the init_arg and of the yielded value. (The computation of the /// bufferized yielded value type usually requires computing the bufferized type /// of the iter_arg again; the implementation of getBufferType traces back the /// use-def chain of the given value and computes a buffer type along the way.) /// If both buffer types are equal, no casts are needed the computed buffer type /// can be used directly. Otherwise, the buffer types can only differ in their /// layout map and a cast must be inserted. static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType( Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue, const BufferizationOptions &options, SmallVector<Value> &invocationStack) { … } /// Return `true` if the given loop may have 0 iterations. bool mayHaveZeroIterations(scf::ForOp forOp) { … } /// Bufferization of scf.for. Replace with a new scf.for that operates on /// memrefs. struct ForOpInterface : public BufferizableOpInterface::ExternalModel<ForOpInterface, scf::ForOp> { … }; /// Bufferization of scf.while. Replace with a new scf.while that operates on /// memrefs. struct WhileOpInterface : public BufferizableOpInterface::ExternalModel<WhileOpInterface, scf::WhileOp> { … }; /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so /// this is for analysis only. struct YieldOpInterface : public BufferizableOpInterface::ExternalModel<YieldOpInterface, scf::YieldOp> { … }; /// Return `true` if the given loop may have 0 iterations. bool mayHaveZeroIterations(scf::ForallOp forallOp) { … } /// Bufferization of ForallOp. This also bufferizes the terminator of the /// region. There are op interfaces for the terminators (InParallelOp /// and ParallelInsertSliceOp), but these are only used during analysis. Not /// for bufferization. struct ForallOpInterface : public BufferizableOpInterface::ExternalModel<ForallOpInterface, ForallOp> { … }; /// Nothing to do for InParallelOp. struct InParallelOpInterface : public BufferizableOpInterface::ExternalModel<InParallelOpInterface, InParallelOp> { … }; } // namespace } // namespace scf } // namespace mlir void mlir::scf::registerBufferizableOpInterfaceExternalModels( DialectRegistry ®istry) { … }