#!/usr/bin/env python3
from __future__ import annotations
from dataclasses import dataclass
from instructions import *
from typing import Any, Iterable, Callable, Optional, Tuple, List, Dict
import argparse
import fileinput
import inspect
import re
import sys
RE_EXPECTS = re.compile(r"^([0-9]+,)*[0-9]+$")
# Parse the SPIR-V instructions. Some instructions are ignored because
# not required to simulate this module.
# Instructions are to be implemented in instructions.py
def parseInstruction(i):
IGNORED = set(
[
"OpCapability",
"OpMemoryModel",
"OpExecutionMode",
"OpExtension",
"OpSource",
"OpTypeInt",
"OpTypeStruct",
"OpTypeFloat",
"OpTypeBool",
"OpTypeVoid",
"OpTypeFunction",
"OpTypePointer",
"OpTypeArray",
]
)
if i.opcode() in IGNORED:
return None
try:
Type = getattr(sys.modules["instructions"], i.opcode())
except AttributeError:
raise RuntimeError(f"Unsupported instruction {i}")
if not inspect.isclass(Type):
raise RuntimeError(
f"{i} instruction definition is not a class. Did you used 'def' instead of 'class'?"
)
return Type(i.line)
# Split a list of instructions into pieces. Pieces are delimited by instructions of the type splitType.
# The delimiter is the first instruction of the next piece.
# This function returns no empty pieces:
# - if 2 subsequent delimiters will mean 2 pieces. One with only the first delimiter, and the second
# with the delimiter and following instructions.
# - if the first instruction is a delimiter, the first piece will begin with this delimiter.
def splitInstructions(
splitType: type, instructions: Iterable[Instruction]
) -> List[List[Instruction]]:
blocks: List[List[Instruction]] = [[]]
for instruction in instructions:
if isinstance(instruction, splitType) and len(blocks[-1]) > 0:
blocks.append([])
blocks[-1].append(instruction)
return blocks
# Defines a BasicBlock in the simulator.
# Begins at an OpLabel, and ends with a control-flow instruction.
class BasicBlock:
def __init__(self, instructions) -> None:
assert isinstance(instructions[0], OpLabel)
# The name of the basic block, which is the register of the leading
# OpLabel.
self._name = instructions[0].output_register()
# The list of instructions belonging to this block.
self._instructions = instructions[1:]
# Returns the name of this basic block.
def name(self):
return self._name
# Returns the instruction at index in this basic block.
def __getitem__(self, index: int) -> Instruction:
return self._instructions[index]
# Returns the number of instructions in this basic block, excluding the
# leading OpLabel.
def __len__(self):
return len(self._instructions)
def dump(self):
print(f" {self._name}:")
for instruction in self._instructions:
print(f" {instruction}")
# Defines a Function in the simulator.
class Function:
def __init__(self, instructions) -> None:
assert isinstance(instructions[0], OpFunction)
# The name of the function (name of the register returned by OpFunction).
self._name: str = instructions[0].output_register()
# The list of basic blocks that belongs to this function.
self._basic_blocks: List[BasicBlock] = []
# The variables local to this function.
self._variables: List[OpVariable] = [
x for x in instructions if isinstance(x, OpVariable)
]
assert isinstance(instructions[-1], OpFunctionEnd)
body = filter(lambda x: not isinstance(x, OpVariable), instructions[1:-1])
for block in splitInstructions(OpLabel, body):
self._basic_blocks.append(BasicBlock(block))
# Returns the name of this function.
def name(self) -> str:
return self._name
# Returns the basic block at index in this function.
def __getitem__(self, index: int) -> BasicBlock:
return self._basic_blocks[index]
# Returns the index of the basic block with the given name if found,
# -1 otherwise.
def get_bb_index(self, name) -> int:
for i in range(len(self._basic_blocks)):
if self._basic_blocks[i].name() == name:
return i
return -1
def dump(self):
print(" Variables:")
for var in self._variables:
print(f" {var}")
print(" Blocks:")
for bb in self._basic_blocks:
bb.dump()
# Represents an instruction pointer in the simulator.
@dataclass
class InstructionPointer:
# The current function the IP points to.
function: Function
# The basic block index in function IP points to.
basic_block: int
# The instruction in basic_block IP points to.
instruction_index: int
def __str__(self):
bb = self.function[self.basic_block]
i = bb[self.instruction_index]
return f"{bb.name()}:{self.instruction_index} in {self.function.name()} | {i}"
def __hash__(self):
return hash((self.function.name(), self.basic_block, self.instruction_index))
# Returns the basic block IP points to.
def bb(self) -> BasicBlock:
return self.function[self.basic_block]
# Returns the instruction IP points to.
def instruction(self):
return self.function[self.basic_block][self.instruction_index]
# Increment IP by 1. This only works inside a basic-block boundary.
# Incrementing IP when at the boundary of a basic block will fail.
def __add__(self, value: int):
bb = self.function[self.basic_block]
assert len(bb) > self.instruction_index + value
return InstructionPointer(
self.function, self.basic_block, self.instruction_index + value
)
# Defines a Lane in this simulator.
class Lane:
# The registers known by this lane.
_registers: Dict[str, Any]
# The current IP of this lane.
_ip: Optional[InstructionPointer]
# If this lane running.
_running: bool
# The wave this lane belongs to.
_wave: Wave
# The callstack of this lane. Each tuple represents 1 call.
# The first element is the IP the function will return to.
# The second element is the callback to call to store the return value
# into the correct register.
_callstack: List[Tuple[InstructionPointer, Callable[[Any], None]]]
_previous_bb: Optional[BasicBlock]
_current_bb: Optional[BasicBlock]
def __init__(self, wave: Wave, tid: int) -> None:
self._registers = dict()
self._ip = None
self._running = True
self._wave = wave
self._callstack = []
# The index of this lane in the wave.
self._tid = tid
# The last BB this lane was executing into.
self._previous_bb = None
# The current BB this lane is executing into.
self._current_bb = None
# Returns the lane/thread ID of this lane in its wave.
def tid(self) -> int:
return self._tid
# Returns true is this lane if the first by index in the current active tangle.
def is_first_active_lane(self) -> bool:
return self._tid == self._wave.get_first_active_lane_index()
# Broadcast value into the registers of all active lanes.
def broadcast_register(self, register: str, value: Any) -> None:
self._wave.broadcast_register(register, value)
# Returns the IP this lane is currently at.
def ip(self) -> InstructionPointer:
assert self._ip is not None
return self._ip
# Returns true if this lane is running, false otherwise.
# Running means not dead. An inactive lane is running.
def running(self) -> bool:
return self._running
# Set the register at "name" to "value" in this lane.
def set_register(self, name: str, value: Any) -> None:
self._registers[name] = value
# Get the value in register "name" in this lane.
# If allow_undef is true, fetching an unknown register won't fail.
def get_register(self, name: str, allow_undef: bool = False) -> Optional[Any]:
if allow_undef and name not in self._registers:
return None
return self._registers[name]
def set_ip(self, ip: InstructionPointer) -> None:
if ip.bb() != self._current_bb:
self._previous_bb = self._current_bb
self._current_bb = ip.bb()
self._ip = ip
def get_previous_bb_name(self):
return self._previous_bb.name()
def handle_convergence_header(self, instruction):
self._wave.handle_convergence_header(self, instruction)
def do_call(self, ip, output_register):
return_ip = None if self._ip is None else self._ip + 1
self._callstack.append(
(return_ip, lambda value: self.set_register(output_register, value))
)
self.set_ip(ip)
def do_return(self, value):
ip, callback = self._callstack[-1]
self._callstack.pop()
callback(value)
if len(self._callstack) == 0:
self._running = False
else:
self.set_ip(ip)
# Represents the SPIR-V module in the simulator.
class Module:
_functions: Dict[str, Function]
_prolog: List[Instruction]
_globals: List[Instruction]
_name2reg: Dict[str, str]
_reg2name: Dict[str, str]
def __init__(self, instructions) -> None:
chunks = splitInstructions(OpFunction, instructions)
# The instructions located outside of all functions.
self._prolog = chunks[0]
# The functions in this module.
self._functions = {}
# Global variables in this module.
self._globals = [
x
for x in instructions
if isinstance(x, OpVariable) or issubclass(type(x), OpConstant)
]
# Helper dictionaries to get real names of registers, or registers by names.
self._name2reg = {}
self._reg2name = {}
for instruction in instructions:
if isinstance(instruction, OpName):
name = instruction.name()
reg = instruction.decoratedRegister()
self._name2reg[name] = reg
self._reg2name[reg] = name
for chunk in chunks[1:]:
function = Function(chunk)
assert function.name() not in self._functions
self._functions[function.name()] = function
# Returns the register matching "name" if any, None otherwise.
# This assumes names are unique.
def getRegisterFromName(self, name):
if name in self._name2reg:
return self._name2reg[name]
return None
# Returns the name given to "register" if any, None otherwise.
def getNameFromRegister(self, register):
if register in self._reg2name:
return self._reg2name[register]
return None
# Initialize the module before wave execution begins.
# See Instruction::static_execution for more details.
def initialize(self, lane):
for instruction in self._globals:
instruction.static_execution(lane)
# Initialize builtins
for instruction in self._prolog:
if isinstance(instruction, OpDecorate):
instruction.static_execution(lane)
def execute_one_instruction(self, lane: Lane, ip: InstructionPointer) -> None:
ip.instruction().runtime_execution(self, lane)
# Returns the first valid IP for the function defined by the given register.
# Calling this with a register not returned by OpFunction is illegal.
def get_function_entry(self, register: str) -> InstructionPointer:
if register not in self._functions:
raise RuntimeError(f"Function defining {register} not found.")
return InstructionPointer(self._functions[register], 0, 0)
# Returns the first valid IP for the basic block defined by register.
# Calling this with a register not returned by an OpLabel is illegal.
def get_bb_entry(self, register: str) -> InstructionPointer:
for name, function in self._functions.items():
index = function.get_bb_index(register)
if index != -1:
return InstructionPointer(function, index, 0)
raise RuntimeError(f"Instruction defining {register} not found.")
# Returns the list of function names in this module.
# If an OpName exists for this function, returns the pretty name, else
# returns the register name.
def get_function_names(self):
return [self.getNameFromRegister(reg) for reg, func in self._functions.items()]
# Returns the global variables defined in this module.
def variables(self) -> Iterable:
return [x.output_register() for x in self._globals]
def dump(self, function_name: Optional[str] = None):
print("Module:")
print(" globals:")
for instruction in self._globals:
print(f" {instruction}")
if function_name is None:
print(" functions:")
for register, function in self._functions.items():
name = self.getNameFromRegister(register)
print(f" Function {register} ({name})")
function.dump()
return
register = self.getRegisterFromName(function_name)
print(f" function {register} ({function_name}):")
if register is not None:
self._functions[register].dump()
else:
print(f" error: cannot find function.")
# Defines a convergence requirement for the simulation:
# A list of lanes impacted by a merge and possibly the associated
# continue target.
@dataclass
class ConvergenceRequirement:
mergeTarget: InstructionPointer
continueTarget: Optional[InstructionPointer]
impactedLanes: set[int]
Task = Dict[InstructionPointer, List[Lane]]
# Defines a Lane group/Wave in the simulator.
class Wave:
# The module this wave will execute.
_module: Module
# The lanes this wave will be composed of.
_lanes: List[Lane]
# The instructions scheduled for execution.
_tasks: Task
# The actual requirements to comply with when executing instructions.
# E.g: the set of lanes required to merge before executing the merge block.
_convergence_requirements: List[ConvergenceRequirement]
# The indices of the active lanes for the current executing instruction.
_active_lane_indices: set[int]
def __init__(self, module, wave_size: int) -> None:
assert wave_size > 0
self._module = module
self._lanes = []
for i in range(wave_size):
self._lanes.append(Lane(self, i))
self._tasks = {}
self._convergence_requirements = []
# The indices of the active lanes for the current executing instruction.
self._active_lane_indices = set()
# Returns True if the given IP can be executed for the given list of lanes.
def _is_task_candidate(self, ip: InstructionPointer, lanes: List[Lane]):
merged_lanes: set[int] = set()
for lane in self._lanes:
if not lane.running():
merged_lanes.add(lane.tid())
for requirement in self._convergence_requirements:
# This task is not executing a merge or continue target.
# Adding all lanes at those points into the ignore list.
if requirement.mergeTarget != ip and requirement.continueTarget != ip:
for tid in requirement.impactedLanes:
if self._lanes[tid].ip() == requirement.mergeTarget:
merged_lanes.add(tid)
if self._lanes[tid].ip() == requirement.continueTarget:
merged_lanes.add(tid)
continue
# This task is executing the current requirement continue/merge
# target.
for tid in requirement.impactedLanes:
lane = self._lanes[tid]
if not lane.running():
continue
if lane.tid() in merged_lanes:
continue
if ip == requirement.mergeTarget:
if lane.ip() != requirement.mergeTarget:
return False
else:
if (
lane.ip() != requirement.mergeTarget
and lane.ip() != requirement.continueTarget
):
return False
return True
# Returns the next task we can schedule. This must always return a task.
# Calling this when all lanes are dead is invalid.
def _get_next_runnable_task(self) -> Tuple[InstructionPointer, List[Lane]]:
candidate = None
for ip, lanes in self._tasks.items():
if len(lanes) == 0:
continue
if self._is_task_candidate(ip, lanes):
candidate = ip
break
if candidate:
lanes = self._tasks[candidate]
del self._tasks[ip]
return (candidate, lanes)
raise RuntimeError("No task to execute. Deadlock?")
# Handle an encountered merge instruction for the given lane.
def handle_convergence_header(self, lane: Lane, instruction: MergeInstruction):
mergeTarget = self._module.get_bb_entry(instruction.merge_location())
for requirement in self._convergence_requirements:
if requirement.mergeTarget == mergeTarget:
requirement.impactedLanes.add(lane.tid())
return
continueTarget = None
if instruction.continue_location():
continueTarget = self._module.get_bb_entry(instruction.continue_location())
requirement = ConvergenceRequirement(
mergeTarget, continueTarget, set([lane.tid()])
)
self._convergence_requirements.append(requirement)
# Returns true if some instructions are scheduled for execution.
def _has_tasks(self) -> bool:
return len(self._tasks) > 0
# Returns the index of the first active lane right now.
def get_first_active_lane_index(self) -> int:
return min(self._active_lane_indices)
# Broadcast the given value to all active lane registers.
def broadcast_register(self, register: str, value: Any) -> None:
for tid in self._active_lane_indices:
self._lanes[tid].set_register(register, value)
# Returns the entrypoint of the function associated with 'name'.
# Calling this function with an invalid name is illegal.
def _get_function_entry_from_name(self, name: str) -> InstructionPointer:
register = self._module.getRegisterFromName(name)
assert register is not None
return self._module.get_function_entry(register)
# Run the wave on the function 'function_name' until all lanes are dead.
# If verbose is True, execution trace is printed.
# Returns the value returned by the function for each lane.
def run(self, function_name: str, verbose: bool = False) -> List[Any]:
for t in self._lanes:
self._module.initialize(t)
entry_ip = self._get_function_entry_from_name(function_name)
assert entry_ip is not None
for t in self._lanes:
t.do_call(entry_ip, "__shader_output__")
self._tasks[self._lanes[0].ip()] = self._lanes
while self._has_tasks():
ip, lanes = self._get_next_runnable_task()
self._active_lane_indices = set([x.tid() for x in lanes])
if verbose:
print(
f"Executing with lanes {self._active_lane_indices}: {ip.instruction()}"
)
for lane in lanes:
self._module.execute_one_instruction(lane, ip)
if not lane.running():
continue
if lane.ip() in self._tasks:
self._tasks[lane.ip()].append(lane)
else:
self._tasks[lane.ip()] = [lane]
if verbose and ip.instruction().has_output_register():
register = ip.instruction().output_register()
print(
f" {register:3} = {[ x.get_register(register, allow_undef=True) for x in lanes ]}"
)
output = []
for lane in self._lanes:
output.append(lane.get_register("__shader_output__"))
return output
def dump_register(self, register: str) -> None:
for lane in self._lanes:
print(
f" Lane {lane.tid():2} | {register:3} = {lane.get_register(register)}"
)
parser = argparse.ArgumentParser(
description="simulator", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"-i", "--input", help="Text SPIR-V to read from", required=False, default="-"
)
parser.add_argument("-f", "--function", help="Function to execute")
parser.add_argument("-w", "--wave", help="Wave size", default=32, required=False)
parser.add_argument(
"-e",
"--expects",
help="Expected results per lanes, expects a list of values. Ex: '1, 2, 3'.",
)
parser.add_argument("-v", "--verbose", help="verbose", action="store_true")
args = parser.parse_args()
def load_instructions(filename: str):
if filename is None:
return []
if filename.strip() != "-":
try:
with open(filename, "r") as f:
lines = f.read().split("\n")
except Exception: # (FileNotFoundError, PermissionError):
return []
else:
lines = sys.stdin.readlines()
# Remove leading/trailing whitespaces.
lines = [x.strip() for x in lines]
# Strip comments.
lines = [x for x in filter(lambda x: len(x) != 0 and x[0] != ";", lines)]
instructions = []
for i in [Instruction(x) for x in lines]:
out = parseInstruction(i)
if out != None:
instructions.append(out)
return instructions
def main():
if args.expects is None or not RE_EXPECTS.match(args.expects):
print("Invalid format for --expects/-e flag.", file=sys.stderr)
sys.exit(1)
if args.function is None:
print("Invalid format for --function/-f flag.", file=sys.stderr)
sys.exit(1)
try:
int(args.wave)
except ValueError:
print("Invalid format for --wave/-w flag.", file=sys.stderr)
sys.exit(1)
expected_results = [int(x.strip()) for x in args.expects.split(",")]
wave_size = int(args.wave)
if len(expected_results) != wave_size:
print("Wave size != expected result array size", file=sys.stderr)
sys.exit(1)
instructions = load_instructions(args.input)
if len(instructions) == 0:
print("Invalid input. Expected a text SPIR-V module.")
sys.exit(1)
module = Module(instructions)
if args.verbose:
module.dump()
module.dump(args.function)
function_names = module.get_function_names()
if args.function not in function_names:
print(
f"'{args.function}' function not found. Known functions are:",
file=sys.stderr,
)
for name in function_names:
print(f" - {name}", file=sys.stderr)
sys.exit(1)
wave = Wave(module, wave_size)
results = wave.run(args.function, verbose=args.verbose)
if expected_results != results:
print("Expected != Observed", file=sys.stderr)
print(f"{expected_results} != {results}", file=sys.stderr)
sys.exit(1)
sys.exit(0)
main()