//===-- SPIRVPostLegalizer.cpp - ammend info after legalization -*- C++ -*-===//
//
// which may appear after the legalizer pass
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// The pass partially apply pre-legalization logic to new instructions inserted
// as a result of legalization:
// - assigns SPIR-V types to registers for new instructions.
//
//===----------------------------------------------------------------------===//
#include "SPIRV.h"
#include "SPIRVSubtarget.h"
#include "SPIRVUtils.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/CodeGen/MachinePostDominators.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/Target/TargetIntrinsicInfo.h"
#include <stack>
#define DEBUG_TYPE "spirv-postlegalizer"
using namespace llvm;
namespace {
class SPIRVPostLegalizer : public MachineFunctionPass {
public:
static char ID;
SPIRVPostLegalizer() : MachineFunctionPass(ID) {
initializeSPIRVPostLegalizerPass(*PassRegistry::getPassRegistry());
}
bool runOnMachineFunction(MachineFunction &MF) override;
};
} // namespace
// Defined in SPIRVLegalizerInfo.cpp.
extern bool isTypeFoldingSupported(unsigned Opcode);
namespace llvm {
// Defined in SPIRVPreLegalizer.cpp.
extern Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIB,
MachineRegisterInfo &MRI);
extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR);
} // namespace llvm
static bool isMetaInstrGET(unsigned Opcode) {
return Opcode == SPIRV::GET_ID || Opcode == SPIRV::GET_fID ||
Opcode == SPIRV::GET_pID || Opcode == SPIRV::GET_vID ||
Opcode == SPIRV::GET_vfID || Opcode == SPIRV::GET_vpID;
}
static bool mayBeInserted(unsigned Opcode) {
switch (Opcode) {
case TargetOpcode::G_SMAX:
case TargetOpcode::G_UMAX:
case TargetOpcode::G_SMIN:
case TargetOpcode::G_UMIN:
case TargetOpcode::G_FMINNUM:
case TargetOpcode::G_FMINIMUM:
case TargetOpcode::G_FMAXNUM:
case TargetOpcode::G_FMAXIMUM:
return true;
default:
return isTypeFoldingSupported(Opcode);
}
}
static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
MachineIRBuilder MIB) {
MachineRegisterInfo &MRI = MF.getRegInfo();
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &I : MBB) {
const unsigned Opcode = I.getOpcode();
if (Opcode == TargetOpcode::G_UNMERGE_VALUES) {
unsigned ArgI = I.getNumOperands() - 1;
Register SrcReg = I.getOperand(ArgI).isReg()
? I.getOperand(ArgI).getReg()
: Register(0);
SPIRVType *DefType =
SrcReg.isValid() ? GR->getSPIRVTypeForVReg(SrcReg) : nullptr;
if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
report_fatal_error(
"cannot select G_UNMERGE_VALUES with a non-vector argument");
SPIRVType *ScalarType =
GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
for (unsigned i = 0; i < I.getNumDefs(); ++i) {
Register ResVReg = I.getOperand(i).getReg();
SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResVReg);
if (!ResType) {
// There was no "assign type" actions, let's fix this now
ResType = ScalarType;
MRI.setRegClass(ResVReg, &SPIRV::iIDRegClass);
MRI.setType(ResVReg,
LLT::scalar(GR->getScalarOrVectorBitWidth(ResType)));
GR->assignSPIRVTypeToVReg(ResType, ResVReg, *GR->CurMF);
}
}
} else if (mayBeInserted(Opcode) && I.getNumDefs() == 1 &&
I.getNumOperands() > 1 && I.getOperand(1).isReg()) {
// Legalizer may have added a new instructions and introduced new
// registers, we must decorate them as if they were introduced in a
// non-automatic way
Register ResVReg = I.getOperand(0).getReg();
SPIRVType *ResVType = GR->getSPIRVTypeForVReg(ResVReg);
// Check if the register defined by the instruction is newly generated
// or already processed
if (!ResVType) {
// Set type of the defined register
ResVType = GR->getSPIRVTypeForVReg(I.getOperand(1).getReg());
// Check if we have type defined for operands of the new instruction
if (!ResVType)
continue;
// Set type & class
MRI.setRegClass(ResVReg, GR->getRegClass(ResVType));
MRI.setType(ResVReg, GR->getRegType(ResVType));
GR->assignSPIRVTypeToVReg(ResVType, ResVReg, *GR->CurMF);
}
// If this is a simple operation that is to be reduced by TableGen
// definition we must apply some of pre-legalizer rules here
if (isTypeFoldingSupported(Opcode)) {
// Check if the instruction newly generated or already processed
MachineInstr *NextMI = I.getNextNode();
if (NextMI && isMetaInstrGET(NextMI->getOpcode()))
continue;
// Restore usual instructions pattern for the newly inserted
// instruction
insertAssignInstr(ResVReg, nullptr, ResVType, GR, MIB, MRI);
processInstr(I, MIB, MRI, GR);
}
}
}
}
}
// Do a preorder traversal of the CFG starting from the BB |Start|.
// point. Calls |op| on each basic block encountered during the traversal.
void visit(MachineFunction &MF, MachineBasicBlock &Start,
std::function<void(MachineBasicBlock *)> op) {
std::stack<MachineBasicBlock *> ToVisit;
SmallPtrSet<MachineBasicBlock *, 8> Seen;
ToVisit.push(&Start);
Seen.insert(ToVisit.top());
while (ToVisit.size() != 0) {
MachineBasicBlock *MBB = ToVisit.top();
ToVisit.pop();
op(MBB);
for (auto Succ : MBB->successors()) {
if (Seen.contains(Succ))
continue;
ToVisit.push(Succ);
Seen.insert(Succ);
}
}
}
// Do a preorder traversal of the CFG starting from the given function's entry
// point. Calls |op| on each basic block encountered during the traversal.
void visit(MachineFunction &MF, std::function<void(MachineBasicBlock *)> op) {
visit(MF, *MF.begin(), op);
}
// Sorts basic blocks by dominance to respect the SPIR-V spec.
void sortBlocks(MachineFunction &MF) {
MachineDominatorTree MDT(MF);
std::unordered_map<MachineBasicBlock *, size_t> Order;
Order.reserve(MF.size());
size_t Index = 0;
visit(MF, [&Order, &Index](MachineBasicBlock *MBB) { Order[MBB] = Index++; });
auto Comparator = [&Order](MachineBasicBlock &LHS, MachineBasicBlock &RHS) {
return Order[&LHS] < Order[&RHS];
};
MF.sort(Comparator);
}
bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
// Initialize the type registry.
const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
GR->setCurrentFunc(MF);
MachineIRBuilder MIB(MF);
processNewInstrs(MF, GR, MIB);
sortBlocks(MF);
return true;
}
INITIALIZE_PASS(SPIRVPostLegalizer, DEBUG_TYPE, "SPIRV post legalizer", false,
false)
char SPIRVPostLegalizer::ID = 0;
FunctionPass *llvm::createSPIRVPostLegalizerPass() {
return new SPIRVPostLegalizer();
}