llvm/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

//===- BufferizableOpInterface.h - Bufferizable Ops -------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_
#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_

#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMapInfoVariant.h"
#include "llvm/ADT/SetVector.h"
#include <optional>

#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"

namespace mlir {
class OpBuilder;
namespace func {
class FuncOp;
}

namespace bufferization {

class AnalysisState;
class BufferizableOpInterface;

/// Specifies a fine-grain relationship between buffers to enable more analysis.
enum class BufferRelation {};

/// A maybe aliasing OpOperand. If `isDefinite` is `true`, the OpOperand is
/// guaranteed to alias at runtime.
struct AliasingOpOperand {};

/// A maybe aliasing Value. If `isDefinite` is `true`, the Value is guaranteed
/// to alias at runtime.
struct AliasingValue {};

template <typename T> class AliasList {};

/// A list of possible aliasing OpOperands. This list models the runtime
/// aliasing relationship for a Value.
AliasingOpOperandList;

/// A list of possible aliasing Values. This list models the runtime aliasing
/// relationship for an OpOperand.
AliasingValueList;

class OpFilter {};

/// Options for BufferizableOpInterface-based bufferization.
struct BufferizationOptions {};

/// Traversal parameters for `findValueInReverseUseDefChain`.
struct TraversalConfig {};

/// AnalysisState provides a variety of helper functions for dealing with
/// tensor values.
class AnalysisState {};

/// Create an AllocTensorOp for the given shaped value (memref or tensor).
/// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
/// undefined contents is allocated.
FailureOr<Value>
allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
                             const BufferizationOptions &options,
                             bool copy = true);

/// Lookup the buffer for the given value. If the value was not bufferized
/// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp,
/// from which the memref operand is returned.
FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
                           const BufferizationOptions &options);

/// Return the buffer type for a given Value (tensor) after bufferization
/// without bufferizing any IR.
///
/// Note: It should be sufficient to call `getBuffer()->getType()` in most
/// cases. However, when a buffer type should be predicted without modifying any
/// IR, this function can be used.
///
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
FailureOr<BaseMemRefType> getBufferType(Value value,
                                        const BufferizationOptions &options);

/// Return the buffer type for a given Value (tensor) after bufferization
/// without bufferizing any IR. This function (and not the other overload
/// without `invocationStack`) can be used from `getBufferType` implementations
/// of the `BufferizableOpInterface`.
///
/// Note: It should be sufficient to call `getBuffer()->getType()` in most
/// cases. However, when a buffer type should be predicted without modifying any
/// IR, this function can be used.
///
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
FailureOr<BaseMemRefType> getBufferType(Value value,
                                        const BufferizationOptions &options,
                                        SmallVector<Value> &invocationStack);

/// Return "true" if the given op has tensor semantics and should be bufferized.
/// If the op is bufferizable, the BufferizableOpInterface is queried.
/// Otherwise, an op has tensor semantics if it has tensor operands, tensor
/// op results and/or tensor block arguments.
bool hasTensorSemantics(Operation *op);

/// Replace an op with replacement values. The op is deleted. Tensor OpResults
/// must be replaced with memref values.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
                                   ValueRange values);

/// Replace an op with a new op. The new op must have the same number of
/// results as the replaced op. The new op may not return any tensor values.
template <typename OpTy, typename... Args>
OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
                                  Args &&...args) {}

/// Return a MemRefType to which the type of the given value can be bufferized.
///
/// If possible, op bufferization implementations should not use this function
/// and instead infer precise memref types for tensor results by themselves.
///
/// Unless a layout map was specified, `options.unknownTypeConverterFn`
/// determines what kind of layout map will be used. For best composability
/// (without copies), the fully dynamic layout map is used by default.
///
/// Note: Canonicalization patterns could clean up layout maps and infer more
/// precise layout maps after bufferization. However, many possible
/// canonicalizations are currently not implemented.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options,
                             MemRefLayoutAttrInterface layout = {};

/// Return a MemRef type with fully dynamic layout. If the given tensor type
/// is unranked, return an unranked MemRef type.
BaseMemRefType
getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
                                    Attribute memorySpace = nullptr);

/// Return a MemRef type with a static identity layout (i.e., no layout map). If
/// the given tensor type is unranked, return an unranked MemRef type.
BaseMemRefType
getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
                                      Attribute memorySpace = nullptr);

/// Return the owner of the given value. In case of a BlockArgument that is the
/// owner of the block. In case of an OpResult that is the defining op.
Operation *getOwnerOfValue(Value value);

/// Assuming that the given region is repetitive, find the next enclosing
/// repetitive region.
Region *getNextEnclosingRepetitiveRegion(Region *region,
                                         const BufferizationOptions &options);

/// If `region` is a parallel region, return `region`. Otherwise, find the first
/// enclosing parallel region of `region`. If there is no such region, return
/// "nullptr".
///
/// Note: Whether a region is parallel or sequential is queried from the
/// `BufferizableOpInterface`.
Region *getParallelRegion(Region *region, const BufferizationOptions &options);

namespace detail {
/// This is the default implementation of
/// BufferizableOpInterface::getAliasingOpOperands. Should not be called from
/// other places.
AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
                                                   const AnalysisState &state);

/// This is the default implementation of
/// BufferizableOpInterface::getBufferType. Should not be called from other
/// places.
FailureOr<BaseMemRefType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
                     SmallVector<Value> &invocationStack);

/// This is the default implementation of
/// BufferizableOpInterface::resultBufferizesToMemoryWrite. Should not be called
/// from other places.
bool defaultResultBufferizesToMemoryWrite(OpResult opResult,
                                          const AnalysisState &state);

/// This is the default implementation of
/// BufferizableOpInterface::isRepetitiveRegion. Should not be called from other
/// places.
bool defaultIsRepetitiveRegion(BufferizableOpInterface bufferizableOp,
                               unsigned index);

/// This is the default implementation of getAliasingOpOperands in case the
/// defining op does not implement the BufferizableOpInterface.
AliasingOpOperandList unknownGetAliasingOpOperands(Value value);

/// This is the default implementation of getAliasingValues in case the owner
/// op does not implement the BufferizableOpInterface.
AliasingValueList unknownGetAliasingValues(OpOperand &opOperand);

/// This is the default implementation of
/// BufferizableOpInterface::hasTensorSemantics
bool defaultHasTensorSemantics(Operation *op);
} // namespace detail

} // namespace bufferization
} // namespace mlir

MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::bufferization::AnalysisState)

//===----------------------------------------------------------------------===//
// Bufferization Interfaces
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h.inc"

#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_