llvm/mlir/include/mlir/IR/PDLPatternMatch.h.inc

//===- PDLPatternMatch.h - PDLPatternMatcher classes -------==---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_IR_PDLPATTERNMATCH_H
#define MLIR_IR_PDLPATTERNMATCH_H

#include "mlir/Config/mlir-config.h"

#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"

namespace mlir {
//===----------------------------------------------------------------------===//
// PDL Patterns
//===----------------------------------------------------------------------===//

//===----------------------------------------------------------------------===//
// PDLValue

/// Storage type of byte-code interpreter values. These are passed to constraint
/// functions as arguments.
class PDLValue {};

inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {}

inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) {}

//===----------------------------------------------------------------------===//
// PDLResultList

/// The class represents a list of PDL results, returned by a native rewrite
/// method. It provides the mechanism with which to pass PDLValues back to the
/// PDL bytecode.
class PDLResultList {};

//===----------------------------------------------------------------------===//
// PDLPatternConfig

/// An individual configuration for a pattern, which can be accessed by native
/// functions via the PDLPatternConfigSet. This allows for injecting additional
/// configuration into PDL patterns that is specific to certain compilation
/// flows.
class PDLPatternConfig {};

/// This class provides a base class for users implementing a type of pattern
/// configuration.
template <typename T>
class PDLPatternConfigBase : public PDLPatternConfig {};

/// This class contains a set of configurations for a specific pattern.
/// Configurations are uniqued by TypeID, meaning that only one configuration of
/// each type is allowed.
class PDLPatternConfigSet {};

//===----------------------------------------------------------------------===//
// PDLPatternModule

/// A generic PDL pattern constraint function. This function applies a
/// constraint to a given set of opaque PDLValue entities. Returns success if
/// the constraint successfully held, failure otherwise.
PDLConstraintFunction;

/// A native PDL rewrite function. This function performs a rewrite on the
/// given set of values. Any results from this rewrite that should be passed
/// back to PDL should be added to the provided result list. This method is only
/// invoked when the corresponding match was successful. Returns failure if an
/// invariant of the rewrite was broken (certain rewriters may recover from
/// partial pattern application).
PDLRewriteFunction;

namespace detail {
namespace pdl_function_builder {
/// A utility variable that always resolves to false. This is useful for static
/// asserts that are always false, but only should fire in certain templated
/// constructs. For example, if a templated function should never be called, the
/// function could be defined as:
///
/// template <typename T>
/// void foo() {
///  static_assert(always_false<T>, "This function should never be called");
/// }
///
always_false;

//===----------------------------------------------------------------------===//
// PDL Function Builder: Type Processing
//===----------------------------------------------------------------------===//

/// This struct provides a convenient way to determine how to process a given
/// type as either a PDL parameter, or a result value. This allows for
/// supporting complex types in constraint and rewrite functions, without
/// requiring the user to hand-write the necessary glue code themselves.
/// Specializations of this class should implement the following methods to
/// enable support as a PDL argument or result type:
///
///   static LogicalResult verifyAsArg(
///     function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue,
///     size_t argIdx);
///
///     * This method verifies that the given PDLValue is valid for use as a
///       value of `T`.
///
///   static T processAsArg(PDLValue pdlValue);
///
///     *  This method processes the given PDLValue as a value of `T`.
///
///   static void processAsResult(PatternRewriter &, PDLResultList &results,
///                               const T &value);
///
///     *  This method processes the given value of `T` as the result of a
///        function invocation. The method should package the value into an
///        appropriate form and append it to the given result list.
///
/// If the type `T` is based on a higher order value, consider using
/// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify
/// the implementation.
///
template <typename T, typename Enable = void>
struct ProcessPDLValue;

/// This struct provides a simplified model for processing types that are based
/// on another type, e.g. APInt is based on the handling for IntegerAttr. This
/// allows for building the necessary processing functions on top of the base
/// value instead of a PDLValue. Derived users should implement the following
/// (which subsume the ProcessPDLValue variants):
///
///   static LogicalResult verifyAsArg(
///     function_ref<LogicalResult(const Twine &)> errorFn,
///     const BaseT &baseValue, size_t argIdx);
///
///     * This method verifies that the given PDLValue is valid for use as a
///       value of `T`.
///
///   static T processAsArg(BaseT baseValue);
///
///     *  This method processes the given base value as a value of `T`.
///
template <typename T, typename BaseT>
struct ProcessPDLValueBasedOn {};

/// This struct provides a simplified model for processing types that have
/// "builtin" PDLValue support:
///   * Attribute, Operation *, Type, TypeRange, ValueRange
template <typename T>
struct ProcessBuiltinPDLValue {};

/// This struct provides a simplified model for processing types that inherit
/// from builtin PDLValue types. For example, derived attributes like
/// IntegerAttr, derived types like IntegerType, derived operations like
/// ModuleOp, Interfaces, etc.
template <typename T, typename BaseT>
struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn<T, BaseT> {};

//===----------------------------------------------------------------------===//
// Attribute

template <>
struct ProcessPDLValue<Attribute> : public ProcessBuiltinPDLValue<Attribute> {};
ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Attribute, T>::value>>;

/// Handling for various Attribute value types.
template <>
struct ProcessPDLValue<StringRef>
    : public ProcessPDLValueBasedOn<StringRef, StringAttr> {};
template <>
struct ProcessPDLValue<std::string>
    : public ProcessPDLValueBasedOn<std::string, StringAttr> {};

//===----------------------------------------------------------------------===//
// Operation

template <>
struct ProcessPDLValue<Operation *>
    : public ProcessBuiltinPDLValue<Operation *> {};
ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>>;

//===----------------------------------------------------------------------===//
// Type

template <>
struct ProcessPDLValue<Type> : public ProcessBuiltinPDLValue<Type> {};
ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>>;

//===----------------------------------------------------------------------===//
// TypeRange

template <>
struct ProcessPDLValue<TypeRange> : public ProcessBuiltinPDLValue<TypeRange> {};
template <>
struct ProcessPDLValue<ValueTypeRange<OperandRange>> {};
template <>
struct ProcessPDLValue<ValueTypeRange<ResultRange>> {};
ProcessPDLValue<SmallVector<Type, N>>;

//===----------------------------------------------------------------------===//
// Value

template <>
struct ProcessPDLValue<Value> : public ProcessBuiltinPDLValue<Value> {};

//===----------------------------------------------------------------------===//
// ValueRange

template <>
struct ProcessPDLValue<ValueRange> : public ProcessBuiltinPDLValue<ValueRange> {};
template <>
struct ProcessPDLValue<OperandRange> {};
template <>
struct ProcessPDLValue<ResultRange> {};
ProcessPDLValue<SmallVector<Value, N>>;

//===----------------------------------------------------------------------===//
// PDL Function Builder: Argument Handling
//===----------------------------------------------------------------------===//

/// Validate the given PDLValues match the constraints defined by the argument
/// types of the given function. In the case of failure, a match failure
/// diagnostic is emitted.
/// FIXME: This should be completely removed in favor of `assertArgs`, but PDL
/// does not currently preserve Constraint application ordering.
template <typename PDLFnT, std::size_t... I>
LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
                           std::index_sequence<I...>) {}

/// Assert that the given PDLValues match the constraints defined by the
/// arguments of the given function. In the case of failure, a fatal error
/// is emitted.
template <typename PDLFnT, std::size_t... I>
void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
                std::index_sequence<I...>) {}

//===----------------------------------------------------------------------===//
// PDL Function Builder: Results Handling
//===----------------------------------------------------------------------===//

/// Store a single result within the result list.
template <typename T>
static LogicalResult processResults(PatternRewriter &rewriter,
                                    PDLResultList &results, T &&value) {}

/// Store a std::pair<> as individual results within the result list.
template <typename T1, typename T2>
static LogicalResult processResults(PatternRewriter &rewriter,
                                    PDLResultList &results,
                                    std::pair<T1, T2> &&pair) {}

/// Store a std::tuple<> as individual results within the result list.
template <typename... Ts>
static LogicalResult processResults(PatternRewriter &rewriter,
                                    PDLResultList &results,
                                    std::tuple<Ts...> &&tuple) {}

/// Handle LogicalResult propagation.
inline LogicalResult processResults(PatternRewriter &rewriter,
                                    PDLResultList &results,
                                    LogicalResult &&result) {}
template <typename T>
static LogicalResult processResults(PatternRewriter &rewriter,
                                    PDLResultList &results,
                                    FailureOr<T> &&result) {}

//===----------------------------------------------------------------------===//
// PDL Constraint Builder
//===----------------------------------------------------------------------===//

/// Process the arguments of a native constraint and invoke it.
template <typename PDLFnT, std::size_t... I,
          typename FnTraitsT = llvm::function_traits<PDLFnT>>
typename FnTraitsT::result_t
processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter,
                               ArrayRef<PDLValue> values,
                               std::index_sequence<I...>) {}

/// Build a constraint function from the given function `ConstraintFnT`. This
/// allows for enabling the user to define simpler, more direct constraint
/// functions without needing to handle the low-level PDL goop.
///
/// If the constraint function is already in the correct form, we just forward
/// it directly.
template <typename ConstraintFnT>
std::enable_if_t<
    std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
    PDLConstraintFunction>
buildConstraintFn(ConstraintFnT &&constraintFn) {}
/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
/// we desire.
template <typename ConstraintFnT>
std::enable_if_t<
    !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
    PDLConstraintFunction>
buildConstraintFn(ConstraintFnT &&constraintFn) {}

//===----------------------------------------------------------------------===//
// PDL Rewrite Builder
//===----------------------------------------------------------------------===//

/// Process the arguments of a native rewrite and invoke it.
/// This overload handles the case of no return values.
template <typename PDLFnT, std::size_t... I,
          typename FnTraitsT = llvm::function_traits<PDLFnT>>
std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value,
                 LogicalResult>
processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
                            PDLResultList &, ArrayRef<PDLValue> values,
                            std::index_sequence<I...>) {}
/// This overload handles the case of return values, which need to be packaged
/// into the result list.
template <typename PDLFnT, std::size_t... I,
          typename FnTraitsT = llvm::function_traits<PDLFnT>>
std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value,
                 LogicalResult>
processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
                            PDLResultList &results, ArrayRef<PDLValue> values,
                            std::index_sequence<I...>) {}

/// Build a rewrite function from the given function `RewriteFnT`. This
/// allows for enabling the user to define simpler, more direct rewrite
/// functions without needing to handle the low-level PDL goop.
///
/// If the rewrite function is already in the correct form, we just forward
/// it directly.
template <typename RewriteFnT>
std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
                 PDLRewriteFunction>
buildRewriteFn(RewriteFnT &&rewriteFn) {}
/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
/// we desire.
template <typename RewriteFnT>
std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
                 PDLRewriteFunction>
buildRewriteFn(RewriteFnT &&rewriteFn) {}

} // namespace pdl_function_builder
} // namespace detail

//===----------------------------------------------------------------------===//
// PDLPatternModule

/// This class contains all of the necessary data for a set of PDL patterns, or
/// pattern rewrites specified in the form of the PDL dialect. This PDL module
/// contained by this pattern may contain any number of `pdl.pattern`
/// operations.
class PDLPatternModule {};
} // namespace mlir

#else

namespace mlir {
// Stubs for when PDL in pattern rewrites is not enabled.

class PDLValue {
public:
  template <typename T>
  T dyn_cast() const {
    return nullptr;
  }
};
class PDLResultList {};
using PDLConstraintFunction = std::function<LogicalResult(
    PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
using PDLRewriteFunction = std::function<LogicalResult(
    PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;

class PDLPatternModule {
public:
  PDLPatternModule() = default;

  PDLPatternModule(OwningOpRef<ModuleOp> /*module*/) {}
  MLIRContext *getContext() {
    llvm_unreachable("Error: PDL for rewrites when PDL is not enabled");
  }
  void mergeIn(PDLPatternModule &&other) {}
  void clear() {}
  template <typename ConstraintFnT>
  void registerConstraintFunction(StringRef name,
                                  ConstraintFnT &&constraintFn) {}
  void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn) {}
  template <typename RewriteFnT>
  void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {}
  const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
    return constraintFunctions;
  }

private:
  llvm::StringMap<PDLConstraintFunction> constraintFunctions;
};

} // namespace mlir
#endif

#endif // MLIR_IR_PDLPATTERNMATCH_H