llvm/mlir/lib/Transforms/Utils/CommutativityUtils.cpp

//===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a commutativity utility pattern and a function to
// populate this pattern. The function is intended to be used inside passes to
// simplify the matching of commutative operations by fixing the order of their
// operands.
//
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/CommutativityUtils.h"

#include <queue>

usingnamespacemlir;

/// The possible "types" of ancestors. Here, an ancestor is an op or a block
/// argument present in the backward slice of a value.
enum AncestorType {};

/// Stores the "key" associated with an ancestor.
struct AncestorKey {};

/// Stores a commutative operand along with its BFS traversal information.
struct CommutativeOperand {};

/// Sorts the operands of `op` in ascending order of the "key" associated with
/// each operand iff `op` is commutative. This is a stable sort.
///
/// After the application of this pattern, since the commutative operands now
/// have a deterministic order in which they occur in an op, the matching of
/// large DAGs becomes much simpler, i.e., requires much less number of checks
/// to be written by a user in her/his pattern matching function.
///
/// Some examples of such a sorting:
///
/// Assume that the sorting is being applied to `foo.commutative`, which is a
/// commutative op.
///
/// Example 1:
///
/// %1 = foo.const 0
/// %2 = foo.mul <block argument>, <block argument>
/// %3 = foo.commutative %1, %2
///
/// Here,
/// 1. The key associated with %1 is:
///     `{
///       {CONSTANT_OP, "foo.const"}
///      }`
/// 2. The key associated with %2 is:
///     `{
///       {NON_CONSTANT_OP, "foo.mul"},
///       {BLOCK_ARGUMENT, ""},
///       {BLOCK_ARGUMENT, ""}
///      }`
///
/// The key of %2 < the key of %1
/// Thus, the sorted `foo.commutative` is:
/// %3 = foo.commutative %2, %1
///
/// Example 2:
///
/// %1 = foo.const 0
/// %2 = foo.mul <block argument>, <block argument>
/// %3 = foo.mul %2, %1
/// %4 = foo.add %2, %1
/// %5 = foo.commutative %1, %2, %3, %4
///
/// Here,
/// 1. The key associated with %1 is:
///     `{
///       {CONSTANT_OP, "foo.const"}
///      }`
/// 2. The key associated with %2 is:
///     `{
///       {NON_CONSTANT_OP, "foo.mul"},
///       {BLOCK_ARGUMENT, ""}
///      }`
/// 3. The key associated with %3 is:
///     `{
///       {NON_CONSTANT_OP, "foo.mul"},
///       {NON_CONSTANT_OP, "foo.mul"},
///       {CONSTANT_OP, "foo.const"},
///       {BLOCK_ARGUMENT, ""},
///       {BLOCK_ARGUMENT, ""}
///      }`
/// 4. The key associated with %4 is:
///     `{
///       {NON_CONSTANT_OP, "foo.add"},
///       {NON_CONSTANT_OP, "foo.mul"},
///       {CONSTANT_OP, "foo.const"},
///       {BLOCK_ARGUMENT, ""},
///       {BLOCK_ARGUMENT, ""}
///      }`
///
/// Thus, the sorted `foo.commutative` is:
/// %5 = foo.commutative %4, %3, %2, %1
class SortCommutativeOperands : public RewritePattern {};

void mlir::populateCommutativityUtilsPatterns(RewritePatternSet &patterns) {}