//===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a commutativity utility pattern and a function to // populate this pattern. The function is intended to be used inside passes to // simplify the matching of commutative operations by fixing the order of their // operands. // //===----------------------------------------------------------------------===// #include "mlir/Transforms/CommutativityUtils.h" #include <queue> usingnamespacemlir; /// The possible "types" of ancestors. Here, an ancestor is an op or a block /// argument present in the backward slice of a value. enum AncestorType { … }; /// Stores the "key" associated with an ancestor. struct AncestorKey { … }; /// Stores a commutative operand along with its BFS traversal information. struct CommutativeOperand { … }; /// Sorts the operands of `op` in ascending order of the "key" associated with /// each operand iff `op` is commutative. This is a stable sort. /// /// After the application of this pattern, since the commutative operands now /// have a deterministic order in which they occur in an op, the matching of /// large DAGs becomes much simpler, i.e., requires much less number of checks /// to be written by a user in her/his pattern matching function. /// /// Some examples of such a sorting: /// /// Assume that the sorting is being applied to `foo.commutative`, which is a /// commutative op. /// /// Example 1: /// /// %1 = foo.const 0 /// %2 = foo.mul <block argument>, <block argument> /// %3 = foo.commutative %1, %2 /// /// Here, /// 1. The key associated with %1 is: /// `{ /// {CONSTANT_OP, "foo.const"} /// }` /// 2. The key associated with %2 is: /// `{ /// {NON_CONSTANT_OP, "foo.mul"}, /// {BLOCK_ARGUMENT, ""}, /// {BLOCK_ARGUMENT, ""} /// }` /// /// The key of %2 < the key of %1 /// Thus, the sorted `foo.commutative` is: /// %3 = foo.commutative %2, %1 /// /// Example 2: /// /// %1 = foo.const 0 /// %2 = foo.mul <block argument>, <block argument> /// %3 = foo.mul %2, %1 /// %4 = foo.add %2, %1 /// %5 = foo.commutative %1, %2, %3, %4 /// /// Here, /// 1. The key associated with %1 is: /// `{ /// {CONSTANT_OP, "foo.const"} /// }` /// 2. The key associated with %2 is: /// `{ /// {NON_CONSTANT_OP, "foo.mul"}, /// {BLOCK_ARGUMENT, ""} /// }` /// 3. The key associated with %3 is: /// `{ /// {NON_CONSTANT_OP, "foo.mul"}, /// {NON_CONSTANT_OP, "foo.mul"}, /// {CONSTANT_OP, "foo.const"}, /// {BLOCK_ARGUMENT, ""}, /// {BLOCK_ARGUMENT, ""} /// }` /// 4. The key associated with %4 is: /// `{ /// {NON_CONSTANT_OP, "foo.add"}, /// {NON_CONSTANT_OP, "foo.mul"}, /// {CONSTANT_OP, "foo.const"}, /// {BLOCK_ARGUMENT, ""}, /// {BLOCK_ARGUMENT, ""} /// }` /// /// Thus, the sorted `foo.commutative` is: /// %5 = foo.commutative %4, %3, %2, %1 class SortCommutativeOperands : public RewritePattern { … }; void mlir::populateCommutativityUtilsPatterns(RewritePatternSet &patterns) { … }