llvm/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp

//===- PredicateTree.cpp - Predicate tree merging -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "PredicateTree.h"
#include "RootOrdering.h"

#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <queue>

#define DEBUG_TYPE

usingnamespacemlir;
usingnamespacemlir::pdl_to_pdl_interp;

//===----------------------------------------------------------------------===//
// Predicate List Building
//===----------------------------------------------------------------------===//

static void getTreePredicates(std::vector<PositionalPredicate> &predList,
                              Value val, PredicateBuilder &builder,
                              DenseMap<Value, Position *> &inputs,
                              Position *pos);

/// Compares the depths of two positions.
static bool comparePosDepth(Position *lhs, Position *rhs) {}

/// Returns the number of non-range elements within `values`.
static unsigned getNumNonRangeValues(ValueRange values) {}

static void getTreePredicates(std::vector<PositionalPredicate> &predList,
                              Value val, PredicateBuilder &builder,
                              DenseMap<Value, Position *> &inputs,
                              AttributePosition *pos) {}

/// Collect all of the predicates for the given operand position.
static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
                                     Value val, PredicateBuilder &builder,
                                     DenseMap<Value, Position *> &inputs,
                                     Position *pos) {}

static void
getTreePredicates(std::vector<PositionalPredicate> &predList, Value val,
                  PredicateBuilder &builder,
                  DenseMap<Value, Position *> &inputs, OperationPosition *pos,
                  std::optional<unsigned> ignoreOperand = std::nullopt) {}

static void getTreePredicates(std::vector<PositionalPredicate> &predList,
                              Value val, PredicateBuilder &builder,
                              DenseMap<Value, Position *> &inputs,
                              TypePosition *pos) {}

/// Collect the tree predicates anchored at the given value.
static void getTreePredicates(std::vector<PositionalPredicate> &predList,
                              Value val, PredicateBuilder &builder,
                              DenseMap<Value, Position *> &inputs,
                              Position *pos) {}

static void getAttributePredicates(pdl::AttributeOp op,
                                   std::vector<PositionalPredicate> &predList,
                                   PredicateBuilder &builder,
                                   DenseMap<Value, Position *> &inputs) {}

static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
                                    std::vector<PositionalPredicate> &predList,
                                    PredicateBuilder &builder,
                                    DenseMap<Value, Position *> &inputs) {}

static void getResultPredicates(pdl::ResultOp op,
                                std::vector<PositionalPredicate> &predList,
                                PredicateBuilder &builder,
                                DenseMap<Value, Position *> &inputs) {}

static void getResultPredicates(pdl::ResultsOp op,
                                std::vector<PositionalPredicate> &predList,
                                PredicateBuilder &builder,
                                DenseMap<Value, Position *> &inputs) {}

static void getTypePredicates(Value typeValue,
                              function_ref<Attribute()> typeAttrFn,
                              PredicateBuilder &builder,
                              DenseMap<Value, Position *> &inputs) {}

/// Collect all of the predicates that cannot be determined via walking the
/// tree.
static void getNonTreePredicates(pdl::PatternOp pattern,
                                 std::vector<PositionalPredicate> &predList,
                                 PredicateBuilder &builder,
                                 DenseMap<Value, Position *> &inputs) {}

namespace {

/// An op accepting a value at an optional index.
struct OpIndex {};

/// The parent and operand index of each operation for each root, stored
/// as a nested map [root][operation].
ParentMaps;

} // namespace

/// Given a pattern, determines the set of roots present in this pattern.
/// These are the operations whose results are not consumed by other operations.
static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {}

/// Given a list of candidate roots, builds the cost graph for connecting them.
/// The graph is formed by traversing the DAG of operations starting from each
/// root and marking the depth of each connector value (operand). Then we join
/// the candidate roots based on the common connector values, taking the one
/// with the minimum depth. Along the way, we compute, for each candidate root,
/// a mapping from each operation (in the DAG underneath this root) to its
/// parent operation and the corresponding operand index.
static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
                           ParentMaps &parentMaps) {}

/// Returns true if the operand at the given index needs to be queried using an
/// operand group, i.e., if it is variadic itself or follows a variadic operand.
static bool useOperandGroup(pdl::OperationOp op, unsigned index) {}

/// Visit a node during upward traversal.
static void visitUpward(std::vector<PositionalPredicate> &predList,
                        OpIndex opIndex, PredicateBuilder &builder,
                        DenseMap<Value, Position *> &valueToPosition,
                        Position *&pos, unsigned rootID) {}

/// Given a pattern operation, build the set of matcher predicates necessary to
/// match this pattern.
static Value buildPredicateList(pdl::PatternOp pattern,
                                PredicateBuilder &builder,
                                std::vector<PositionalPredicate> &predList,
                                DenseMap<Value, Position *> &valueToPosition) {}

//===----------------------------------------------------------------------===//
// Pattern Predicate Tree Merging
//===----------------------------------------------------------------------===//

namespace {

/// This class represents a specific predicate applied to a position, and
/// provides hashing and ordering operators. This class allows for computing a
/// frequence sum and ordering predicates based on a cost model.
struct OrderedPredicate {};

/// A DenseMapInfo for OrderedPredicate based solely on the position and
/// question.
struct OrderedPredicateDenseInfo {};

/// This class wraps a set of ordered predicates that are used within a specific
/// pattern operation.
struct OrderedPredicateList {};
} // namespace

/// Returns true if the given matcher refers to the same predicate as the given
/// ordered predicate. This means that the position and questions of the two
/// match.
static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {}

/// Get or insert a child matcher for the given parent switch node, given a
/// predicate and parent pattern.
std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
                                               OrderedPredicate *predicate,
                                               pdl::PatternOp pattern) {}

/// Build the matcher CFG by "pushing" patterns through by sorted predicate
/// order. A pattern will traverse as far as possible using common predicates
/// and then either diverge from the CFG or reach the end of a branch and start
/// creating new nodes.
static void propagatePattern(std::unique_ptr<MatcherNode> &node,
                             OrderedPredicateList &list,
                             std::vector<OrderedPredicate *>::iterator current,
                             std::vector<OrderedPredicate *>::iterator end) {}

/// Fold any switch nodes nested under `node` to boolean nodes when possible.
/// `node` is updated in-place if it is a switch.
static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {}

/// Insert an exit node at the end of the failure path of the `root`.
static void insertExitNode(std::unique_ptr<MatcherNode> *root) {}

/// Sorts the range begin/end with the partial order given by cmp.
template <typename Iterator, typename Compare>
static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {}

/// Returns true if 'b' depends on a result of 'a'.
static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {}

/// Given a module containing PDL pattern operations, generate a matcher tree
/// using the patterns within the given module and return the root matcher node.
std::unique_ptr<MatcherNode>
MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
                                 DenseMap<Value, Position *> &valueToPosition) {}

//===----------------------------------------------------------------------===//
// MatcherNode
//===----------------------------------------------------------------------===//

MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
                         std::unique_ptr<MatcherNode> failureNode)
    :{}

//===----------------------------------------------------------------------===//
// BoolNode
//===----------------------------------------------------------------------===//

BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
                   std::unique_ptr<MatcherNode> successNode,
                   std::unique_ptr<MatcherNode> failureNode)
    :{}

//===----------------------------------------------------------------------===//
// SuccessNode
//===----------------------------------------------------------------------===//

SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root,
                         std::unique_ptr<MatcherNode> failureNode)
    :{}

//===----------------------------------------------------------------------===//
// SwitchNode
//===----------------------------------------------------------------------===//

SwitchNode::SwitchNode(Position *position, Qualifier *question)
    :{}