llvm/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp

//===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Merge the multiple exit targets of a convergence region into a single block.
// Each exit target will be assigned a constant value, and a phi node + switch
// will allow the new exit target to re-route to the correct basic block.
//
//===----------------------------------------------------------------------===//

#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
#include "SPIRV.h"
#include "SPIRVSubtarget.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/CodeGen/IntrinsicLowering.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/InitializePasses.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/LoopSimplify.h"
#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"

using namespace llvm;

namespace llvm {
void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry &);

class SPIRVMergeRegionExitTargets : public FunctionPass {
public:
  static char ID;

  SPIRVMergeRegionExitTargets() : FunctionPass(ID) {
    initializeSPIRVMergeRegionExitTargetsPass(*PassRegistry::getPassRegistry());
  };

  // Gather all the successors of |BB|.
  // This function asserts if the terminator neither a branch, switch or return.
  std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
    std::unordered_set<BasicBlock *> output;
    auto *T = BB->getTerminator();

    if (auto *BI = dyn_cast<BranchInst>(T)) {
      output.insert(BI->getSuccessor(0));
      if (BI->isConditional())
        output.insert(BI->getSuccessor(1));
      return output;
    }

    if (auto *SI = dyn_cast<SwitchInst>(T)) {
      output.insert(SI->getDefaultDest());
      for (auto &Case : SI->cases())
        output.insert(Case.getCaseSuccessor());
      return output;
    }

    assert(isa<ReturnInst>(T) && "Unhandled terminator type.");
    return output;
  }

  /// Create a value in BB set to the value associated with the branch the block
  /// terminator will take.
  llvm::Value *createExitVariable(
      BasicBlock *BB,
      const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
    auto *T = BB->getTerminator();
    if (isa<ReturnInst>(T))
      return nullptr;

    IRBuilder<> Builder(BB);
    Builder.SetInsertPoint(T);

    if (auto *BI = dyn_cast<BranchInst>(T)) {

      BasicBlock *LHSTarget = BI->getSuccessor(0);
      BasicBlock *RHSTarget =
          BI->isConditional() ? BI->getSuccessor(1) : nullptr;

      Value *LHS = TargetToValue.count(LHSTarget) != 0
                       ? TargetToValue.at(LHSTarget)
                       : nullptr;
      Value *RHS = TargetToValue.count(RHSTarget) != 0
                       ? TargetToValue.at(RHSTarget)
                       : nullptr;

      if (LHS == nullptr || RHS == nullptr)
        return LHS == nullptr ? RHS : LHS;
      return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
    }

    // TODO: add support for switch cases.
    llvm_unreachable("Unhandled terminator type.");
  }

  /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
  void replaceBranchTargets(BasicBlock *BB,
                            const SmallPtrSet<BasicBlock *, 4> &ToReplace,
                            BasicBlock *NewTarget) {
    auto *T = BB->getTerminator();
    if (isa<ReturnInst>(T))
      return;

    if (auto *BI = dyn_cast<BranchInst>(T)) {
      for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
        if (ToReplace.count(BI->getSuccessor(i)) != 0)
          BI->setSuccessor(i, NewTarget);
      }
      return;
    }

    if (auto *SI = dyn_cast<SwitchInst>(T)) {
      for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
        if (ToReplace.count(SI->getSuccessor(i)) != 0)
          SI->setSuccessor(i, NewTarget);
      }
      return;
    }

    assert(false && "Unhandled terminator type.");
  }

  // Run the pass on the given convergence region, ignoring the sub-regions.
  // Returns true if the CFG changed, false otherwise.
  bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
                                       const SPIRV::ConvergenceRegion *CR) {
    // Gather all the exit targets for this region.
    SmallPtrSet<BasicBlock *, 4> ExitTargets;
    for (BasicBlock *Exit : CR->Exits) {
      for (BasicBlock *Target : gatherSuccessors(Exit)) {
        if (CR->Blocks.count(Target) == 0)
          ExitTargets.insert(Target);
      }
    }

    // If we have zero or one exit target, nothing do to.
    if (ExitTargets.size() <= 1)
      return false;

    // Create the new single exit target.
    auto F = CR->Entry->getParent();
    auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F);
    IRBuilder<> Builder(NewExitTarget);

    // CodeGen output needs to be stable. Using the set as-is would order
    // the targets differently depending on the allocation pattern.
    // Sorting per basic-block ordering in the function.
    std::vector<BasicBlock *> SortedExitTargets;
    std::vector<BasicBlock *> SortedExits;
    for (BasicBlock &BB : *F) {
      if (ExitTargets.count(&BB) != 0)
        SortedExitTargets.push_back(&BB);
      if (CR->Exits.count(&BB) != 0)
        SortedExits.push_back(&BB);
    }

    // Creating one constant per distinct exit target. This will be route to the
    // correct target.
    DenseMap<BasicBlock *, ConstantInt *> TargetToValue;
    for (BasicBlock *Target : SortedExitTargets)
      TargetToValue.insert(
          std::make_pair(Target, Builder.getInt32(TargetToValue.size())));

    // Creating one variable per exit node, set to the constant matching the
    // targeted external block.
    std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
    for (auto Exit : SortedExits) {
      llvm::Value *Value = createExitVariable(Exit, TargetToValue);
      ExitToVariable.emplace_back(std::make_pair(Exit, Value));
    }

    // Gather the correct value depending on the exit we came from.
    llvm::PHINode *node =
        Builder.CreatePHI(Builder.getInt32Ty(), ExitToVariable.size());
    for (auto [BB, Value] : ExitToVariable) {
      node->addIncoming(Value, BB);
    }

    // Creating the switch to jump to the correct exit target.
    llvm::SwitchInst *Sw = Builder.CreateSwitch(node, SortedExitTargets[0],
                                                SortedExitTargets.size() - 1);
    for (size_t i = 1; i < SortedExitTargets.size(); i++) {
      BasicBlock *BB = SortedExitTargets[i];
      Sw->addCase(TargetToValue[BB], BB);
    }

    // Fix exit branches to redirect to the new exit.
    for (auto Exit : CR->Exits)
      replaceBranchTargets(Exit, ExitTargets, NewExitTarget);

    return true;
  }

  /// Run the pass on the given convergence region and sub-regions (DFS).
  /// Returns true if a region/sub-region was modified, false otherwise.
  /// This returns as soon as one region/sub-region has been modified.
  bool runOnConvergenceRegion(LoopInfo &LI,
                              const SPIRV::ConvergenceRegion *CR) {
    for (auto *Child : CR->Children)
      if (runOnConvergenceRegion(LI, Child))
        return true;

    return runOnConvergenceRegionNoRecurse(LI, CR);
  }

#if !NDEBUG
  /// Validates each edge exiting the region has the same destination basic
  /// block.
  void validateRegionExits(const SPIRV::ConvergenceRegion *CR) {
    for (auto *Child : CR->Children)
      validateRegionExits(Child);

    std::unordered_set<BasicBlock *> ExitTargets;
    for (auto *Exit : CR->Exits) {
      auto Set = gatherSuccessors(Exit);
      for (auto *BB : Set) {
        if (CR->Blocks.count(BB) == 0)
          ExitTargets.insert(BB);
      }
    }

    assert(ExitTargets.size() <= 1);
  }
#endif

  virtual bool runOnFunction(Function &F) override {
    LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
    const auto *TopLevelRegion =
        getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
            .getRegionInfo()
            .getTopLevelRegion();

    // FIXME: very inefficient method: each time a region is modified, we bubble
    // back up, and recompute the whole convergence region tree. Once the
    // algorithm is completed and test coverage good enough, rewrite this pass
    // to be efficient instead of simple.
    bool modified = false;
    while (runOnConvergenceRegion(LI, TopLevelRegion)) {
      TopLevelRegion = getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
                           .getRegionInfo()
                           .getTopLevelRegion();
      modified = true;
    }

#if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS)
    validateRegionExits(TopLevelRegion);
#endif
    return modified;
  }

  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.addRequired<DominatorTreeWrapperPass>();
    AU.addRequired<LoopInfoWrapperPass>();
    AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
    FunctionPass::getAnalysisUsage(AU);
  }
};
} // namespace llvm

char SPIRVMergeRegionExitTargets::ID = 0;

INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
                      "SPIRV split region exit blocks", false, false)
INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)

INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
                    "SPIRV split region exit blocks", false, false)

FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() {
  return new SPIRVMergeRegionExitTargets();
}