//===- Matchers.h - Various common matchers ---------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file provides a simple and efficient mechanism for performing general // tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's // include/llvm/IR/PatternMatch.h. // //===----------------------------------------------------------------------===// #ifndef MLIR_IR_MATCHERS_H #define MLIR_IR_MATCHERS_H #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/InferIntRangeInterface.h" namespace mlir { namespace detail { /// The matcher that matches a certain kind of Attribute and binds the value /// inside the Attribute. template < typename AttrClass, // Require AttrClass to be a derived class from Attribute and get its // value type typename ValueType = typename std::enable_if_t< std::is_base_of<Attribute, AttrClass>::value, AttrClass>::ValueType, // Require the ValueType is not void typename = std::enable_if_t<!std::is_void<ValueType>::value>> struct attr_value_binder { ValueType *bind_value; /// Creates a matcher instance that binds the value to bv if match succeeds. attr_value_binder(ValueType *bv) : … { … } bool match(Attribute attr) { … } }; /// The matcher that matches operations that have the `ConstantLike` trait. struct constant_op_matcher { … }; /// The matcher that matches operations that have the specified op name. struct NameOpMatcher { … }; /// The matcher that matches operations that have the specified attribute name. struct AttrOpMatcher { … }; /// The matcher that matches operations that have the `ConstantLike` trait, and /// binds the folded attribute value. template <typename AttrT> struct constant_op_binder { … }; /// A matcher that matches operations that implement the /// `InferIntRangeInterface` interface, and binds the inferred range. struct infer_int_range_op_binder { … }; /// The matcher that matches operations that have the specified attribute /// name, and binds the attribute value. template <typename AttrT> struct AttrOpBinder { … }; /// The matcher that matches a constant scalar / vector splat / tensor splat /// float Attribute or Operation and binds the constant float value. struct constant_float_value_binder { … }; /// The matcher that matches a given target constant scalar / vector splat / /// tensor splat float value that fulfills a predicate. struct constant_float_predicate_matcher { … }; /// The matcher that matches a constant scalar / vector splat / tensor splat /// integer Attribute or Operation and binds the constant integer value. struct constant_int_value_binder { … }; /// The matcher that matches a given target constant scalar / vector splat / /// tensor splat integer value that fulfills a predicate. struct constant_int_predicate_matcher { … }; /// A matcher that matches a given a constant scalar / vector splat / tensor /// splat integer value or a constant integer range that fulfills a predicate. struct constant_int_range_predicate_matcher { … }; /// The matcher that matches a certain kind of op. template <typename OpClass> struct op_matcher { … }; /// Trait to check whether T provides a 'match' method with type /// `MatchTarget` (Value, Operation, or Attribute). has_compatible_matcher_t; /// Statically switch to a Value matcher. template <typename MatcherClass> std::enable_if_t<llvm::is_detected<detail::has_compatible_matcher_t, MatcherClass, Value>::value, bool> matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) { … } /// Statically switch to an Operation matcher. template <typename MatcherClass> std::enable_if_t<llvm::is_detected<detail::has_compatible_matcher_t, MatcherClass, Operation *>::value, bool> matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) { … } /// Terminal matcher, always returns true. struct AnyValueMatcher { … }; /// Terminal matcher, always returns true. struct AnyCapturedValueMatcher { … }; /// Binds to a specific value and matches it. struct PatternMatcherValue { … }; template <typename TupleT, class CallbackT, std::size_t... Is> constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback, std::index_sequence<Is...>) { … } template <typename... Tys, typename CallbackT> constexpr void enumerate(std::tuple<Tys...> &tuple, CallbackT &&callback) { … } /// RecursivePatternMatcher that composes. template <typename OpType, typename... OperandMatchers> struct RecursivePatternMatcher { … }; } // namespace detail /// Matches a constant foldable operation. inline detail::constant_op_matcher m_Constant() { … } /// Matches a named attribute operation. inline detail::AttrOpMatcher m_Attr(StringRef attrName) { … } /// Matches a named operation. inline detail::NameOpMatcher m_Op(StringRef opName) { … } /// Matches a value from a constant foldable operation and writes the value to /// bind_value. template <typename AttrT> inline detail::constant_op_binder<AttrT> m_Constant(AttrT *bind_value) { … } /// Matches a named attribute operation and writes the value to bind_value. template <typename AttrT> inline detail::AttrOpBinder<AttrT> m_Attr(StringRef attrName, AttrT *bindValue) { … } /// Matches a constant scalar / vector splat / tensor splat float (both positive /// and negative) zero. inline detail::constant_float_predicate_matcher m_AnyZeroFloat() { … } /// Matches a constant scalar / vector splat / tensor splat float positive zero. inline detail::constant_float_predicate_matcher m_PosZeroFloat() { … } /// Matches a constant scalar / vector splat / tensor splat float negative zero. inline detail::constant_float_predicate_matcher m_NegZeroFloat() { … } /// Matches a constant scalar / vector splat / tensor splat float ones. inline detail::constant_float_predicate_matcher m_OneFloat() { … } /// Matches a constant scalar / vector splat / tensor splat float positive /// infinity. inline detail::constant_float_predicate_matcher m_PosInfFloat() { … } /// Matches a constant scalar / vector splat / tensor splat float negative /// infinity. inline detail::constant_float_predicate_matcher m_NegInfFloat() { … } /// Matches a constant scalar / vector splat / tensor splat integer zero. inline detail::constant_int_predicate_matcher m_Zero() { … } /// Matches a constant scalar / vector splat / tensor splat integer that is any /// non-zero value. inline detail::constant_int_predicate_matcher m_NonZero() { … } /// Matches a constant scalar / vector splat / tensor splat integer or a /// unsigned integer range that does not contain zero. Note that this matcher /// interprets the target value as an unsigned integer. inline detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU() { … } /// Matches a constant scalar / vector splat / tensor splat integer or a /// signed integer range that does not contain zero. Note that this matcher /// interprets the target value as a signed integer. inline detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS() { … } /// Matches a constant scalar / vector splat / tensor splat integer or a /// signed integer range that does not contain minus one. Note /// that this matcher interprets the target value as a signed integer. inline detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS() { … } /// Matches a constant scalar / vector splat / tensor splat integer one. inline detail::constant_int_predicate_matcher m_One() { … } /// Matches the given OpClass. template <typename OpClass> inline detail::op_matcher<OpClass> m_Op() { … } /// Entry point for matching a pattern over a Value. template <typename Pattern> inline bool matchPattern(Value value, const Pattern &pattern) { … } /// Entry point for matching a pattern over an Operation. template <typename Pattern> inline bool matchPattern(Operation *op, const Pattern &pattern) { … } /// Entry point for matching a pattern over an Attribute. Returns `false` /// when `attr` is null. template <typename Pattern> inline bool matchPattern(Attribute attr, const Pattern &pattern) { … } /// Matches a constant holding a scalar/vector/tensor float (splat) and /// writes the float value to bind_value. inline detail::constant_float_value_binder m_ConstantFloat(FloatAttr::ValueType *bind_value) { … } /// Matches a constant holding a scalar/vector/tensor integer (splat) and /// writes the integer value to bind_value. inline detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value) { … } template <typename OpType, typename... Matchers> auto m_Op(Matchers... matchers) { … } namespace matchers { inline auto m_Any() { … } inline auto m_Any(Value *val) { … } inline auto m_Val(Value v) { … } } // namespace matchers } // namespace mlir #endif // MLIR_IR_MATCHERS_H