llvm/mlir/include/mlir/IR/Matchers.h

//===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file provides a simple and efficient mechanism for performing general
// tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's
// include/llvm/IR/PatternMatch.h.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_IR_MATCHERS_H
#define MLIR_IR_MATCHERS_H

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"

namespace mlir {

namespace detail {

/// The matcher that matches a certain kind of Attribute and binds the value
/// inside the Attribute.
template <
    typename AttrClass,
    // Require AttrClass to be a derived class from Attribute and get its
    // value type
    typename ValueType = typename std::enable_if_t<
        std::is_base_of<Attribute, AttrClass>::value, AttrClass>::ValueType,
    // Require the ValueType is not void
    typename = std::enable_if_t<!std::is_void<ValueType>::value>>
struct attr_value_binder {
  ValueType *bind_value;

  /// Creates a matcher instance that binds the value to bv if match succeeds.
  attr_value_binder(ValueType *bv) :{}

  bool match(Attribute attr) {}
};

/// The matcher that matches operations that have the `ConstantLike` trait.
struct constant_op_matcher {};

/// The matcher that matches operations that have the specified op name.
struct NameOpMatcher {};

/// The matcher that matches operations that have the specified attribute name.
struct AttrOpMatcher {};

/// The matcher that matches operations that have the `ConstantLike` trait, and
/// binds the folded attribute value.
template <typename AttrT>
struct constant_op_binder {};

/// A matcher that matches operations that implement the
/// `InferIntRangeInterface` interface, and binds the inferred range.
struct infer_int_range_op_binder {};

/// The matcher that matches operations that have the specified attribute
/// name, and binds the attribute value.
template <typename AttrT>
struct AttrOpBinder {};

/// The matcher that matches a constant scalar / vector splat / tensor splat
/// float Attribute or Operation and binds the constant float value.
struct constant_float_value_binder {};

/// The matcher that matches a given target constant scalar / vector splat /
/// tensor splat float value that fulfills a predicate.
struct constant_float_predicate_matcher {};

/// The matcher that matches a constant scalar / vector splat / tensor splat
/// integer Attribute or Operation and binds the constant integer value.
struct constant_int_value_binder {};

/// The matcher that matches a given target constant scalar / vector splat /
/// tensor splat integer value that fulfills a predicate.
struct constant_int_predicate_matcher {};

/// A matcher that matches a given a constant scalar / vector splat / tensor
/// splat integer value or a constant integer range that fulfills a predicate.
struct constant_int_range_predicate_matcher {};

/// The matcher that matches a certain kind of op.
template <typename OpClass>
struct op_matcher {};

/// Trait to check whether T provides a 'match' method with type
/// `MatchTarget` (Value, Operation, or Attribute).
has_compatible_matcher_t;

/// Statically switch to a Value matcher.
template <typename MatcherClass>
std::enable_if_t<llvm::is_detected<detail::has_compatible_matcher_t,
                                   MatcherClass, Value>::value,
                 bool>
matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {}

/// Statically switch to an Operation matcher.
template <typename MatcherClass>
std::enable_if_t<llvm::is_detected<detail::has_compatible_matcher_t,
                                   MatcherClass, Operation *>::value,
                 bool>
matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {}

/// Terminal matcher, always returns true.
struct AnyValueMatcher {};

/// Terminal matcher, always returns true.
struct AnyCapturedValueMatcher {};

/// Binds to a specific value and matches it.
struct PatternMatcherValue {};

template <typename TupleT, class CallbackT, std::size_t... Is>
constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback,
                             std::index_sequence<Is...>) {}

template <typename... Tys, typename CallbackT>
constexpr void enumerate(std::tuple<Tys...> &tuple, CallbackT &&callback) {}

/// RecursivePatternMatcher that composes.
template <typename OpType, typename... OperandMatchers>
struct RecursivePatternMatcher {};

} // namespace detail

/// Matches a constant foldable operation.
inline detail::constant_op_matcher m_Constant() {}

/// Matches a named attribute operation.
inline detail::AttrOpMatcher m_Attr(StringRef attrName) {}

/// Matches a named operation.
inline detail::NameOpMatcher m_Op(StringRef opName) {}

/// Matches a value from a constant foldable operation and writes the value to
/// bind_value.
template <typename AttrT>
inline detail::constant_op_binder<AttrT> m_Constant(AttrT *bind_value) {}

/// Matches a named attribute operation and writes the value to bind_value.
template <typename AttrT>
inline detail::AttrOpBinder<AttrT> m_Attr(StringRef attrName,
                                          AttrT *bindValue) {}

/// Matches a constant scalar / vector splat / tensor splat float (both positive
/// and negative) zero.
inline detail::constant_float_predicate_matcher m_AnyZeroFloat() {}

/// Matches a constant scalar / vector splat / tensor splat float positive zero.
inline detail::constant_float_predicate_matcher m_PosZeroFloat() {}

/// Matches a constant scalar / vector splat / tensor splat float negative zero.
inline detail::constant_float_predicate_matcher m_NegZeroFloat() {}

/// Matches a constant scalar / vector splat / tensor splat float ones.
inline detail::constant_float_predicate_matcher m_OneFloat() {}

/// Matches a constant scalar / vector splat / tensor splat float positive
/// infinity.
inline detail::constant_float_predicate_matcher m_PosInfFloat() {}

/// Matches a constant scalar / vector splat / tensor splat float negative
/// infinity.
inline detail::constant_float_predicate_matcher m_NegInfFloat() {}

/// Matches a constant scalar / vector splat / tensor splat integer zero.
inline detail::constant_int_predicate_matcher m_Zero() {}

/// Matches a constant scalar / vector splat / tensor splat integer that is any
/// non-zero value.
inline detail::constant_int_predicate_matcher m_NonZero() {}

/// Matches a constant scalar / vector splat / tensor splat integer or a
/// unsigned integer range that does not contain zero. Note that this matcher
/// interprets the target value as an unsigned integer.
inline detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU() {}

/// Matches a constant scalar / vector splat / tensor splat integer or a
/// signed integer range that does not contain zero. Note that this matcher
/// interprets the target value as a signed integer.
inline detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS() {}

/// Matches a constant scalar / vector splat / tensor splat integer or a
/// signed integer range that does not contain minus one. Note
/// that this matcher interprets the target value as a signed integer.
inline detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS() {}

/// Matches a constant scalar / vector splat / tensor splat integer one.
inline detail::constant_int_predicate_matcher m_One() {}

/// Matches the given OpClass.
template <typename OpClass>
inline detail::op_matcher<OpClass> m_Op() {}

/// Entry point for matching a pattern over a Value.
template <typename Pattern>
inline bool matchPattern(Value value, const Pattern &pattern) {}

/// Entry point for matching a pattern over an Operation.
template <typename Pattern>
inline bool matchPattern(Operation *op, const Pattern &pattern) {}

/// Entry point for matching a pattern over an Attribute. Returns `false`
/// when `attr` is null.
template <typename Pattern>
inline bool matchPattern(Attribute attr, const Pattern &pattern) {}

/// Matches a constant holding a scalar/vector/tensor float (splat) and
/// writes the float value to bind_value.
inline detail::constant_float_value_binder
m_ConstantFloat(FloatAttr::ValueType *bind_value) {}

/// Matches a constant holding a scalar/vector/tensor integer (splat) and
/// writes the integer value to bind_value.
inline detail::constant_int_value_binder
m_ConstantInt(IntegerAttr::ValueType *bind_value) {}

template <typename OpType, typename... Matchers>
auto m_Op(Matchers... matchers) {}

namespace matchers {
inline auto m_Any() {}
inline auto m_Any(Value *val) {}
inline auto m_Val(Value v) {}
} // namespace matchers

} // namespace mlir

#endif // MLIR_IR_MATCHERS_H