from typing import Optional, List
# Base class for an instruction. To implement a basic instruction that doesn't
# impact the control-flow, create a new class inheriting from this.
class Instruction:
# Contains the name of the output register, if any.
_result: Optional[str]
# Contains the instruction opcode.
_opcode: str
# Contains all the instruction operands, except result and opcode.
_operands: List[str]
def __init__(self, line: str):
self.line = line
tokens = line.split()
if len(tokens) > 1 and tokens[1] == "=":
self._result = tokens[0]
self._opcode = tokens[2]
self._operands = tokens[3:] if len(tokens) > 2 else []
else:
self._result = None
self._opcode = tokens[0]
self._operands = tokens[1:] if len(tokens) > 1 else []
def __str__(self):
if self._result is None:
return f" {self._opcode} {self._operands}"
return f"{self._result:3} = {self._opcode} {self._operands}"
# Returns the instruction opcode.
def opcode(self) -> str:
return self._opcode
# Returns the instruction operands.
def operands(self) -> List[str]:
return self._operands
# Returns the instruction output register. Calling this function is
# only allowed if has_output_register() is true.
def output_register(self) -> str:
assert self._result is not None
return self._result
# Returns true if this function has an output register. False otherwise.
def has_output_register(self) -> bool:
return self._result is not None
# This function is used to initialize state related to this instruction
# before module execution begins. For example, global Input variables
# can use this to store the lane ID into the register.
def static_execution(self, lane):
pass
# This function is called everytime this instruction is executed by a
# tangle. This function should not be directly overriden, instead see
# _impl and _advance_ip.
def runtime_execution(self, module, lane):
self._impl(module, lane)
self._advance_ip(module, lane)
# This function needs to be overriden if your instruction can be executed.
# It implements the logic of the instruction.
# 'Static' instructions like OpConstant should not override this since
# they are not supposed to be executed at runtime.
def _impl(self, module, lane):
raise RuntimeError(f"Unimplemented instruction {self}")
# By default, IP is incremented to point to the next instruction.
# If the instruction modifies IP (like OpBranch), this must be overridden.
def _advance_ip(self, module, lane):
lane.set_ip(lane.ip() + 1)
# Those are parsed, but never executed.
class OpEntryPoint(Instruction):
pass
class OpFunction(Instruction):
pass
class OpFunctionEnd(Instruction):
pass
class OpLabel(Instruction):
pass
class OpVariable(Instruction):
pass
class OpName(Instruction):
def name(self) -> str:
return self._operands[1][1:-1]
def decoratedRegister(self) -> str:
return self._operands[0]
# The only decoration we use if the BuiltIn one to initialize the values.
class OpDecorate(Instruction):
def static_execution(self, lane):
if self._operands[1] == "LinkageAttributes":
return
assert (
self._operands[1] == "BuiltIn"
and self._operands[2] == "SubgroupLocalInvocationId"
)
lane.set_register(self._operands[0], lane.tid())
# Constants
class OpConstant(Instruction):
def static_execution(self, lane):
lane.set_register(self._result, int(self._operands[1]))
class OpConstantTrue(OpConstant):
def static_execution(self, lane):
lane.set_register(self._result, True)
class OpConstantFalse(OpConstant):
def static_execution(self, lane):
lane.set_register(self._result, False)
class OpConstantComposite(OpConstant):
def static_execution(self, lane):
result = []
for op in self._operands[1:]:
result.append(lane.get_register(op))
lane.set_register(self._result, result)
# Control flow instructions
class OpFunctionCall(Instruction):
def _impl(self, module, lane):
pass
def _advance_ip(self, module, lane):
entry = module.get_function_entry(self._operands[1])
lane.do_call(entry, self._result)
class OpReturn(Instruction):
def _impl(self, module, lane):
pass
def _advance_ip(self, module, lane):
lane.do_return(None)
class OpReturnValue(Instruction):
def _impl(self, module, lane):
pass
def _advance_ip(self, module, lane):
lane.do_return(lane.get_register(self._operands[0]))
class OpBranch(Instruction):
def _impl(self, module, lane):
pass
def _advance_ip(self, module, lane):
lane.set_ip(module.get_bb_entry(self._operands[0]))
pass
class OpBranchConditional(Instruction):
def _impl(self, module, lane):
pass
def _advance_ip(self, module, lane):
condition = lane.get_register(self._operands[0])
if condition:
lane.set_ip(module.get_bb_entry(self._operands[1]))
else:
lane.set_ip(module.get_bb_entry(self._operands[2]))
class OpSwitch(Instruction):
def _impl(self, module, lane):
pass
def _advance_ip(self, module, lane):
value = lane.get_register(self._operands[0])
default_label = self._operands[1]
i = 2
while i < len(self._operands):
imm = int(self._operands[i])
label = self._operands[i + 1]
if value == imm:
lane.set_ip(module.get_bb_entry(label))
return
i += 2
lane.set_ip(module.get_bb_entry(default_label))
class OpUnreachable(Instruction):
def _impl(self, module, lane):
raise RuntimeError("This instruction should never be executed.")
# Convergence instructions
class MergeInstruction(Instruction):
def merge_location(self):
return self._operands[0]
def continue_location(self):
return None if len(self._operands) < 3 else self._operands[1]
def _impl(self, module, lane):
lane.handle_convergence_header(self)
class OpLoopMerge(MergeInstruction):
pass
class OpSelectionMerge(MergeInstruction):
pass
# Other instructions
class OpBitcast(Instruction):
def _impl(self, module, lane):
# TODO: find out the type from the defining instruction.
# This can only work for DXC.
if self._operands[0] == "%int":
lane.set_register(self._result, int(lane.get_register(self._operands[1])))
else:
raise RuntimeError("Unsupported OpBitcast operand")
class OpAccessChain(Instruction):
def _impl(self, module, lane):
# Python dynamic types allows me to simplify. As long as the SPIR-V
# is legal, this should be fine.
# Note: SPIR-V structs are stored as tuples
value = lane.get_register(self._operands[1])
for operand in self._operands[2:]:
value = value[lane.get_register(operand)]
lane.set_register(self._result, value)
class OpCompositeConstruct(Instruction):
def _impl(self, module, lane):
output = []
for op in self._operands[1:]:
output.append(lane.get_register(op))
lane.set_register(self._result, output)
class OpCompositeExtract(Instruction):
def _impl(self, module, lane):
value = lane.get_register(self._operands[1])
output = value
for op in self._operands[2:]:
output = output[int(op)]
lane.set_register(self._result, output)
class OpStore(Instruction):
def _impl(self, module, lane):
lane.set_register(self._operands[0], lane.get_register(self._operands[1]))
class OpLoad(Instruction):
def _impl(self, module, lane):
lane.set_register(self._result, lane.get_register(self._operands[1]))
class OpIAdd(Instruction):
def _impl(self, module, lane):
LHS = lane.get_register(self._operands[1])
RHS = lane.get_register(self._operands[2])
lane.set_register(self._result, LHS + RHS)
class OpISub(Instruction):
def _impl(self, module, lane):
LHS = lane.get_register(self._operands[1])
RHS = lane.get_register(self._operands[2])
lane.set_register(self._result, LHS - RHS)
class OpIMul(Instruction):
def _impl(self, module, lane):
LHS = lane.get_register(self._operands[1])
RHS = lane.get_register(self._operands[2])
lane.set_register(self._result, LHS * RHS)
class OpLogicalNot(Instruction):
def _impl(self, module, lane):
LHS = lane.get_register(self._operands[1])
lane.set_register(self._result, not LHS)
class _LessThan(Instruction):
def _impl(self, module, lane):
LHS = lane.get_register(self._operands[1])
RHS = lane.get_register(self._operands[2])
lane.set_register(self._result, LHS < RHS)
class _GreaterThan(Instruction):
def _impl(self, module, lane):
LHS = lane.get_register(self._operands[1])
RHS = lane.get_register(self._operands[2])
lane.set_register(self._result, LHS > RHS)
class OpSLessThan(_LessThan):
pass
class OpULessThan(_LessThan):
pass
class OpSGreaterThan(_GreaterThan):
pass
class OpUGreaterThan(_GreaterThan):
pass
class OpIEqual(Instruction):
def _impl(self, module, lane):
LHS = lane.get_register(self._operands[1])
RHS = lane.get_register(self._operands[2])
lane.set_register(self._result, LHS == RHS)
class OpINotEqual(Instruction):
def _impl(self, module, lane):
LHS = lane.get_register(self._operands[1])
RHS = lane.get_register(self._operands[2])
lane.set_register(self._result, LHS != RHS)
class OpPhi(Instruction):
def _impl(self, module, lane):
previousBBName = lane.get_previous_bb_name()
i = 1
while i < len(self._operands):
label = self._operands[i + 1]
if label == previousBBName:
lane.set_register(self._result, lane.get_register(self._operands[i]))
return
i += 2
raise RuntimeError("previousBB not in the OpPhi _operands")
class OpSelect(Instruction):
def _impl(self, module, lane):
condition = lane.get_register(self._operands[1])
value = lane.get_register(self._operands[2 if condition else 3])
lane.set_register(self._result, value)
# Wave intrinsics
class OpGroupNonUniformBroadcastFirst(Instruction):
def _impl(self, module, lane):
assert lane.get_register(self._operands[1]) == 3
if lane.is_first_active_lane():
lane.broadcast_register(self._result, lane.get_register(self._operands[2]))
class OpGroupNonUniformElect(Instruction):
def _impl(self, module, lane):
lane.set_register(self._result, lane.is_first_active_lane())