//===- PDLPatternMatch.h - PDLPatternMatcher classes -------==---*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #ifndef MLIR_IR_PDLPATTERNMATCH_H #define MLIR_IR_PDLPATTERNMATCH_H #include "mlir/Config/mlir-config.h" #if MLIR_ENABLE_PDL_IN_PATTERNMATCH #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" namespace mlir { //===----------------------------------------------------------------------===// // PDL Patterns //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// // PDLValue /// Storage type of byte-code interpreter values. These are passed to constraint /// functions as arguments. class PDLValue { … }; inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) { … } inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) { … } //===----------------------------------------------------------------------===// // PDLResultList /// The class represents a list of PDL results, returned by a native rewrite /// method. It provides the mechanism with which to pass PDLValues back to the /// PDL bytecode. class PDLResultList { … }; //===----------------------------------------------------------------------===// // PDLPatternConfig /// An individual configuration for a pattern, which can be accessed by native /// functions via the PDLPatternConfigSet. This allows for injecting additional /// configuration into PDL patterns that is specific to certain compilation /// flows. class PDLPatternConfig { … }; /// This class provides a base class for users implementing a type of pattern /// configuration. template <typename T> class PDLPatternConfigBase : public PDLPatternConfig { … }; /// This class contains a set of configurations for a specific pattern. /// Configurations are uniqued by TypeID, meaning that only one configuration of /// each type is allowed. class PDLPatternConfigSet { … }; //===----------------------------------------------------------------------===// // PDLPatternModule /// A generic PDL pattern constraint function. This function applies a /// constraint to a given set of opaque PDLValue entities. Returns success if /// the constraint successfully held, failure otherwise. PDLConstraintFunction; /// A native PDL rewrite function. This function performs a rewrite on the /// given set of values. Any results from this rewrite that should be passed /// back to PDL should be added to the provided result list. This method is only /// invoked when the corresponding match was successful. Returns failure if an /// invariant of the rewrite was broken (certain rewriters may recover from /// partial pattern application). PDLRewriteFunction; namespace detail { namespace pdl_function_builder { /// A utility variable that always resolves to false. This is useful for static /// asserts that are always false, but only should fire in certain templated /// constructs. For example, if a templated function should never be called, the /// function could be defined as: /// /// template <typename T> /// void foo() { /// static_assert(always_false<T>, "This function should never be called"); /// } /// always_false; //===----------------------------------------------------------------------===// // PDL Function Builder: Type Processing //===----------------------------------------------------------------------===// /// This struct provides a convenient way to determine how to process a given /// type as either a PDL parameter, or a result value. This allows for /// supporting complex types in constraint and rewrite functions, without /// requiring the user to hand-write the necessary glue code themselves. /// Specializations of this class should implement the following methods to /// enable support as a PDL argument or result type: /// /// static LogicalResult verifyAsArg( /// function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue, /// size_t argIdx); /// /// * This method verifies that the given PDLValue is valid for use as a /// value of `T`. /// /// static T processAsArg(PDLValue pdlValue); /// /// * This method processes the given PDLValue as a value of `T`. /// /// static void processAsResult(PatternRewriter &, PDLResultList &results, /// const T &value); /// /// * This method processes the given value of `T` as the result of a /// function invocation. The method should package the value into an /// appropriate form and append it to the given result list. /// /// If the type `T` is based on a higher order value, consider using /// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify /// the implementation. /// template <typename T, typename Enable = void> struct ProcessPDLValue; /// This struct provides a simplified model for processing types that are based /// on another type, e.g. APInt is based on the handling for IntegerAttr. This /// allows for building the necessary processing functions on top of the base /// value instead of a PDLValue. Derived users should implement the following /// (which subsume the ProcessPDLValue variants): /// /// static LogicalResult verifyAsArg( /// function_ref<LogicalResult(const Twine &)> errorFn, /// const BaseT &baseValue, size_t argIdx); /// /// * This method verifies that the given PDLValue is valid for use as a /// value of `T`. /// /// static T processAsArg(BaseT baseValue); /// /// * This method processes the given base value as a value of `T`. /// template <typename T, typename BaseT> struct ProcessPDLValueBasedOn { … }; /// This struct provides a simplified model for processing types that have /// "builtin" PDLValue support: /// * Attribute, Operation *, Type, TypeRange, ValueRange template <typename T> struct ProcessBuiltinPDLValue { … }; /// This struct provides a simplified model for processing types that inherit /// from builtin PDLValue types. For example, derived attributes like /// IntegerAttr, derived types like IntegerType, derived operations like /// ModuleOp, Interfaces, etc. template <typename T, typename BaseT> struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn<T, BaseT> { … }; //===----------------------------------------------------------------------===// // Attribute template <> struct ProcessPDLValue<Attribute> : public ProcessBuiltinPDLValue<Attribute> { … }; ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Attribute, T>::value>>; /// Handling for various Attribute value types. template <> struct ProcessPDLValue<StringRef> : public ProcessPDLValueBasedOn<StringRef, StringAttr> { … }; template <> struct ProcessPDLValue<std::string> : public ProcessPDLValueBasedOn<std::string, StringAttr> { … }; //===----------------------------------------------------------------------===// // Operation template <> struct ProcessPDLValue<Operation *> : public ProcessBuiltinPDLValue<Operation *> { … }; ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>>; //===----------------------------------------------------------------------===// // Type template <> struct ProcessPDLValue<Type> : public ProcessBuiltinPDLValue<Type> { … }; ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>>; //===----------------------------------------------------------------------===// // TypeRange template <> struct ProcessPDLValue<TypeRange> : public ProcessBuiltinPDLValue<TypeRange> { … }; template <> struct ProcessPDLValue<ValueTypeRange<OperandRange>> { … }; template <> struct ProcessPDLValue<ValueTypeRange<ResultRange>> { … }; ProcessPDLValue<SmallVector<Type, N>>; //===----------------------------------------------------------------------===// // Value template <> struct ProcessPDLValue<Value> : public ProcessBuiltinPDLValue<Value> { … }; //===----------------------------------------------------------------------===// // ValueRange template <> struct ProcessPDLValue<ValueRange> : public ProcessBuiltinPDLValue<ValueRange> { … }; template <> struct ProcessPDLValue<OperandRange> { … }; template <> struct ProcessPDLValue<ResultRange> { … }; ProcessPDLValue<SmallVector<Value, N>>; //===----------------------------------------------------------------------===// // PDL Function Builder: Argument Handling //===----------------------------------------------------------------------===// /// Validate the given PDLValues match the constraints defined by the argument /// types of the given function. In the case of failure, a match failure /// diagnostic is emitted. /// FIXME: This should be completely removed in favor of `assertArgs`, but PDL /// does not currently preserve Constraint application ordering. template <typename PDLFnT, std::size_t... I> LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values, std::index_sequence<I...>) { … } /// Assert that the given PDLValues match the constraints defined by the /// arguments of the given function. In the case of failure, a fatal error /// is emitted. template <typename PDLFnT, std::size_t... I> void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values, std::index_sequence<I...>) { … } //===----------------------------------------------------------------------===// // PDL Function Builder: Results Handling //===----------------------------------------------------------------------===// /// Store a single result within the result list. template <typename T> static LogicalResult processResults(PatternRewriter &rewriter, PDLResultList &results, T &&value) { … } /// Store a std::pair<> as individual results within the result list. template <typename T1, typename T2> static LogicalResult processResults(PatternRewriter &rewriter, PDLResultList &results, std::pair<T1, T2> &&pair) { … } /// Store a std::tuple<> as individual results within the result list. template <typename... Ts> static LogicalResult processResults(PatternRewriter &rewriter, PDLResultList &results, std::tuple<Ts...> &&tuple) { … } /// Handle LogicalResult propagation. inline LogicalResult processResults(PatternRewriter &rewriter, PDLResultList &results, LogicalResult &&result) { … } template <typename T> static LogicalResult processResults(PatternRewriter &rewriter, PDLResultList &results, FailureOr<T> &&result) { … } //===----------------------------------------------------------------------===// // PDL Constraint Builder //===----------------------------------------------------------------------===// /// Process the arguments of a native constraint and invoke it. template <typename PDLFnT, std::size_t... I, typename FnTraitsT = llvm::function_traits<PDLFnT>> typename FnTraitsT::result_t processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter, ArrayRef<PDLValue> values, std::index_sequence<I...>) { … } /// Build a constraint function from the given function `ConstraintFnT`. This /// allows for enabling the user to define simpler, more direct constraint /// functions without needing to handle the low-level PDL goop. /// /// If the constraint function is already in the correct form, we just forward /// it directly. template <typename ConstraintFnT> std::enable_if_t< std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value, PDLConstraintFunction> buildConstraintFn(ConstraintFnT &&constraintFn) { … } /// Otherwise, we generate a wrapper that will unpack the PDLValues in the form /// we desire. template <typename ConstraintFnT> std::enable_if_t< !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value, PDLConstraintFunction> buildConstraintFn(ConstraintFnT &&constraintFn) { … } //===----------------------------------------------------------------------===// // PDL Rewrite Builder //===----------------------------------------------------------------------===// /// Process the arguments of a native rewrite and invoke it. /// This overload handles the case of no return values. template <typename PDLFnT, std::size_t... I, typename FnTraitsT = llvm::function_traits<PDLFnT>> std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value, LogicalResult> processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, PDLResultList &, ArrayRef<PDLValue> values, std::index_sequence<I...>) { … } /// This overload handles the case of return values, which need to be packaged /// into the result list. template <typename PDLFnT, std::size_t... I, typename FnTraitsT = llvm::function_traits<PDLFnT>> std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value, LogicalResult> processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, PDLResultList &results, ArrayRef<PDLValue> values, std::index_sequence<I...>) { … } /// Build a rewrite function from the given function `RewriteFnT`. This /// allows for enabling the user to define simpler, more direct rewrite /// functions without needing to handle the low-level PDL goop. /// /// If the rewrite function is already in the correct form, we just forward /// it directly. template <typename RewriteFnT> std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value, PDLRewriteFunction> buildRewriteFn(RewriteFnT &&rewriteFn) { … } /// Otherwise, we generate a wrapper that will unpack the PDLValues in the form /// we desire. template <typename RewriteFnT> std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value, PDLRewriteFunction> buildRewriteFn(RewriteFnT &&rewriteFn) { … } } // namespace pdl_function_builder } // namespace detail //===----------------------------------------------------------------------===// // PDLPatternModule /// This class contains all of the necessary data for a set of PDL patterns, or /// pattern rewrites specified in the form of the PDL dialect. This PDL module /// contained by this pattern may contain any number of `pdl.pattern` /// operations. class PDLPatternModule { … }; } // namespace mlir #else namespace mlir { // Stubs for when PDL in pattern rewrites is not enabled. class PDLValue { public: template <typename T> T dyn_cast() const { return nullptr; } }; class PDLResultList {}; using PDLConstraintFunction = std::function<LogicalResult( PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>; using PDLRewriteFunction = std::function<LogicalResult( PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>; class PDLPatternModule { public: PDLPatternModule() = default; PDLPatternModule(OwningOpRef<ModuleOp> /*module*/) {} MLIRContext *getContext() { llvm_unreachable("Error: PDL for rewrites when PDL is not enabled"); } void mergeIn(PDLPatternModule &&other) {} void clear() {} template <typename ConstraintFnT> void registerConstraintFunction(StringRef name, ConstraintFnT &&constraintFn) {} void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn) {} template <typename RewriteFnT> void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {} const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const { return constraintFunctions; } private: llvm::StringMap<PDLConstraintFunction> constraintFunctions; }; } // namespace mlir #endif #endif // MLIR_IR_PDLPATTERNMATCH_H