llvm/mlir/include/mlir/IR/PatternBase.td

//===-- PatternBase.td - Base pattern definition file ------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This files contains all of the base constructs for defining DRR patterns.
//
//===----------------------------------------------------------------------===//

#ifndef PATTERNBASE_TD
#define PATTERNBASE_TD

include "mlir/IR/OpBase.td"

//===----------------------------------------------------------------------===//
// Pattern definitions
//===----------------------------------------------------------------------===//

// Marker used to identify the delta value added to the default benefit value.
def addBenefit;

// Base class for op+ -> op+ rewrite rules. These allow declaratively
// specifying rewrite rules.
//
// A rewrite rule contains two components: a source pattern and one or more
// result patterns. Each pattern is specified as a (recursive) DAG node (tree)
// in the form of `(node arg0, arg1, ...)`.
//
// The `node` are normally MLIR ops, but it can also be one of the directives
// listed later in this section.
//
// ## Symbol binding
//
// In the source pattern, `argN` can be used to specify matchers (e.g., using
// type/attribute type constraints, etc.) and bound to a name for later use.
// We can also bind names to op instances to reference them later in
// multi-entity constraints. Operands in the source pattern can have
// the same name. This bounds one operand to the name while verifying
// the rest are all equal.
//
//
// In the result pattern, `argN` can be used to refer to a previously bound
// name, with potential transformations (e.g., using tAttr, etc.). `argN` can
// itself be nested DAG node. We can also bound names to ops to reference
// them later in other result patterns.
//
// For example,
//
// ```
// def : Pattern<(OneResultOp1:$op1 $arg0, $arg1, $arg0),
//               [(OneResultOp2:$op2 $arg0, $arg1),
//                (OneResultOp3 $op2 (OneResultOp4))],
//               [(HasStaticShapePred $op1)]>;
// ```
//
// First `$arg0` and '$arg1' are bound to the `OneResultOp1`'s first
// and second arguments and used later to build `OneResultOp2`. Second '$arg0'
// is verified to be equal to the first '$arg0' operand.
// `$op1` is bound to `OneResultOp1` and used to check whether the result's
// shape is static. `$op2` is bound to `OneResultOp2` and used to
// build `OneResultOp3`.
//
// ## Multi-result op
//
// To create multi-result ops in result pattern, you can use a syntax similar
// to uni-result op, and it will act as a value pack for all results:
//
// ```
// def : Pattern<(ThreeResultOp ...),
//               [(TwoResultOp ...), (OneResultOp ...)]>;
// ```
//
// Then `TwoResultOp` will replace the first two values of `ThreeResultOp`.
//
// You can also use `$<name>__N` to explicitly access the N-th result.
// ```
// def : Pattern<(FiveResultOp ...),
//               [(TwoResultOp1:$res1__1 ...), (replaceWithValue $res1__0),
//                (TwoResultOp2:$res2 ...), (replaceWithValue $res2__1)]>;
// ```
//
// Then the values generated by `FiveResultOp` will be replaced by
//
// * `FiveResultOp`#0: `TwoResultOp1`#1
// * `FiveResultOp`#1: `TwoResultOp1`#0
// * `FiveResultOp`#2: `TwoResultOp2`#0
// * `FiveResultOp`#3: `TwoResultOp2`#1
// * `FiveResultOp`#4: `TwoResultOp2`#1
class Pattern<dag source, list<dag> results, list<dag> preds = [],
  list<dag> supplemental_results = [],
  dag benefitAdded = (addBenefit 0)> {
  dag sourcePattern = source;
  // Result patterns. Each result pattern is expected to replace one result
  // of the root op in the source pattern. In the case of more result patterns
  // than needed to replace the source op, only the last N results generated
  // by the last N result pattern is used to replace a N-result source op.
  // So that the beginning result patterns can be used to generate additional
  // ops to aid building the results used for replacement.
  list<dag> resultPatterns = results;
  // Multi-entity constraints. Each constraint here involves multiple entities
  // matched in source pattern and places further constraints on them as a
  // whole.
  list<dag> constraints = preds;
  // Optional patterns that are executed after the result patterns. Similar to
  // auxiliary patterns, they are not used for replacement. These patterns can
  // be used to invoke additional code after the result patterns, e.g. copy
  // the attributes from the source op to the result ops.
  list<dag> supplementalPatterns = supplemental_results;
  // The delta value added to the default benefit value. The default value is
  // the number of ops in the source pattern. The rule with the highest final
  // benefit value will be applied first if there are multiple rules matches.
  // This delta value can be either positive or negative.
  dag benefitDelta = benefitAdded;
}

// Form of a pattern which produces a single result.
class Pat<dag pattern, dag result, list<dag> preds = [],
  list<dag> supplemental_results = [],
  dag benefitAdded = (addBenefit 0)> :
  Pattern<pattern, [result], preds, supplemental_results, benefitAdded>;

// Native code call wrapper. This allows invoking an arbitrary C++ expression
// to create an op operand/attribute or replace an op result.
//
// ## Placeholders
//
// If used as a DAG leaf, i.e., `(... NativeCodeCall<"...">:$arg, ...)`,
// the wrapped expression can take special placeholders listed below:
//
// * `$_builder` will be replaced by the current `mlir::PatternRewriter`.
// * `$_self` will be replaced by the defining operation in a source pattern.
//   E.g., `NativeCodeCall<"Foo($_self, &$0)> I32Attr:$attr)>`, `$_self` will be
//   replaced with the defining operation of the first operand of OneArgOp.
//
// If used as a DAG node, i.e., `(NativeCodeCall<"..."> <arg0>, ..., <argN>)`,
// then positional placeholders are also supported; placeholder `$N` in the
// wrapped C++ expression will be replaced by `<argN>`.
//
// ## Bind multiple results
//
// To bind multi-results and access the N-th result with `$<name>__N`, specify
// the number of return values in the template. Note that only `Value` type is
// supported for multiple results binding.

class NativeCodeCall<string expr, int returns = 1> {
  string expression = expr;
  int numReturns = returns;
}

class NativeCodeCallVoid<string expr> : NativeCodeCall<expr, 0>;

def ConstantLikeMatcher : NativeCodeCall<"::mlir::success("
    "::mlir::matchPattern($_self->getResult(0), ::mlir::m_Constant(&$0)))">;

//===----------------------------------------------------------------------===//
// Rewrite directives
//===----------------------------------------------------------------------===//

// Directive used in result pattern to indicate that no new op are generated,
// so to replace the matched DAG with an existing SSA value.
def replaceWithValue;

// Directive used in result patterns to specify the location of the generated
// op. This directive must be used as a trailing argument to op creation or
// native code calls.
//
// Usage:
// * Create a named location: `(location "myLocation")`
// * Copy the location of a captured symbol: `(location $arg)`
// * Create a fused location: `(location "metadata", $arg0, $arg1)`

def location;

// Directive used in result patterns to specify return types for a created op.
// This allows ops to be created without relying on type inference with
// `OpTraits` or an op builder with deduction.
//
// This directive must be used as a trailing argument to op creation.
//
// Specify one return type with a string literal:
//
// ```
// (AnOp $val, (returnType "$_builder.getI32Type()"))
// ```
//
// Pass a captured value to copy its return type:
//
// ```
// (AnOp $val, (returnType $val));
// ```
//
// Pass a native code call inside a DAG to create a new type with arguments.
//
// ```
// (AnOp $val,
//       (returnType (NativeCodeCall<"$_builder.getTupleType({$0})"> $val)));
// ```
//
// Specify multiple return types with multiple of any of the above.

def returnType;

// Directive used to specify the operands may be matched in either order. When
// two adjacents are marked with `either`, it'll try to match the operands in
// either ordering of constraints. Example:
//
// ```
// (TwoArgOp (either $firstArg, (AnOp $secondArg)))
// ```
// The above pattern will accept either `"test.TwoArgOp"(%I32Arg, %AnOpArg)` and
// `"test.TwoArgOp"(%AnOpArg, %I32Arg)`.
//
// Only operand is supported with `either` and note that an operation with
// `Commutative` trait doesn't imply that it'll have the same behavior than
// `either` while pattern matching.
def either;

// Directive used to match variadic operands. This directive only matches if
// the variadic operand has the same length as the specified formal
// sub-dags.
//
// ```
// (VariadicOp (variadic:$input1 $input1a, $input1b),
//             (variadic:$input2 $input2a, $input2b, $input2c),
//             $attr1, $attr2)
// ```
//
// The pattern above only matches if the `$input1` operand is of length 2,
// `$input2` is of length 3, and all sub-dags match respectively. The `$input1`
// symbol denotes the full variadic operand range. The `$input1a` symbol
// denotes the first operand in the variadic sub-operands.
def variadic;

//===----------------------------------------------------------------------===//
// Common value constraints
//===----------------------------------------------------------------------===//

def HasNoUseOf: Constraint<
    CPred<"$_self.use_empty()">, "has no use">;

#endif // PATTERNBASE_TD