llvm/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h

//===- TransformInterfaces.h - Transform Dialect Interfaces -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
#define MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H

#include "mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h"
#include "mlir/Dialect/Transform/Utils/RaggedArray.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"

#include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.h.inc"

namespace mlir {
namespace transform {

class TransformOpInterface;
class TransformResults;
class TransformRewriter;
class TransformState;

Param;
MappedValue;

namespace detail {
/// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait
/// to either the list of operations associated with its operand or the root of
/// the payload IR, depending on what is available in the context.
LogicalResult
mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
                                             Operation *op, Region &region);

/// Verification hook for PossibleTopLevelTransformOpTrait.
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);

/// Populates `effects` with side effects implied by
/// PossibleTopLevelTransformOpTrait for the given operation. The operation may
/// have an optional `root` operand, indicating it is not in fact top-level. It
/// is also expected to have a single-block body.
void getPotentialTopLevelEffects(
    Operation *operation, Value root, Block &body,
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects);

/// Verification hook for TransformOpInterface.
LogicalResult verifyTransformOpInterface(Operation *op);

/// Appends the entities associated with the given transform values in `state`
/// to the pre-existing list of mappings. The array of mappings must have as
/// many elements as values. If `flatten` is set, multiple values may be
/// associated with each transform value, and this always succeeds. Otherwise,
/// checks that each value has exactly one mapping associated and return failure
/// otherwise.
LogicalResult appendValueMappings(
    MutableArrayRef<SmallVector<transform::MappedValue>> mappings,
    ValueRange values, const transform::TransformState &state,
    bool flatten = true);

/// Populates `mappings` with mapped values associated with the given transform
/// IR values in the given `state`.
void prepareValueMappings(
    SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings,
    ValueRange values, const transform::TransformState &state);

/// Populates `results` with payload associations that match exactly those of
/// the operands to `block`'s terminator.
void forwardTerminatorOperands(Block *block, transform::TransformState &state,
                               transform::TransformResults &results);

/// Make a dummy transform state for testing purposes. This MUST NOT be used
/// outside of test cases.
TransformState makeTransformStateForTesting(Region *region,
                                            Operation *payloadRoot);

/// Returns all operands that are handles and being consumed by the given op.
SmallVector<OpOperand *>
getConsumedHandleOpOperands(transform::TransformOpInterface transformOp);
} // namespace detail
} // namespace transform
} // namespace mlir

#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc"

namespace mlir {
namespace transform {

/// Options controlling the application of transform operations by the
/// TransformState.
class TransformOptions {};

/// Entry point to the Transform dialect infrastructure. Applies the
/// transformation specified by `transform` to payload IR contained in
/// `payloadRoot`. The `transform` operation may contain other operations that
/// will be executed following the internal logic of the operation. It must
/// have the `PossibleTopLevelTransformOp` trait and not have any operands.
/// This function internally keeps track of the transformation state.
LogicalResult applyTransforms(
    Operation *payloadRoot, TransformOpInterface transform,
    const RaggedArray<MappedValue> &extraMapping = {};

/// The state maintained across applications of various ops implementing the
/// TransformOpInterface. The operations implementing this interface and the
/// surrounding structure are referred to as transform IR. The operations to
/// which transformations apply are referred to as payload IR. Transform IR
/// operates on values that can be associated either with a list of payload IR
/// operations (such values are referred to as handles) or with a list of
/// parameters represented as attributes. The state thus contains the mapping
/// between values defined in the transform IR ops and either payload IR ops or
/// parameters. For payload ops, the mapping is many-to-many and the reverse
/// mapping is also stored. The "expensive-checks" option can be passed to the
/// constructor at transformation execution time that transform IR values used
/// as operands by a transform IR operation are not associated with dangling
/// pointers to payload IR operations that are known to have been erased by
/// previous transformation through the same or a different transform IR value.
///
/// A reference to this class is passed as an argument to "apply" methods of the
/// transform op interface. Thus the "apply" method can call either
/// `state.getPayloadOps( getSomeOperand() )` to obtain the list of operations
/// or `state.getParams( getSomeOperand() )` to obtain the list of parameters
/// associated with its operand. The method is expected to populate the
/// `TransformResults` class instance in order to update the mapping. The
/// `applyTransform` method takes care of propagating the state of
/// `TransformResults` into the instance of this class.
///
/// When applying transform IR operations with regions, the client is expected
/// to create a `RegionScope` RAII object to create a new "stack frame" for
/// values defined inside the region. The mappings from and to these values will
/// be automatically dropped when the object goes out of scope, typically at the
/// end of the `apply` function of the parent operation. If a region contains
/// blocks with arguments, the client can map those arguments to payload IR ops
/// using `mapBlockArguments`.
class TransformState {};

/// Local mapping between values defined by a specific op implementing the
/// TransformOpInterface and the payload IR ops they correspond to.
class TransformResults {};

/// Creates a RAII object the lifetime of which corresponds to the new mapping
/// for transform IR values defined in the given region. Values defined in
/// surrounding regions remain accessible.
TransformState::RegionScope TransformState::make_region_scope(Region &region) {}

/// A configuration object for customizing a `TrackingListener`.
struct TrackingListenerConfig {};

/// A listener that updates a TransformState based on IR modifications. This
/// listener can be used during a greedy pattern rewrite to keep the transform
/// state up-to-date.
class TrackingListener : public RewriterBase::Listener,
                         public TransformState::Extension {};

/// A specialized listener that keeps track of cases in which no replacement
/// payload could be found. The error state of this listener must be checked
/// before the end of its lifetime.
class ErrorCheckingTrackingListener : public TrackingListener {};

/// This is a special rewriter to be used in transform op implementations,
/// providing additional helper functions to update the transform state, etc.
// TODO: Helper functions will be added in a subsequent change.
class TransformRewriter : public RewriterBase {};

/// This trait is supposed to be attached to Transform dialect operations that
/// can be standalone top-level transforms. Such operations typically contain
/// other Transform dialect operations that can be executed following some
/// control flow logic specific to the current operation. The operations with
/// this trait are expected to have at least one single-block region with at
/// least one argument of type implementing TransformHandleTypeInterface. The
/// operations are also expected to be valid without operands, in which case
/// they are considered top-level, and with one or more arguments, in which case
/// they are considered nested. Top-level operations have the block argument of
/// the entry block in the Transform IR correspond to the root operation of
/// Payload IR. Nested operations have the block argument of the entry block in
/// the Transform IR correspond to a list of Payload IR operations mapped to the
/// first operand of the Transform IR operation. The operation must implement
/// TransformOpInterface.
template <typename OpTy>
class PossibleTopLevelTransformOpTrait
    : public OpTrait::TraitBase<OpTy, PossibleTopLevelTransformOpTrait> {};

class ApplyToEachResultList;

/// Trait implementing the TransformOpInterface for operations applying a
/// transformation to a single operation handle and producing an arbitrary
/// number of handles and parameter values.
/// The op must implement a method with the following signature:
///   - DiagnosedSilenceableFailure applyToOne(OpTy,
///       ApplyToEachResultList &results, TransformState &state)
/// to perform a transformation that is applied in turn to all payload IR
/// operations that correspond to the handle of the transform IR operation.
/// In `applyToOne`, OpTy is either Operation* or a concrete payload IR Op class
/// that the transformation is applied to (and NOT the class of the transform IR
/// op).
/// The `applyToOne` method takes an empty `results` vector that it fills with
/// zero, one or multiple operations depending on the number of results expected
/// by the transform op.
/// The number of results must match the number of results of the transform op.
/// `applyToOne` is allowed to fill the `results` with all null elements to
/// signify that the transformation did not apply to the payload IR operations.
/// Such null elements are filtered out from results before return.
///
/// The transform op having this trait is expected to have a single operand.
template <typename OpTy>
class TransformEachOpTrait
    : public OpTrait::TraitBase<OpTy, TransformEachOpTrait> {};

/// Side effect resource corresponding to the mapping between Transform IR
/// values and Payload IR operations. An Allocate effect from this resource
/// means creating a new mapping entry, it is always accompanied by a Write
/// effect. A Read effect from this resource means accessing the mapping. A Free
/// effect on this resource indicates the removal of the mapping entry,
/// typically after a transformation that modifies the Payload IR operations
/// associated with one of the Transform IR operation's operands. It is always
/// accompanied by a Read effect. Read-after-Free and double-Free are not
/// allowed (they would be problematic with "regular" memory effects too) as
/// they indicate an attempt to access Payload IR operations that have been
/// modified, potentially erased, by the previous transformations.
// TODO: consider custom effects if these are not enabling generic passes such
// as CSE/DCE to work.
struct TransformMappingResource
    : public SideEffects::Resource::Base<TransformMappingResource> {};

/// Side effect resource corresponding to the Payload IR itself. Only Read and
/// Write effects are expected on this resource, with Write always accompanied
/// by a Read (short of fully replacing the top-level Payload IR operation, one
/// cannot modify the Payload IR without reading it first). This is intended
/// to disallow reordering of Transform IR operations that mutate the Payload IR
/// while still allowing the reordering of those that only access it.
struct PayloadIRResource
    : public SideEffects::Resource::Base<PayloadIRResource> {};

/// Populates `effects` with the memory effects indicating the operation on the
/// given handle value:
///   - consumes = Read + Free,
///   - produces = Allocate + Write,
///   - onlyReads = Read.
void consumesHandle(MutableArrayRef<OpOperand> handles,
                    SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
void producesHandle(ResultRange handles,
                    SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
void producesHandle(MutableArrayRef<BlockArgument> handles,
                    SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
void onlyReadsHandle(MutableArrayRef<OpOperand> handles,
                     SmallVectorImpl<MemoryEffects::EffectInstance> &effects);

/// Checks whether the transform op consumes the given handle.
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);

/// Populates `effects` with the memory effects indicating the access to payload
/// IR resource.
void modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
void onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);

/// Checks whether the transform op modifies the payload.
bool doesModifyPayload(transform::TransformOpInterface transform);
/// Checks whether the transform op reads the payload.
bool doesReadPayload(transform::TransformOpInterface transform);

/// Populates `consumedArguments` with positions of `block` arguments that are
/// consumed by the operations in the `block`.
void getConsumedBlockArguments(
    Block &block, llvm::SmallDenseSet<unsigned> &consumedArguments);

/// Trait implementing the MemoryEffectOpInterface for operations that "consume"
/// their operands and produce new results.
template <typename OpTy>
class FunctionalStyleTransformOpTrait
    : public OpTrait::TraitBase<OpTy, FunctionalStyleTransformOpTrait> {};

/// Trait implementing the MemoryEffectOpInterface for operations that use their
/// operands without consuming and without modifying the Payload IR to
/// potentially produce new handles.
template <typename OpTy>
class NavigationTransformOpTrait
    : public OpTrait::TraitBase<OpTy, NavigationTransformOpTrait> {};

namespace detail {
/// Non-template implementation of ParamProducerTransformOpTrait::getEffects().
void getParamProducerTransformOpTraitEffects(
    Operation *op, SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
/// Non-template implementation of ParamProducerTransformOpTrait::verify().
LogicalResult verifyParamProducerTransformOpTrait(Operation *op);
} // namespace detail

/// Trait implementing the MemoryEffectsOpInterface for operations that produce
/// transform dialect parameters. It marks all op results of
/// TransformHandleTypeInterface as produced by the op, all operands as only
/// read by the op and, if at least one of the operand is a handle to payload
/// ops, the entire payload as potentially read. The op must only produce
/// parameter-typed results.
template <typename OpTy>
class ParamProducerTransformOpTrait
    : public OpTrait::TraitBase<OpTy, ParamProducerTransformOpTrait> {};

/// `TrackingListener` failures are reported only for ops that have this trait.
/// The purpose of this trait is to give users more time to update their custom
/// transform ops to use the provided `TransformRewriter` for all IR
/// modifications. This trait will eventually be removed, and failures will be
/// reported for all transform ops.
template <typename OpTy>
class ReportTrackingListenerFailuresOpTrait
    : public OpTrait::TraitBase<OpTy, ReportTrackingListenerFailuresOpTrait> {};

/// A single result of applying a transform op with `ApplyEachOpTrait` to a
/// single payload operation.
ApplyToEachResult;

/// A list of results of applying a transform op with `ApplyEachOpTrait` to a
/// single payload operation, co-indexed with the results of the transform op.
class ApplyToEachResultList {};

namespace detail {

/// Check that the contents of `partialResult` matches the number, kind (payload
/// op or parameter) and nullity (either all or none) requirements of
/// `transformOp`. Report errors and return failure otherwise.
LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc,
                              const ApplyToEachResultList &partialResult);

/// "Transpose" the results produced by individual applications, arranging them
/// per result value of the transform op, and populate `transformResults` with
/// that. The number, kind and nullity of per-application results are assumed to
/// have been verified.
void setApplyToOneResults(Operation *transformOp,
                          TransformResults &transformResults,
                          ArrayRef<ApplyToEachResultList> results);

/// Applies a one-to-one or a one-to-many transform to each of the given
/// targets. Puts the results of transforms, if any, in `results` in the same
/// order. Fails if any of the application fails. Individual transforms must be
/// callable with the following signature:
///   - DiagnosedSilenceableFailure(OpTy,
///       SmallVector<Operation*> &results, state)
/// where OpTy is either
///   - Operation *, in which case the transform is always applied;
///   - a concrete Op class, in which case a check is performed whether
///   `targets` contains operations of the same class and a silenceable failure
///   is reported if it does not.
template <typename TransformOpTy, typename Range>
DiagnosedSilenceableFailure applyTransformToEach(
    TransformOpTy transformOp, TransformRewriter &rewriter, Range &&targets,
    SmallVectorImpl<ApplyToEachResultList> &results, TransformState &state) {}

/// Reports an error and returns failure if `targets` contains an ancestor
/// operation before its descendant (or a copy of itself). Implementation detail
/// for expensive checks during `TransformEachOpTrait::apply`.
LogicalResult checkNestedConsumption(Location loc,
                                     ArrayRef<Operation *> targets);

} // namespace detail
} // namespace transform
} // namespace mlir

template <typename OpTy>
mlir::DiagnosedSilenceableFailure
mlir::transform::TransformEachOpTrait<OpTy>::apply(
    TransformRewriter &rewriter, TransformResults &transformResults,
    TransformState &state) {}

template <typename OpTy>
llvm::LogicalResult
mlir::transform::TransformEachOpTrait<OpTy>::verifyTrait(Operation *op) {}

#endif // DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H