llvm/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp

//===-- SPIRVStructurizer.cpp ----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
//===----------------------------------------------------------------------===//

#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
#include "SPIRV.h"
#include "SPIRVSubtarget.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/CodeGen/IntrinsicLowering.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/InitializePasses.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/LoopSimplify.h"
#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
#include <queue>
#include <stack>
#include <unordered_set>

using namespace llvm;
using namespace SPIRV;

namespace llvm {

void initializeSPIRVStructurizerPass(PassRegistry &);

namespace {

using BlockSet = std::unordered_set<BasicBlock *>;
using Edge = std::pair<BasicBlock *, BasicBlock *>;

// Helper function to do a partial order visit from the block |Start|, calling
// |Op| on each visited node.
void partialOrderVisit(BasicBlock &Start,
                       std::function<bool(BasicBlock *)> Op) {
  PartialOrderingVisitor V(*Start.getParent());
  V.partialOrderVisit(Start, Op);
}

// Returns the exact convergence region in the tree defined by `Node` for which
// `BB` is the header, nullptr otherwise.
const ConvergenceRegion *getRegionForHeader(const ConvergenceRegion *Node,
                                            BasicBlock *BB) {
  if (Node->Entry == BB)
    return Node;

  for (auto *Child : Node->Children) {
    const auto *CR = getRegionForHeader(Child, BB);
    if (CR != nullptr)
      return CR;
  }
  return nullptr;
}

// Returns the single BasicBlock exiting the convergence region `CR`,
// nullptr if no such exit exists.
BasicBlock *getExitFor(const ConvergenceRegion *CR) {
  std::unordered_set<BasicBlock *> ExitTargets;
  for (BasicBlock *Exit : CR->Exits) {
    for (BasicBlock *Successor : successors(Exit)) {
      if (CR->Blocks.count(Successor) == 0)
        ExitTargets.insert(Successor);
    }
  }

  assert(ExitTargets.size() <= 1);
  if (ExitTargets.size() == 0)
    return nullptr;

  return *ExitTargets.begin();
}

// Returns the merge block designated by I if I is a merge instruction, nullptr
// otherwise.
BasicBlock *getDesignatedMergeBlock(Instruction *I) {
  IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
  if (II == nullptr)
    return nullptr;

  if (II->getIntrinsicID() != Intrinsic::spv_loop_merge &&
      II->getIntrinsicID() != Intrinsic::spv_selection_merge)
    return nullptr;

  BlockAddress *BA = cast<BlockAddress>(II->getOperand(0));
  return BA->getBasicBlock();
}

// Returns the continue block designated by I if I is an OpLoopMerge, nullptr
// otherwise.
BasicBlock *getDesignatedContinueBlock(Instruction *I) {
  IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
  if (II == nullptr)
    return nullptr;

  if (II->getIntrinsicID() != Intrinsic::spv_loop_merge)
    return nullptr;

  BlockAddress *BA = cast<BlockAddress>(II->getOperand(1));
  return BA->getBasicBlock();
}

// Returns true if Header has one merge instruction which designated Merge as
// merge block.
bool isDefinedAsSelectionMergeBy(BasicBlock &Header, BasicBlock &Merge) {
  for (auto &I : Header) {
    BasicBlock *MB = getDesignatedMergeBlock(&I);
    if (MB == &Merge)
      return true;
  }
  return false;
}

// Returns true if the BB has one OpLoopMerge instruction.
bool hasLoopMergeInstruction(BasicBlock &BB) {
  for (auto &I : BB)
    if (getDesignatedContinueBlock(&I))
      return true;
  return false;
}

// Returns true is I is an OpSelectionMerge or OpLoopMerge instruction, false
// otherwise.
bool isMergeInstruction(Instruction *I) {
  return getDesignatedMergeBlock(I) != nullptr;
}

// Returns all blocks in F having at least one OpLoopMerge or OpSelectionMerge
// instruction.
SmallPtrSet<BasicBlock *, 2> getHeaderBlocks(Function &F) {
  SmallPtrSet<BasicBlock *, 2> Output;
  for (BasicBlock &BB : F) {
    for (Instruction &I : BB) {
      if (getDesignatedMergeBlock(&I) != nullptr)
        Output.insert(&BB);
    }
  }
  return Output;
}

// Returns all basic blocks in |F| referenced by at least 1
// OpSelectionMerge/OpLoopMerge instruction.
SmallPtrSet<BasicBlock *, 2> getMergeBlocks(Function &F) {
  SmallPtrSet<BasicBlock *, 2> Output;
  for (BasicBlock &BB : F) {
    for (Instruction &I : BB) {
      BasicBlock *MB = getDesignatedMergeBlock(&I);
      if (MB != nullptr)
        Output.insert(MB);
    }
  }
  return Output;
}

// Return all the merge instructions contained in BB.
// Note: the SPIR-V spec doesn't allow a single BB to contain more than 1 merge
// instruction, but this can happen while we structurize the CFG.
std::vector<Instruction *> getMergeInstructions(BasicBlock &BB) {
  std::vector<Instruction *> Output;
  for (Instruction &I : BB)
    if (isMergeInstruction(&I))
      Output.push_back(&I);
  return Output;
}

// Returns all basic blocks in |F| referenced as continue target by at least 1
// OpLoopMerge instruction.
SmallPtrSet<BasicBlock *, 2> getContinueBlocks(Function &F) {
  SmallPtrSet<BasicBlock *, 2> Output;
  for (BasicBlock &BB : F) {
    for (Instruction &I : BB) {
      BasicBlock *MB = getDesignatedContinueBlock(&I);
      if (MB != nullptr)
        Output.insert(MB);
    }
  }
  return Output;
}

// Do a preorder traversal of the CFG starting from the BB |Start|.
// point. Calls |op| on each basic block encountered during the traversal.
void visit(BasicBlock &Start, std::function<bool(BasicBlock *)> op) {
  std::stack<BasicBlock *> ToVisit;
  SmallPtrSet<BasicBlock *, 8> Seen;

  ToVisit.push(&Start);
  Seen.insert(ToVisit.top());
  while (ToVisit.size() != 0) {
    BasicBlock *BB = ToVisit.top();
    ToVisit.pop();

    if (!op(BB))
      continue;

    for (auto Succ : successors(BB)) {
      if (Seen.contains(Succ))
        continue;
      ToVisit.push(Succ);
      Seen.insert(Succ);
    }
  }
}

// Replaces the conditional and unconditional branch targets of |BB| by
// |NewTarget| if the target was |OldTarget|. This function also makes sure the
// associated merge instruction gets updated accordingly.
void replaceIfBranchTargets(BasicBlock *BB, BasicBlock *OldTarget,
                            BasicBlock *NewTarget) {
  auto *BI = cast<BranchInst>(BB->getTerminator());

  // 1. Replace all matching successors.
  for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
    if (BI->getSuccessor(i) == OldTarget)
      BI->setSuccessor(i, NewTarget);
  }

  // Branch was unconditional, no fixup required.
  if (BI->isUnconditional())
    return;

  // Branch had 2 successors, maybe now both are the same?
  if (BI->getSuccessor(0) != BI->getSuccessor(1))
    return;

  // Note: we may end up here because the original IR had such branches.
  // This means Target is not necessarily equal to NewTarget.
  IRBuilder<> Builder(BB);
  Builder.SetInsertPoint(BI);
  Builder.CreateBr(BI->getSuccessor(0));
  BI->eraseFromParent();

  // The branch was the only instruction, nothing else to do.
  if (BB->size() == 1)
    return;

  // Otherwise, we need to check: was there an OpSelectionMerge before this
  // branch? If we removed the OpBranchConditional, we must also remove the
  // OpSelectionMerge. This is not valid for OpLoopMerge:
  IntrinsicInst *II =
      dyn_cast<IntrinsicInst>(BB->getTerminator()->getPrevNode());
  if (!II || II->getIntrinsicID() != Intrinsic::spv_selection_merge)
    return;

  Constant *C = cast<Constant>(II->getOperand(0));
  II->eraseFromParent();
  if (!C->isConstantUsed())
    C->destroyConstant();
}

// Replaces the target of branch instruction in |BB| with |NewTarget| if it
// was |OldTarget|. This function also fixes the associated merge instruction.
// Note: this function does not simplify branching instructions, it only updates
// targets. See also: simplifyBranches.
void replaceBranchTargets(BasicBlock *BB, BasicBlock *OldTarget,
                          BasicBlock *NewTarget) {
  auto *T = BB->getTerminator();
  if (isa<ReturnInst>(T))
    return;

  if (isa<BranchInst>(T))
    return replaceIfBranchTargets(BB, OldTarget, NewTarget);

  if (auto *SI = dyn_cast<SwitchInst>(T)) {
    for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
      if (SI->getSuccessor(i) == OldTarget)
        SI->setSuccessor(i, NewTarget);
    }
    return;
  }

  assert(false && "Unhandled terminator type.");
}

// Replaces basic bloc operands |OldSrc| or OpPhi instructions in |BB| by
// |NewSrc|. This function does not simplify the OpPhi instruction once
// transformed.
void replacePhiTargets(BasicBlock *BB, BasicBlock *OldSrc, BasicBlock *NewSrc) {
  for (PHINode &Phi : BB->phis()) {
    int index = Phi.getBasicBlockIndex(OldSrc);
    if (index == -1)
      continue;
    Phi.setIncomingBlock(index, NewSrc);
  }
}

} // anonymous namespace

// Given a reducible CFG, produces a structurized CFG in the SPIR-V sense,
// adding merge instructions when required.
class SPIRVStructurizer : public FunctionPass {

  struct DivergentConstruct;
  // Represents a list of condition/loops/switch constructs.
  // See SPIR-V 2.11.2. Structured Control-flow Constructs for the list of
  // constructs.
  using ConstructList = std::vector<std::unique_ptr<DivergentConstruct>>;

  // Represents a divergent construct in the SPIR-V sense.
  // Such constructs are represented by a header (entry), a merge block (exit),
  // and possibly a continue block (back-edge). A construct can contain other
  // constructs, but their boundaries do not cross.
  struct DivergentConstruct {
    BasicBlock *Header = nullptr;
    BasicBlock *Merge = nullptr;
    BasicBlock *Continue = nullptr;

    DivergentConstruct *Parent = nullptr;
    ConstructList Children;
  };

  // An helper class to clean the construct boundaries.
  // It is used to gather the list of blocks that should belong to each
  // divergent construct, and possibly modify CFG edges when exits would cross
  // the boundary of multiple constructs.
  struct Splitter {
    Function &F;
    LoopInfo &LI;
    DomTreeBuilder::BBDomTree DT;
    DomTreeBuilder::BBPostDomTree PDT;

    Splitter(Function &F, LoopInfo &LI) : F(F), LI(LI) { invalidate(); }

    void invalidate() {
      PDT.recalculate(F);
      DT.recalculate(F);
    }

    // Returns the list of blocks that belong to a SPIR-V loop construct,
    // including the continue construct.
    std::vector<BasicBlock *> getLoopConstructBlocks(BasicBlock *Header,
                                                     BasicBlock *Merge) {
      assert(DT.dominates(Header, Merge));
      std::vector<BasicBlock *> Output;
      partialOrderVisit(*Header, [&](BasicBlock *BB) {
        if (BB == Merge)
          return false;
        if (DT.dominates(Merge, BB) || !DT.dominates(Header, BB))
          return false;
        Output.push_back(BB);
        return true;
      });
      return Output;
    }

    // Returns the list of blocks that belong to a SPIR-V selection construct.
    std::vector<BasicBlock *>
    getSelectionConstructBlocks(DivergentConstruct *Node) {
      assert(DT.dominates(Node->Header, Node->Merge));
      BlockSet OutsideBlocks;
      OutsideBlocks.insert(Node->Merge);

      for (DivergentConstruct *It = Node->Parent; It != nullptr;
           It = It->Parent) {
        OutsideBlocks.insert(It->Merge);
        if (It->Continue)
          OutsideBlocks.insert(It->Continue);
      }

      std::vector<BasicBlock *> Output;
      partialOrderVisit(*Node->Header, [&](BasicBlock *BB) {
        if (OutsideBlocks.count(BB) != 0)
          return false;
        if (DT.dominates(Node->Merge, BB) || !DT.dominates(Node->Header, BB))
          return false;
        Output.push_back(BB);
        return true;
      });
      return Output;
    }

    // Returns the list of blocks that belong to a SPIR-V switch construct.
    std::vector<BasicBlock *> getSwitchConstructBlocks(BasicBlock *Header,
                                                       BasicBlock *Merge) {
      assert(DT.dominates(Header, Merge));

      std::vector<BasicBlock *> Output;
      partialOrderVisit(*Header, [&](BasicBlock *BB) {
        // the blocks structurally dominated by a switch header,
        if (!DT.dominates(Header, BB))
          return false;
        // excluding blocks structurally dominated by the switch header’s merge
        // block.
        if (DT.dominates(Merge, BB) || BB == Merge)
          return false;
        Output.push_back(BB);
        return true;
      });
      return Output;
    }

    // Returns the list of blocks that belong to a SPIR-V case construct.
    std::vector<BasicBlock *> getCaseConstructBlocks(BasicBlock *Target,
                                                     BasicBlock *Merge) {
      assert(DT.dominates(Target, Merge));

      std::vector<BasicBlock *> Output;
      partialOrderVisit(*Target, [&](BasicBlock *BB) {
        // the blocks structurally dominated by an OpSwitch Target or Default
        // block
        if (!DT.dominates(Target, BB))
          return false;
        // excluding the blocks structurally dominated by the OpSwitch
        // construct’s corresponding merge block.
        if (DT.dominates(Merge, BB) || BB == Merge)
          return false;
        Output.push_back(BB);
        return true;
      });
      return Output;
    }

    // Splits the given edges by recreating proxy nodes so that the destination
    // OpPhi instruction can still be viable.
    //
    // clang-format off
    //
    // In SPIR-V, constructs must have a single exit/merge.
    // Given nodes A and B in the construct, a node C outside, and the following edges.
    //  A -> C
    //  B -> C
    //
    // In such cases, we must create a new exit node D, that belong to the construct to make is viable:
    // A -> D -> C
    // B -> D -> C
    //
    // But if C had a phi node, adding such proxy-block breaks it. In such case, we must add 1 new block per
    // exit, and patchup the phi node:
    // A -> D -> D1 -> C
    // B -> D -> D2 -> C
    //
    // A, B, D belongs to the construct. D is the exit. D1 and D2 are empty, just used as
    // source operands for C's phi node.
    //
    // clang-format on
    std::vector<Edge>
    createAliasBlocksForComplexEdges(std::vector<Edge> Edges) {
      std::unordered_map<BasicBlock *, BasicBlock *> Seen;
      std::vector<Edge> Output;
      Output.reserve(Edges.size());

      for (auto &[Src, Dst] : Edges) {
        auto [iterator, inserted] = Seen.insert({Src, Dst});
        if (inserted) {
          Output.emplace_back(Src, Dst);
          continue;
        }

        // The exact same edge was already seen. Ignoring.
        if (iterator->second == Dst)
          continue;

        // The same Src block branches to 2 distinct blocks. This will be an
        // issue for the generated OpPhi. Creating alias block.
        BasicBlock *NewSrc =
            BasicBlock::Create(F.getContext(), "new.exit.src", &F);
        replaceBranchTargets(Src, Dst, NewSrc);
        replacePhiTargets(Dst, Src, NewSrc);

        IRBuilder<> Builder(NewSrc);
        Builder.CreateBr(Dst);

        Seen.emplace(NewSrc, Dst);
        Output.emplace_back(NewSrc, Dst);
      }

      return Output;
    }

    // Given a construct defined by |Header|, and a list of exiting edges
    // |Edges|, creates a new single exit node, fixing up those edges.
    BasicBlock *createSingleExitNode(BasicBlock *Header,
                                     std::vector<Edge> &Edges) {
      auto NewExit = BasicBlock::Create(F.getContext(), "new.exit", &F);
      IRBuilder<> ExitBuilder(NewExit);

      std::vector<BasicBlock *> Dsts;
      std::unordered_map<BasicBlock *, ConstantInt *> DstToIndex;

      // Given 2 edges: Src1 -> Dst, Src2 -> Dst:
      // If Dst has an PHI node, and Src1 and Src2 are both operands, both Src1
      // and Src2 cannot be hidden by NewExit. Create 2 new nodes: Alias1,
      // Alias2 to which NewExit will branch before going to Dst. Then, patchup
      // Dst PHI node to look for Alias1 and Alias2.
      std::vector<Edge> FixedEdges = createAliasBlocksForComplexEdges(Edges);

      for (auto &[Src, Dst] : FixedEdges) {
        if (DstToIndex.count(Dst) != 0)
          continue;
        DstToIndex.emplace(Dst, ExitBuilder.getInt32(DstToIndex.size()));
        Dsts.push_back(Dst);
      }

      if (Dsts.size() == 1) {
        for (auto &[Src, Dst] : FixedEdges) {
          replaceBranchTargets(Src, Dst, NewExit);
          replacePhiTargets(Dst, Src, NewExit);
        }
        ExitBuilder.CreateBr(Dsts[0]);
        return NewExit;
      }

      PHINode *PhiNode =
          ExitBuilder.CreatePHI(ExitBuilder.getInt32Ty(), FixedEdges.size());

      for (auto &[Src, Dst] : FixedEdges) {
        PhiNode->addIncoming(DstToIndex[Dst], Src);
        replaceBranchTargets(Src, Dst, NewExit);
        replacePhiTargets(Dst, Src, NewExit);
      }

      // If we can avoid an OpSwitch, generate an OpBranch. Reason is some
      // OpBranch are allowed to exist without a new OpSelectionMerge if one of
      // the branch is the parent's merge node, while OpSwitches are not.
      if (Dsts.size() == 2) {
        Value *Condition = ExitBuilder.CreateCmp(CmpInst::ICMP_EQ,
                                                 DstToIndex[Dsts[0]], PhiNode);
        ExitBuilder.CreateCondBr(Condition, Dsts[0], Dsts[1]);
        return NewExit;
      }

      SwitchInst *Sw =
          ExitBuilder.CreateSwitch(PhiNode, Dsts[0], Dsts.size() - 1);
      for (auto It = Dsts.begin() + 1; It != Dsts.end(); ++It) {
        Sw->addCase(DstToIndex[*It], *It);
      }
      return NewExit;
    }
  };

  /// Create a value in BB set to the value associated with the branch the block
  /// terminator will take.
  Value *createExitVariable(
      BasicBlock *BB,
      const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
    auto *T = BB->getTerminator();
    if (isa<ReturnInst>(T))
      return nullptr;

    IRBuilder<> Builder(BB);
    Builder.SetInsertPoint(T);

    if (auto *BI = dyn_cast<BranchInst>(T)) {

      BasicBlock *LHSTarget = BI->getSuccessor(0);
      BasicBlock *RHSTarget =
          BI->isConditional() ? BI->getSuccessor(1) : nullptr;

      Value *LHS = TargetToValue.count(LHSTarget) != 0
                       ? TargetToValue.at(LHSTarget)
                       : nullptr;
      Value *RHS = TargetToValue.count(RHSTarget) != 0
                       ? TargetToValue.at(RHSTarget)
                       : nullptr;

      if (LHS == nullptr || RHS == nullptr)
        return LHS == nullptr ? RHS : LHS;
      return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
    }

    // TODO: add support for switch cases.
    llvm_unreachable("Unhandled terminator type.");
  }

  // Creates a new basic block in F with a single OpUnreachable instruction.
  BasicBlock *CreateUnreachable(Function &F) {
    BasicBlock *BB = BasicBlock::Create(F.getContext(), "new.exit", &F);
    IRBuilder<> Builder(BB);
    Builder.CreateUnreachable();
    return BB;
  }

  // Add OpLoopMerge instruction on cycles.
  bool addMergeForLoops(Function &F) {
    LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
    auto *TopLevelRegion =
        getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
            .getRegionInfo()
            .getTopLevelRegion();

    bool Modified = false;
    for (auto &BB : F) {
      // Not a loop header. Ignoring for now.
      if (!LI.isLoopHeader(&BB))
        continue;
      auto *L = LI.getLoopFor(&BB);

      // This loop header is not the entrance of a convergence region. Ignoring
      // this block.
      auto *CR = getRegionForHeader(TopLevelRegion, &BB);
      if (CR == nullptr)
        continue;

      IRBuilder<> Builder(&BB);

      auto *Merge = getExitFor(CR);
      // We are indeed in a loop, but there are no exits (infinite loop).
      // This could be caused by a bad shader, but also could be an artifact
      // from an earlier optimization. It is not always clear if structurally
      // reachable means runtime reachable, so we cannot error-out. What we must
      // do however is to make is legal on the SPIR-V point of view, hence
      // adding an unreachable merge block.
      if (Merge == nullptr) {
        BranchInst *Br = cast<BranchInst>(BB.getTerminator());
        assert(Br &&
               "This assumes the branch is not a switch. Maybe that's wrong?");
        assert(cast<BranchInst>(BB.getTerminator())->isUnconditional());

        Merge = CreateUnreachable(F);
        Builder.SetInsertPoint(Br);
        Builder.CreateCondBr(Builder.getFalse(), Merge, Br->getSuccessor(0));
        Br->eraseFromParent();
      }

      auto *Continue = L->getLoopLatch();

      Builder.SetInsertPoint(BB.getTerminator());
      auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
      auto ContinueAddress = BlockAddress::get(Continue->getParent(), Continue);
      SmallVector<Value *, 2> Args = {MergeAddress, ContinueAddress};

      Builder.CreateIntrinsic(Intrinsic::spv_loop_merge, {}, {Args});
      Modified = true;
    }

    return Modified;
  }

  // Adds an OpSelectionMerge to the immediate dominator or each node with an
  // in-degree of 2 or more which is not already the merge target of an
  // OpLoopMerge/OpSelectionMerge.
  bool addMergeForNodesWithMultiplePredecessors(Function &F) {
    DomTreeBuilder::BBDomTree DT;
    DT.recalculate(F);

    bool Modified = false;
    for (auto &BB : F) {
      if (pred_size(&BB) <= 1)
        continue;

      if (hasLoopMergeInstruction(BB) && pred_size(&BB) <= 2)
        continue;

      assert(DT.getNode(&BB)->getIDom());
      BasicBlock *Header = DT.getNode(&BB)->getIDom()->getBlock();

      if (isDefinedAsSelectionMergeBy(*Header, BB))
        continue;

      IRBuilder<> Builder(Header);
      Builder.SetInsertPoint(Header->getTerminator());

      auto MergeAddress = BlockAddress::get(BB.getParent(), &BB);
      SmallVector<Value *, 1> Args = {MergeAddress};
      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});

      Modified = true;
    }

    return Modified;
  }

  // When a block has multiple OpSelectionMerge/OpLoopMerge instructions, sorts
  // them to put the "largest" first. A merge instruction is defined as larger
  // than another when its target merge block post-dominates the other target's
  // merge block. (This ordering should match the nesting ordering of the source
  // HLSL).
  bool sortSelectionMerge(Function &F, BasicBlock &Block) {
    std::vector<Instruction *> MergeInstructions;
    for (Instruction &I : Block)
      if (isMergeInstruction(&I))
        MergeInstructions.push_back(&I);

    if (MergeInstructions.size() <= 1)
      return false;

    Instruction *InsertionPoint = *MergeInstructions.begin();

    PartialOrderingVisitor Visitor(F);
    std::sort(MergeInstructions.begin(), MergeInstructions.end(),
              [&Visitor](Instruction *Left, Instruction *Right) {
                if (Left == Right)
                  return false;
                BasicBlock *RightMerge = getDesignatedMergeBlock(Right);
                BasicBlock *LeftMerge = getDesignatedMergeBlock(Left);
                return !Visitor.compare(RightMerge, LeftMerge);
              });

    for (Instruction *I : MergeInstructions) {
      I->moveBefore(InsertionPoint);
      InsertionPoint = I;
    }

    return true;
  }

  // Sorts selection merge headers in |F|.
  // A is sorted before B if the merge block designated by B is an ancestor of
  // the one designated by A.
  bool sortSelectionMergeHeaders(Function &F) {
    bool Modified = false;
    for (BasicBlock &BB : F) {
      Modified |= sortSelectionMerge(F, BB);
    }
    return Modified;
  }

  // Split basic blocks containing multiple OpLoopMerge/OpSelectionMerge
  // instructions so each basic block contains only a single merge instruction.
  bool splitBlocksWithMultipleHeaders(Function &F) {
    std::stack<BasicBlock *> Work;
    for (auto &BB : F) {
      std::vector<Instruction *> MergeInstructions = getMergeInstructions(BB);
      if (MergeInstructions.size() <= 1)
        continue;
      Work.push(&BB);
    }

    const bool Modified = Work.size() > 0;
    while (Work.size() > 0) {
      BasicBlock *Header = Work.top();
      Work.pop();

      std::vector<Instruction *> MergeInstructions =
          getMergeInstructions(*Header);
      for (unsigned i = 1; i < MergeInstructions.size(); i++) {
        BasicBlock *NewBlock =
            Header->splitBasicBlock(MergeInstructions[i], "new.header");

        if (getDesignatedContinueBlock(MergeInstructions[0]) == nullptr) {
          BasicBlock *Unreachable = CreateUnreachable(F);

          BranchInst *BI = cast<BranchInst>(Header->getTerminator());
          IRBuilder<> Builder(Header);
          Builder.SetInsertPoint(BI);
          Builder.CreateCondBr(Builder.getTrue(), NewBlock, Unreachable);
          BI->eraseFromParent();
        }

        Header = NewBlock;
      }
    }

    return Modified;
  }

  // Adds an OpSelectionMerge to each block with an out-degree >= 2 which
  // doesn't already have an OpSelectionMerge.
  bool addMergeForDivergentBlocks(Function &F) {
    DomTreeBuilder::BBPostDomTree PDT;
    PDT.recalculate(F);
    bool Modified = false;

    auto MergeBlocks = getMergeBlocks(F);
    auto ContinueBlocks = getContinueBlocks(F);

    for (auto &BB : F) {
      if (getMergeInstructions(BB).size() != 0)
        continue;

      std::vector<BasicBlock *> Candidates;
      for (BasicBlock *Successor : successors(&BB)) {
        if (MergeBlocks.contains(Successor))
          continue;
        if (ContinueBlocks.contains(Successor))
          continue;
        Candidates.push_back(Successor);
      }

      if (Candidates.size() <= 1)
        continue;

      Modified = true;
      BasicBlock *Merge = Candidates[0];

      auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
      SmallVector<Value *, 1> Args = {MergeAddress};
      IRBuilder<> Builder(&BB);
      Builder.SetInsertPoint(BB.getTerminator());
      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
    }

    return Modified;
  }

  // Gather all the exit nodes for the construct header by |Header| and
  // containing the blocks |Construct|.
  std::vector<Edge> getExitsFrom(const BlockSet &Construct,
                                 BasicBlock &Header) {
    std::vector<Edge> Output;
    visit(Header, [&](BasicBlock *Item) {
      if (Construct.count(Item) == 0)
        return false;

      for (BasicBlock *Successor : successors(Item)) {
        if (Construct.count(Successor) == 0)
          Output.emplace_back(Item, Successor);
      }
      return true;
    });

    return Output;
  }

  // Build a divergent construct tree searching from |BB|.
  // If |Parent| is not null, this tree is attached to the parent's tree.
  void constructDivergentConstruct(BlockSet &Visited, Splitter &S,
                                   BasicBlock *BB, DivergentConstruct *Parent) {
    if (Visited.count(BB) != 0)
      return;
    Visited.insert(BB);

    auto MIS = getMergeInstructions(*BB);
    if (MIS.size() == 0) {
      for (BasicBlock *Successor : successors(BB))
        constructDivergentConstruct(Visited, S, Successor, Parent);
      return;
    }

    assert(MIS.size() == 1);
    Instruction *MI = MIS[0];

    BasicBlock *Merge = getDesignatedMergeBlock(MI);
    BasicBlock *Continue = getDesignatedContinueBlock(MI);

    auto Output = std::make_unique<DivergentConstruct>();
    Output->Header = BB;
    Output->Merge = Merge;
    Output->Continue = Continue;
    Output->Parent = Parent;

    constructDivergentConstruct(Visited, S, Merge, Parent);
    if (Continue)
      constructDivergentConstruct(Visited, S, Continue, Output.get());

    for (BasicBlock *Successor : successors(BB))
      constructDivergentConstruct(Visited, S, Successor, Output.get());

    if (Parent)
      Parent->Children.emplace_back(std::move(Output));
  }

  // Returns the blocks belonging to the divergent construct |Node|.
  BlockSet getConstructBlocks(Splitter &S, DivergentConstruct *Node) {
    assert(Node->Header && Node->Merge);

    if (Node->Continue) {
      auto LoopBlocks = S.getLoopConstructBlocks(Node->Header, Node->Merge);
      return BlockSet(LoopBlocks.begin(), LoopBlocks.end());
    }

    auto SelectionBlocks = S.getSelectionConstructBlocks(Node);
    return BlockSet(SelectionBlocks.begin(), SelectionBlocks.end());
  }

  // Fixup the construct |Node| to respect a set of rules defined by the SPIR-V
  // spec.
  bool fixupConstruct(Splitter &S, DivergentConstruct *Node) {
    bool Modified = false;
    for (auto &Child : Node->Children)
      Modified |= fixupConstruct(S, Child.get());

    // This construct is the root construct. Does not represent any real
    // construct, just a way to access the first level of the forest.
    if (Node->Parent == nullptr)
      return Modified;

    // This node's parent is the root. Meaning this is a top-level construct.
    // There can be multiple exists, but all are guaranteed to exit at most 1
    // construct since we are at first level.
    if (Node->Parent->Header == nullptr)
      return Modified;

    // Health check for the structure.
    assert(Node->Header && Node->Merge);
    assert(Node->Parent->Header && Node->Parent->Merge);

    BlockSet ConstructBlocks = getConstructBlocks(S, Node);
    auto Edges = getExitsFrom(ConstructBlocks, *Node->Header);

    //  No edges exiting the construct.
    if (Edges.size() < 1)
      return Modified;

    bool HasBadEdge = Node->Merge == Node->Parent->Merge ||
                      Node->Merge == Node->Parent->Continue;
    // BasicBlock *Target = Edges[0].second;
    for (auto &[Src, Dst] : Edges) {
      // - Breaking from a selection construct: S is a selection construct, S is
      // the innermost structured
      //   control-flow construct containing A, and B is the merge block for S
      // - Breaking from the innermost loop: S is the innermost loop construct
      // containing A,
      //   and B is the merge block for S
      if (Node->Merge == Dst)
        continue;

      // Entering the innermost loop’s continue construct: S is the innermost
      // loop construct containing A, and B is the continue target for S
      if (Node->Continue == Dst)
        continue;

      // TODO: what about cases branching to another case in the switch? Seems
      // to work, but need to double check.
      HasBadEdge = true;
    }

    if (!HasBadEdge)
      return Modified;

    // Create a single exit node gathering all exit edges.
    BasicBlock *NewExit = S.createSingleExitNode(Node->Header, Edges);

    // Fixup this construct's merge node to point to the new exit.
    // Note: this algorithm fixes inner-most divergence construct first. So
    // recursive structures sharing a single merge node are fixed from the
    // inside toward the outside.
    auto MergeInstructions = getMergeInstructions(*Node->Header);
    assert(MergeInstructions.size() == 1);
    Instruction *I = MergeInstructions[0];
    BlockAddress *BA = cast<BlockAddress>(I->getOperand(0));
    if (BA->getBasicBlock() == Node->Merge) {
      auto MergeAddress = BlockAddress::get(NewExit->getParent(), NewExit);
      I->setOperand(0, MergeAddress);
    }

    // Clean up of the possible dangling BockAddr operands to prevent MIR
    // comments about "address of removed block taken".
    if (!BA->isConstantUsed())
      BA->destroyConstant();

    Node->Merge = NewExit;
    // Regenerate the dom trees.
    S.invalidate();
    return true;
  }

  bool splitCriticalEdges(Function &F) {
    LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
    Splitter S(F, LI);

    DivergentConstruct Root;
    BlockSet Visited;
    constructDivergentConstruct(Visited, S, &*F.begin(), &Root);
    return fixupConstruct(S, &Root);
  }

  // Simplify branches when possible:
  //  - if the 2 sides of a conditional branch are the same, transforms it to an
  //  unconditional branch.
  //  - if a switch has only 2 distinct successors, converts it to a conditional
  //  branch.
  bool simplifyBranches(Function &F) {
    bool Modified = false;

    for (BasicBlock &BB : F) {
      SwitchInst *SI = dyn_cast<SwitchInst>(BB.getTerminator());
      if (!SI)
        continue;
      if (SI->getNumCases() > 1)
        continue;

      Modified = true;
      IRBuilder<> Builder(&BB);
      Builder.SetInsertPoint(SI);

      if (SI->getNumCases() == 0) {
        Builder.CreateBr(SI->getDefaultDest());
      } else {
        Value *Condition =
            Builder.CreateCmp(CmpInst::ICMP_EQ, SI->getCondition(),
                              SI->case_begin()->getCaseValue());
        Builder.CreateCondBr(Condition, SI->case_begin()->getCaseSuccessor(),
                             SI->getDefaultDest());
      }
      SI->eraseFromParent();
    }

    return Modified;
  }

  // Makes sure every case target in |F| is unique. If 2 cases branch to the
  // same basic block, one of the targets is updated so it jumps to a new basic
  // block ending with a single unconditional branch to the original target.
  bool splitSwitchCases(Function &F) {
    bool Modified = false;

    for (BasicBlock &BB : F) {
      SwitchInst *SI = dyn_cast<SwitchInst>(BB.getTerminator());
      if (!SI)
        continue;

      BlockSet Seen;
      Seen.insert(SI->getDefaultDest());

      auto It = SI->case_begin();
      while (It != SI->case_end()) {
        BasicBlock *Target = It->getCaseSuccessor();
        if (Seen.count(Target) == 0) {
          Seen.insert(Target);
          ++It;
          continue;
        }

        Modified = true;
        BasicBlock *NewTarget =
            BasicBlock::Create(F.getContext(), "new.sw.case", &F);
        IRBuilder<> Builder(NewTarget);
        Builder.CreateBr(Target);
        SI->addCase(It->getCaseValue(), NewTarget);
        It = SI->removeCase(It);
      }
    }

    return Modified;
  }

  bool IsRequiredForPhiNode(BasicBlock *BB) {
    for (BasicBlock *Successor : successors(BB)) {
      for (PHINode &Phi : Successor->phis()) {
        if (Phi.getBasicBlockIndex(BB) != -1)
          return true;
      }
    }

    return false;
  }

  bool removeUselessBlocks(Function &F) {
    std::vector<BasicBlock *> ToRemove;

    auto MergeBlocks = getMergeBlocks(F);
    auto ContinueBlocks = getContinueBlocks(F);

    for (BasicBlock &BB : F) {
      if (BB.size() != 1)
        continue;

      if (isa<ReturnInst>(BB.getTerminator()))
        continue;

      if (MergeBlocks.count(&BB) != 0 || ContinueBlocks.count(&BB) != 0)
        continue;

      if (IsRequiredForPhiNode(&BB))
        continue;

      if (BB.getUniqueSuccessor() == nullptr)
        continue;

      BasicBlock *Successor = BB.getUniqueSuccessor();
      std::vector<BasicBlock *> Predecessors(predecessors(&BB).begin(),
                                             predecessors(&BB).end());
      for (BasicBlock *Predecessor : Predecessors)
        replaceBranchTargets(Predecessor, &BB, Successor);
      ToRemove.push_back(&BB);
    }

    for (BasicBlock *BB : ToRemove)
      BB->eraseFromParent();

    return ToRemove.size() != 0;
  }

  bool addHeaderToRemainingDivergentDAG(Function &F) {
    bool Modified = false;

    auto MergeBlocks = getMergeBlocks(F);
    auto ContinueBlocks = getContinueBlocks(F);
    auto HeaderBlocks = getHeaderBlocks(F);

    DomTreeBuilder::BBDomTree DT;
    DomTreeBuilder::BBPostDomTree PDT;
    PDT.recalculate(F);
    DT.recalculate(F);

    for (BasicBlock &BB : F) {
      if (HeaderBlocks.count(&BB) != 0)
        continue;
      if (succ_size(&BB) < 2)
        continue;

      size_t CandidateEdges = 0;
      for (BasicBlock *Successor : successors(&BB)) {
        if (MergeBlocks.count(Successor) != 0 ||
            ContinueBlocks.count(Successor) != 0)
          continue;
        if (HeaderBlocks.count(Successor) != 0)
          continue;
        CandidateEdges += 1;
      }

      if (CandidateEdges <= 1)
        continue;

      BasicBlock *Header = &BB;
      BasicBlock *Merge = PDT.getNode(&BB)->getIDom()->getBlock();

      bool HasBadBlock = false;
      visit(*Header, [&](const BasicBlock *Node) {
        if (DT.dominates(Header, Node))
          return false;
        if (PDT.dominates(Merge, Node))
          return false;
        if (Node == Header || Node == Merge)
          return true;

        HasBadBlock |= MergeBlocks.count(Node) != 0 ||
                       ContinueBlocks.count(Node) != 0 ||
                       HeaderBlocks.count(Node) != 0;
        return !HasBadBlock;
      });

      if (HasBadBlock)
        continue;

      Modified = true;
      Instruction *SplitInstruction = Merge->getTerminator();
      if (isMergeInstruction(SplitInstruction->getPrevNode()))
        SplitInstruction = SplitInstruction->getPrevNode();
      BasicBlock *NewMerge =
          Merge->splitBasicBlockBefore(SplitInstruction, "new.merge");

      IRBuilder<> Builder(Header);
      Builder.SetInsertPoint(Header->getTerminator());

      auto MergeAddress = BlockAddress::get(NewMerge->getParent(), NewMerge);
      SmallVector<Value *, 1> Args = {MergeAddress};
      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
    }

    return Modified;
  }

public:
  static char ID;

  SPIRVStructurizer() : FunctionPass(ID) {
    initializeSPIRVStructurizerPass(*PassRegistry::getPassRegistry());
  };

  virtual bool runOnFunction(Function &F) override {
    bool Modified = false;

    // In LLVM, Switches are allowed to have several cases branching to the same
    // basic block. This is allowed in SPIR-V, but can make structurizing SPIR-V
    // harder, so first remove edge cases.
    Modified |= splitSwitchCases(F);

    // LLVM allows conditional branches to have both side jumping to the same
    // block. It also allows switched to have a single default, or just one
    // case. Cleaning this up now.
    Modified |= simplifyBranches(F);

    // At this state, we should have a reducible CFG with cycles.
    // STEP 1: Adding OpLoopMerge instructions to loop headers.
    Modified |= addMergeForLoops(F);

    // STEP 2: adding OpSelectionMerge to each node with an in-degree >= 2.
    Modified |= addMergeForNodesWithMultiplePredecessors(F);

    // STEP 3:
    // Sort selection merge, the largest construct goes first.
    // This simplifies the next step.
    Modified |= sortSelectionMergeHeaders(F);

    // STEP 4: As this stage, we can have a single basic block with multiple
    // OpLoopMerge/OpSelectionMerge instructions. Splitting this block so each
    // BB has a single merge instruction.
    Modified |= splitBlocksWithMultipleHeaders(F);

    // STEP 5: In the previous steps, we added merge blocks the loops and
    // natural merge blocks (in-degree >= 2). What remains are conditions with
    // an exiting branch (return, unreachable). In such case, we must start from
    // the header, and add headers to divergent construct with no headers.
    Modified |= addMergeForDivergentBlocks(F);

    // STEP 6: At this stage, we have several divergent construct defines by a
    // header and a merge block. But their boundaries have no constraints: a
    // construct exit could be outside of the parents' construct exit. Such
    // edges are called critical edges. What we need is to split those edges
    // into several parts. Each part exiting the parent's construct by its merge
    // block.
    Modified |= splitCriticalEdges(F);

    // STEP 7: The previous steps possibly created a lot of "proxy" blocks.
    // Blocks with a single unconditional branch, used to create a valid
    // divergent construct tree. Some nodes are still requires (e.g: nodes
    // allowing a valid exit through the parent's merge block). But some are
    // left-overs of past transformations, and could cause actual validation
    // issues. E.g: the SPIR-V spec allows a construct to break to the parents
    // loop construct without an OpSelectionMerge, but this requires a straight
    // jump. If a proxy block lies between the conditional branch and the
    // parent's merge, the CFG is not valid.
    Modified |= removeUselessBlocks(F);

    // STEP 8: Final fix-up steps: our tree boundaries are correct, but some
    // blocks are branching with no header. Those are often simple conditional
    // branches with 1 or 2 returning edges. Adding a header for those.
    Modified |= addHeaderToRemainingDivergentDAG(F);

    // STEP 9: sort basic blocks to match both the LLVM & SPIR-V requirements.
    Modified |= sortBlocks(F);

    return Modified;
  }

  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.addRequired<DominatorTreeWrapperPass>();
    AU.addRequired<LoopInfoWrapperPass>();
    AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();

    AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
    FunctionPass::getAnalysisUsage(AU);
  }
};
} // namespace llvm

char SPIRVStructurizer::ID = 0;

INITIALIZE_PASS_BEGIN(SPIRVStructurizer, "structurizer", "structurize SPIRV",
                      false, false)
INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)

INITIALIZE_PASS_END(SPIRVStructurizer, "structurize", "structurize SPIRV",
                    false, false)

FunctionPass *llvm::createSPIRVStructurizerPass() {
  return new SPIRVStructurizer();
}