//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines the Reduction Tree Pass class. It provides a framework for // the implementation of different reduction passes in the MLIR Reduce tool. It // allows for custom specification of the variant generation behavior. It // implements methods that define the different possible traversals of the // reduction tree. // //===----------------------------------------------------------------------===// #include "mlir/IR/DialectInterface.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Reducer/Passes.h" #include "mlir/Reducer/ReductionNode.h" #include "mlir/Reducer/ReductionPatternInterface.h" #include "mlir/Reducer/Tester.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/ManagedStatic.h" namespace mlir { #define GEN_PASS_DEF_REDUCTIONTREE #include "mlir/Reducer/Passes.h.inc" } // namespace mlir usingnamespacemlir; /// We implicitly number each operation in the region and if an operation's /// number falls into rangeToKeep, we need to keep it and apply the given /// rewrite patterns on it. static void applyPatterns(Region ®ion, const FrozenRewritePatternSet &patterns, ArrayRef<ReductionNode::Range> rangeToKeep, bool eraseOpNotInRange) { … } /// We will apply the reducer patterns to the operations in the ranges specified /// by ReductionNode. Note that we are not able to remove an operation without /// replacing it with another valid operation. However, The validity of module /// reduction is based on the Tester provided by the user and that means certain /// invalid module is still interested by the use. Thus we provide an /// alternative way to remove operations, which is using `eraseOpNotInRange` to /// erase the operations not in the range specified by ReductionNode. template <typename IteratorType> static LogicalResult findOptimal(ModuleOp module, Region ®ion, const FrozenRewritePatternSet &patterns, const Tester &test, bool eraseOpNotInRange) { … } template <typename IteratorType> static LogicalResult findOptimal(ModuleOp module, Region ®ion, const FrozenRewritePatternSet &patterns, const Tester &test) { … } namespace { //===----------------------------------------------------------------------===// // Reduction Pattern Interface Collection //===----------------------------------------------------------------------===// class ReductionPatternInterfaceCollection : public DialectInterfaceCollection<DialectReductionPatternInterface> { … }; //===----------------------------------------------------------------------===// // ReductionTreePass //===----------------------------------------------------------------------===// /// This class defines the Reduction Tree Pass. It provides a framework to /// to implement a reduction pass using a tree structure to keep track of the /// generated reduced variants. class ReductionTreePass : public impl::ReductionTreeBase<ReductionTreePass> { … }; } // namespace LogicalResult ReductionTreePass::initialize(MLIRContext *context) { … } void ReductionTreePass::runOnOperation() { … } LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) { … } std::unique_ptr<Pass> mlir::createReductionTreePass() { … }