#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Script for updating SPIR-V dialect by scraping information from SPIR-V
# HTML and JSON specs from the Internet.
#
# For example, to define the enum attribute for SPIR-V memory model:
#
# ./gen_spirv_dialect.py --base-td-path /path/to/SPIRVBase.td \
# --new-enum MemoryModel
#
# The 'operand_kinds' dict of spirv.core.grammar.json contains all supported
# SPIR-V enum classes.
import itertools
import math
import re
import requests
import textwrap
import yaml
SPIRV_HTML_SPEC_URL = (
"https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html"
)
SPIRV_JSON_SPEC_URL = "https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json"
SPIRV_CL_EXT_HTML_SPEC_URL = "https://www.khronos.org/registry/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html"
SPIRV_CL_EXT_JSON_SPEC_URL = "https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/extinst.opencl.std.100.grammar.json"
AUTOGEN_OP_DEF_SEPARATOR = "\n// -----\n\n"
AUTOGEN_ENUM_SECTION_MARKER = "enum section. Generated from SPIR-V spec; DO NOT MODIFY!"
AUTOGEN_OPCODE_SECTION_MARKER = (
"opcode section. Generated from SPIR-V spec; DO NOT MODIFY!"
)
def get_spirv_doc_from_html_spec(url, settings):
"""Extracts instruction documentation from SPIR-V HTML spec.
Returns:
- A dict mapping from instruction opcode to documentation.
"""
if url is None:
url = SPIRV_HTML_SPEC_URL
response = requests.get(url)
spec = response.content
from bs4 import BeautifulSoup
spirv = BeautifulSoup(spec, "html.parser")
doc = {}
if settings.gen_cl_ops:
section_anchor = spirv.find("h2", {"id": "_binary_form"})
for section in section_anchor.parent.find_all("div", {"class": "sect2"}):
for table in section.find_all("table"):
inst_html = table.tbody.tr.td
opname = inst_html.a["id"]
# Ignore the first line, which is just the opname.
doc[opname] = inst_html.text.split("\n", 1)[1].strip()
else:
section_anchor = spirv.find("h3", {"id": "_instructions_3"})
for section in section_anchor.parent.find_all("div", {"class": "sect3"}):
for table in section.find_all("table"):
inst_html = table.tbody.tr.td.p
opname = inst_html.a["id"]
# Ignore the first line, which is just the opname.
doc[opname] = inst_html.text.split("\n", 1)[1].strip()
return doc
def get_spirv_grammar_from_json_spec(url):
"""Extracts operand kind and instruction grammar from SPIR-V JSON spec.
Returns:
- A list containing all operand kinds' grammar
- A list containing all instructions' grammar
"""
response = requests.get(SPIRV_JSON_SPEC_URL)
spec = response.content
import json
spirv = json.loads(spec)
if url is None:
return spirv["operand_kinds"], spirv["instructions"]
response_ext = requests.get(url)
spec_ext = response_ext.content
spirv_ext = json.loads(spec_ext)
return spirv["operand_kinds"], spirv_ext["instructions"]
def split_list_into_sublists(items):
"""Split the list of items into multiple sublists.
This is to make sure the string composed from each sublist won't exceed
80 characters.
Arguments:
- items: a list of strings
"""
chuncks = []
chunk = []
chunk_len = 0
for item in items:
chunk_len += len(item) + 2
if chunk_len > 80:
chuncks.append(chunk)
chunk = []
chunk_len = len(item) + 2
chunk.append(item)
if len(chunk) != 0:
chuncks.append(chunk)
return chuncks
def uniquify_enum_cases(lst):
"""Prunes duplicate enum cases from the list.
Arguments:
- lst: List whose elements are to be uniqued. Assumes each element is a
(symbol, value) pair and elements already sorted according to value.
Returns:
- A list with all duplicates removed. The elements are sorted according to
value and, for each value, uniqued according to symbol.
original list,
- A map from deduplicated cases to the uniqued case.
"""
cases = lst
uniqued_cases = []
duplicated_cases = {}
# First sort according to the value
cases.sort(key=lambda x: x[1])
# Then group them according to the value
for _, groups in itertools.groupby(cases, key=lambda x: x[1]):
# For each value, sort according to the enumerant symbol.
sorted_group = sorted(groups, key=lambda x: x[0])
# Keep the "smallest" case, which is typically the symbol without extension
# suffix. But we have special cases that we want to fix.
case = sorted_group[0]
for i in range(1, len(sorted_group)):
duplicated_cases[sorted_group[i][0]] = case[0]
if case[0] == "HlslSemanticGOOGLE":
assert len(sorted_group) == 2, "unexpected new variant for HlslSemantic"
case = sorted_group[1]
duplicated_cases[sorted_group[0][0]] = case[0]
uniqued_cases.append(case)
return uniqued_cases, duplicated_cases
def toposort(dag, sort_fn):
"""Topologically sorts the given dag.
Arguments:
- dag: a dict mapping from a node to its incoming nodes.
- sort_fn: a function for sorting nodes in the same batch.
Returns:
A list containing topologically sorted nodes.
"""
# Returns the next batch of nodes without incoming edges
def get_next_batch(dag):
while True:
no_prev_nodes = set(node for node, prev in dag.items() if not prev)
if not no_prev_nodes:
break
yield sorted(no_prev_nodes, key=sort_fn)
dag = {
node: (prev - no_prev_nodes)
for node, prev in dag.items()
if node not in no_prev_nodes
}
assert not dag, "found cyclic dependency"
sorted_nodes = []
for batch in get_next_batch(dag):
sorted_nodes.extend(batch)
return sorted_nodes
def toposort_capabilities(all_cases, capability_mapping):
"""Returns topologically sorted capability (symbol, value) pairs.
Arguments:
- all_cases: all capability cases (containing symbol, value, and implied
capabilities).
- capability_mapping: mapping from duplicated capability symbols to the
canonicalized symbol chosen for SPIRVBase.td.
Returns:
A list containing topologically sorted capability (symbol, value) pairs.
"""
dag = {}
name_to_value = {}
for case in all_cases:
# Get the current capability.
cur = case["enumerant"]
name_to_value[cur] = case["value"]
# Ignore duplicated symbols.
if cur in capability_mapping:
continue
# Get capabilities implied by the current capability.
prev = case.get("capabilities", [])
uniqued_prev = set([capability_mapping.get(c, c) for c in prev])
dag[cur] = uniqued_prev
sorted_caps = toposort(dag, lambda x: name_to_value[x])
# Attach the capability's value as the second component of the pair.
return [(c, name_to_value[c]) for c in sorted_caps]
def get_capability_mapping(operand_kinds):
"""Returns the capability mapping from duplicated cases to canonicalized ones.
Arguments:
- operand_kinds: all operand kinds' grammar spec
Returns:
- A map mapping from duplicated capability symbols to the canonicalized
symbol chosen for SPIRVBase.td.
"""
# Find the operand kind for capability
cap_kind = {}
for kind in operand_kinds:
if kind["kind"] == "Capability":
cap_kind = kind
kind_cases = [(case["enumerant"], case["value"]) for case in cap_kind["enumerants"]]
_, capability_mapping = uniquify_enum_cases(kind_cases)
return capability_mapping
def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
"""Returns the availability specification string for the given enum case.
Arguments:
- enum_case: the enum case to generate availability spec for. It may contain
'version', 'lastVersion', 'extensions', or 'capabilities'.
- capability_mapping: mapping from duplicated capability symbols to the
canonicalized symbol chosen for SPIRVBase.td.
- for_op: bool value indicating whether this is the availability spec for an
op itself.
- for_cap: bool value indicating whether this is the availability spec for
capabilities themselves.
Returns:
- A `let availability = [...];` string if with availability spec or
empty string if without availability spec
"""
assert not (for_op and for_cap), "cannot set both for_op and for_cap"
DEFAULT_MIN_VERSION = "MinVersion<SPIRV_V_1_0>"
DEFAULT_MAX_VERSION = "MaxVersion<SPIRV_V_1_6>"
DEFAULT_CAP = "Capability<[]>"
DEFAULT_EXT = "Extension<[]>"
min_version = enum_case.get("version", "")
if min_version == "None":
min_version = ""
elif min_version:
min_version = "MinVersion<SPIRV_V_{}>".format(min_version.replace(".", "_"))
# TODO: delete this once ODS can support dialect-specific content
# and we can use omission to mean no requirements.
if for_op and not min_version:
min_version = DEFAULT_MIN_VERSION
max_version = enum_case.get("lastVersion", "")
if max_version:
max_version = "MaxVersion<SPIRV_V_{}>".format(max_version.replace(".", "_"))
# TODO: delete this once ODS can support dialect-specific content
# and we can use omission to mean no requirements.
if for_op and not max_version:
max_version = DEFAULT_MAX_VERSION
exts = enum_case.get("extensions", [])
if exts:
exts = "Extension<[{}]>".format(", ".join(sorted(set(exts))))
# We need to strip the minimal version requirement if this symbol is
# available via an extension, which means *any* SPIR-V version can support
# it as long as the extension is provided. The grammar's 'version' field
# under such case should be interpreted as this symbol is introduced as
# a core symbol since the given version, rather than a minimal version
# requirement.
min_version = DEFAULT_MIN_VERSION if for_op else ""
# TODO: delete this once ODS can support dialect-specific content
# and we can use omission to mean no requirements.
if for_op and not exts:
exts = DEFAULT_EXT
caps = enum_case.get("capabilities", [])
implies = ""
if caps:
canonicalized_caps = []
for c in caps:
if c in capability_mapping:
canonicalized_caps.append(capability_mapping[c])
else:
canonicalized_caps.append(c)
prefixed_caps = [
"SPIRV_C_{}".format(c) for c in sorted(set(canonicalized_caps))
]
if for_cap:
# If this is generating the availability for capabilities, we need to
# put the capability "requirements" in implies field because now
# the "capabilities" field in the source grammar means so.
caps = ""
implies = "list<I32EnumAttrCase> implies = [{}];".format(
", ".join(prefixed_caps)
)
else:
caps = "Capability<[{}]>".format(", ".join(prefixed_caps))
implies = ""
# TODO: delete this once ODS can support dialect-specific content
# and we can use omission to mean no requirements.
if for_op and not caps:
caps = DEFAULT_CAP
avail = ""
# Compose availability spec if any of the requirements is not empty.
# For ops, because we have a default in SPIRV_Op class, omit if the spec
# is the same.
if (min_version or max_version or caps or exts) and not (
for_op
and min_version == DEFAULT_MIN_VERSION
and max_version == DEFAULT_MAX_VERSION
and caps == DEFAULT_CAP
and exts == DEFAULT_EXT
):
joined_spec = ",\n ".join(
[e for e in [min_version, max_version, exts, caps] if e]
)
avail = "{} availability = [\n {}\n ];".format(
"let" if for_op else "list<Availability>", joined_spec
)
return "{}{}{}".format(implies, "\n " if implies and avail else "", avail)
def gen_operand_kind_enum_attr(operand_kind, capability_mapping):
"""Generates the TableGen EnumAttr definition for the given operand kind.
Returns:
- The operand kind's name
- A string containing the TableGen EnumAttr definition
"""
if "enumerants" not in operand_kind:
return "", ""
# Returns a symbol for the given case in the given kind. This function
# handles Dim specially to avoid having numbers as the start of symbols,
# which does not play well with C++ and the MLIR parser.
def get_case_symbol(kind_name, case_name):
if kind_name == "Dim":
if case_name == "1D" or case_name == "2D" or case_name == "3D":
return "Dim{}".format(case_name)
return case_name
kind_name = operand_kind["kind"]
is_bit_enum = operand_kind["category"] == "BitEnum"
kind_acronym = "".join([c for c in kind_name if c >= "A" and c <= "Z"])
name_to_case_dict = {}
for case in operand_kind["enumerants"]:
name_to_case_dict[case["enumerant"]] = case
if kind_name == "Capability":
# Special treatment for capability cases: we need to sort them topologically
# because a capability can refer to another via the 'implies' field.
kind_cases = toposort_capabilities(
operand_kind["enumerants"], capability_mapping
)
else:
kind_cases = [
(case["enumerant"], case["value"]) for case in operand_kind["enumerants"]
]
kind_cases, _ = uniquify_enum_cases(kind_cases)
max_len = max([len(symbol) for (symbol, _) in kind_cases])
# Generate the definition for each enum case
case_category = "I32Bit" if is_bit_enum else "I32"
fmt_str = (
"def SPIRV_{acronym}_{case_name} {colon:>{offset}} "
'{category}EnumAttrCase{suffix}<"{symbol}"{case_value_part}>{avail}'
)
case_defs = []
for case_pair in kind_cases:
name = case_pair[0]
if is_bit_enum:
value = int(case_pair[1], base=16)
else:
value = int(case_pair[1])
avail = get_availability_spec(
name_to_case_dict[name],
capability_mapping,
False,
kind_name == "Capability",
)
if is_bit_enum:
if value == 0:
suffix = "None"
value = ""
else:
suffix = "Bit"
value = ", {}".format(int(math.log2(value)))
else:
suffix = ""
value = ", {}".format(value)
case_def = fmt_str.format(
category=case_category,
suffix=suffix,
acronym=kind_acronym,
case_name=name,
symbol=get_case_symbol(kind_name, name),
case_value_part=value,
avail=" {{\n {}\n}}".format(avail) if avail else ";",
colon=":",
offset=(max_len + 1 - len(name)),
)
case_defs.append(case_def)
case_defs = "\n".join(case_defs)
# Generate the list of enum case names
fmt_str = "SPIRV_{acronym}_{symbol}"
case_names = [
fmt_str.format(acronym=kind_acronym, symbol=case[0]) for case in kind_cases
]
# Split them into sublists and concatenate into multiple lines
case_names = split_list_into_sublists(case_names)
case_names = ["{:6}".format("") + ", ".join(sublist) for sublist in case_names]
case_names = ",\n".join(case_names)
# Generate the enum attribute definition
kind_category = "Bit" if is_bit_enum else "I32"
enum_attr = """def SPIRV_{name}Attr :
SPIRV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", "{snake_name}", [
{cases}
]>;""".format(
name=kind_name,
snake_name=snake_casify(kind_name),
category=kind_category,
cases=case_names,
)
return kind_name, case_defs + "\n\n" + enum_attr
def gen_opcode(instructions):
"""Generates the TableGen definition to map opname to opcode
Returns:
- A string containing the TableGen SPIRV_OpCode definition
"""
max_len = max([len(inst["opname"]) for inst in instructions])
def_fmt_str = (
"def SPIRV_OC_{name} {colon:>{offset}} " 'I32EnumAttrCase<"{name}", {value}>;'
)
opcode_defs = [
def_fmt_str.format(
name=inst["opname"],
value=inst["opcode"],
colon=":",
offset=(max_len + 1 - len(inst["opname"])),
)
for inst in instructions
]
opcode_str = "\n".join(opcode_defs)
decl_fmt_str = "SPIRV_OC_{name}"
opcode_list = [decl_fmt_str.format(name=inst["opname"]) for inst in instructions]
opcode_list = split_list_into_sublists(opcode_list)
opcode_list = ["{:6}".format("") + ", ".join(sublist) for sublist in opcode_list]
opcode_list = ",\n".join(opcode_list)
enum_attr = (
"def SPIRV_OpcodeAttr :\n"
' SPIRV_I32EnumAttr<"{name}", "valid SPIR-V instructions", '
'"opcode", [\n'
"{lst}\n"
" ]>;".format(name="Opcode", lst=opcode_list)
)
return opcode_str + "\n\n" + enum_attr
def map_cap_to_opnames(instructions):
"""Maps capabilities to instructions enabled by those capabilities
Arguments:
- instructions: a list containing a subset of SPIR-V instructions' grammar
Returns:
- A map with keys representing capabilities and values of lists of
instructions enabled by the corresponding key
"""
cap_to_inst = {}
for inst in instructions:
caps = inst["capabilities"] if "capabilities" in inst else ["0_core_0"]
for cap in caps:
if cap not in cap_to_inst:
cap_to_inst[cap] = []
cap_to_inst[cap].append(inst["opname"])
return cap_to_inst
def gen_instr_coverage_report(path, instructions):
"""Dumps to standard output a YAML report of current instruction coverage
Arguments:
- path: the path to SPIRBase.td
- instructions: a list containing all SPIR-V instructions' grammar
"""
with open(path, "r") as f:
content = f.read()
content = content.split(AUTOGEN_OPCODE_SECTION_MARKER)
prefix = "def SPIRV_OC_"
existing_opcodes = [
k[len(prefix) :] for k in re.findall(prefix + "\w+", content[1])
]
existing_instructions = list(
filter(lambda inst: (inst["opname"] in existing_opcodes), instructions)
)
instructions_opnames = [inst["opname"] for inst in instructions]
remaining_opcodes = list(set(instructions_opnames) - set(existing_opcodes))
remaining_instructions = list(
filter(lambda inst: (inst["opname"] in remaining_opcodes), instructions)
)
rem_cap_to_instr = map_cap_to_opnames(remaining_instructions)
ex_cap_to_instr = map_cap_to_opnames(existing_instructions)
rem_cap_to_cov = {}
# Calculate coverage for each capability
for cap in rem_cap_to_instr:
if cap not in ex_cap_to_instr:
rem_cap_to_cov[cap] = 0.0
else:
rem_cap_to_cov[cap] = len(ex_cap_to_instr[cap]) / (
len(ex_cap_to_instr[cap]) + len(rem_cap_to_instr[cap])
)
report = {}
# Merge the 3 maps into one report
for cap in rem_cap_to_instr:
report[cap] = {}
report[cap]["Supported Instructions"] = (
ex_cap_to_instr[cap] if cap in ex_cap_to_instr else []
)
report[cap]["Unsupported Instructions"] = rem_cap_to_instr[cap]
report[cap]["Coverage"] = "{}%".format(int(rem_cap_to_cov[cap] * 100))
print(yaml.dump(report))
def update_td_opcodes(path, instructions, filter_list):
"""Updates SPIRBase.td with new generated opcode cases.
Arguments:
- path: the path to SPIRBase.td
- instructions: a list containing all SPIR-V instructions' grammar
- filter_list: a list containing new opnames to add
"""
with open(path, "r") as f:
content = f.read()
content = content.split(AUTOGEN_OPCODE_SECTION_MARKER)
assert len(content) == 3
# Extend opcode list with existing list
prefix = "def SPIRV_OC_"
existing_opcodes = [
k[len(prefix) :] for k in re.findall(prefix + "\w+", content[1])
]
filter_list.extend(existing_opcodes)
filter_list = list(set(filter_list))
# Generate the opcode for all instructions in SPIR-V
filter_instrs = list(
filter(lambda inst: (inst["opname"] in filter_list), instructions)
)
# Sort instruction based on opcode
filter_instrs.sort(key=lambda inst: inst["opcode"])
opcode = gen_opcode(filter_instrs)
# Substitute the opcode
content = (
content[0]
+ AUTOGEN_OPCODE_SECTION_MARKER
+ "\n\n"
+ opcode
+ "\n\n// End "
+ AUTOGEN_OPCODE_SECTION_MARKER
+ content[2]
)
with open(path, "w") as f:
f.write(content)
def update_td_enum_attrs(path, operand_kinds, filter_list):
"""Updates SPIRBase.td with new generated enum definitions.
Arguments:
- path: the path to SPIRBase.td
- operand_kinds: a list containing all operand kinds' grammar
- filter_list: a list containing new enums to add
"""
with open(path, "r") as f:
content = f.read()
content = content.split(AUTOGEN_ENUM_SECTION_MARKER)
assert len(content) == 3
# Extend filter list with existing enum definitions
prefix = "def SPIRV_"
suffix = "Attr"
existing_kinds = [
k[len(prefix) : -len(suffix)]
for k in re.findall(prefix + "\w+" + suffix, content[1])
]
filter_list.extend(existing_kinds)
capability_mapping = get_capability_mapping(operand_kinds)
# Generate definitions for all enums in filter list
defs = [
gen_operand_kind_enum_attr(kind, capability_mapping)
for kind in operand_kinds
if kind["kind"] in filter_list
]
# Sort alphabetically according to enum name
defs.sort(key=lambda enum: enum[0])
# Only keep the definitions from now on
# Put Capability's definition at the very beginning because capability cases
# will be referenced later
defs = [enum[1] for enum in defs if enum[0] == "Capability"] + [
enum[1] for enum in defs if enum[0] != "Capability"
]
# Substitute the old section
content = (
content[0]
+ AUTOGEN_ENUM_SECTION_MARKER
+ "\n\n"
+ "\n\n".join(defs)
+ "\n\n// End "
+ AUTOGEN_ENUM_SECTION_MARKER
+ content[2]
)
with open(path, "w") as f:
f.write(content)
def snake_casify(name):
"""Turns the given name to follow snake_case convention."""
return re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
def map_spec_operand_to_ods_argument(operand):
"""Maps an operand in SPIR-V JSON spec to an op argument in ODS.
Arguments:
- A dict containing the operand's kind, quantifier, and name
Returns:
- A string containing both the type and name for the argument
"""
kind = operand["kind"]
quantifier = operand.get("quantifier", "")
# These instruction "operands" are for encoding the results; they should
# not be handled here.
assert kind != "IdResultType", 'unexpected to handle "IdResultType" kind'
assert kind != "IdResult", 'unexpected to handle "IdResult" kind'
if kind == "IdRef":
if quantifier == "":
arg_type = "SPIRV_Type"
elif quantifier == "?":
arg_type = "Optional<SPIRV_Type>"
else:
arg_type = "Variadic<SPIRV_Type>"
elif kind == "IdMemorySemantics" or kind == "IdScope":
# TODO: Need to further constrain 'IdMemorySemantics'
# and 'IdScope' given that they should be generated from OpConstant.
assert quantifier == "", (
"unexpected to have optional/variadic memory " "semantics or scope <id>"
)
arg_type = "SPIRV_" + kind[2:] + "Attr"
elif kind == "LiteralInteger":
if quantifier == "":
arg_type = "I32Attr"
elif quantifier == "?":
arg_type = "OptionalAttr<I32Attr>"
else:
arg_type = "OptionalAttr<I32ArrayAttr>"
elif (
kind == "LiteralString"
or kind == "LiteralContextDependentNumber"
or kind == "LiteralExtInstInteger"
or kind == "LiteralSpecConstantOpInteger"
or kind == "PairLiteralIntegerIdRef"
or kind == "PairIdRefLiteralInteger"
or kind == "PairIdRefIdRef"
):
assert False, '"{}" kind unimplemented'.format(kind)
else:
# The rest are all enum operands that we represent with op attributes.
assert quantifier != "*", "unexpected to have variadic enum attribute"
arg_type = "SPIRV_{}Attr".format(kind)
if quantifier == "?":
arg_type = "OptionalAttr<{}>".format(arg_type)
name = operand.get("name", "")
name = snake_casify(name) if name else kind.lower()
return "{}:${}".format(arg_type, name)
def get_description(text, appendix):
"""Generates the description for the given SPIR-V instruction.
Arguments:
- text: Textual description of the operation as string.
- appendix: Additional contents to attach in description as string,
includking IR examples, and others.
Returns:
- A string that corresponds to the description of the Tablegen op.
"""
fmt_str = "{text}\n\n <!-- End of AutoGen section -->\n{appendix}\n "
return fmt_str.format(text=text, appendix=appendix)
def get_op_definition(
instruction, opname, doc, existing_info, capability_mapping, settings
):
"""Generates the TableGen op definition for the given SPIR-V instruction.
Arguments:
- instruction: the instruction's SPIR-V JSON grammar
- doc: the instruction's SPIR-V HTML doc
- existing_info: a dict containing potential manually specified sections for
this instruction
- capability_mapping: mapping from duplicated capability symbols to the
canonicalized symbol chosen for SPIRVBase.td
Returns:
- A string containing the TableGen op definition
"""
if settings.gen_cl_ops:
fmt_str = (
"def SPIRV_{opname}Op : "
'SPIRV_{inst_category}<"{opname_src}", {opcode}, <<Insert result type>> > '
"{{\n let summary = {summary};\n\n let description = "
"[{{\n{description}}}];{availability}\n"
)
else:
fmt_str = (
"def SPIRV_{vendor_name}{opname_src}Op : "
'SPIRV_{inst_category}<"{opname_src}"{category_args}, [{traits}]> '
"{{\n let summary = {summary};\n\n let description = "
"[{{\n{description}}}];{availability}\n"
)
vendor_name = ""
inst_category = existing_info.get("inst_category", "Op")
if inst_category == "Op":
fmt_str += (
"\n let arguments = (ins{args});\n\n" " let results = (outs{results});\n"
)
elif inst_category.endswith("VendorOp"):
vendor_name = inst_category.split("VendorOp")[0].upper()
assert len(vendor_name) != 0, "Invalid instruction category"
fmt_str += "{extras}" "}}\n"
opname_src = instruction["opname"]
if opname.startswith("Op"):
opname_src = opname_src[2:]
if len(vendor_name) > 0:
assert opname_src.endswith(
vendor_name
), "op name does not match the instruction category"
opname_src = opname_src[: -len(vendor_name)]
category_args = existing_info.get("category_args", "")
if "\n" in doc:
summary, text = doc.split("\n", 1)
else:
summary = doc
text = ""
wrapper = textwrap.TextWrapper(
width=76, initial_indent=" ", subsequent_indent=" "
)
# Format summary. If the summary can fit in the same line, we print it out
# as a "-quoted string; otherwise, wrap the lines using "[{...}]".
summary = summary.strip()
if len(summary) + len(' let summary = "";') <= 80:
summary = '"{}"'.format(summary)
else:
summary = "[{{\n{}\n }}]".format(wrapper.fill(summary))
# Wrap text
text = text.split("\n")
text = [wrapper.fill(line) for line in text if line]
text = "\n\n".join(text)
operands = instruction.get("operands", [])
# Op availability
avail = get_availability_spec(instruction, capability_mapping, True, False)
if avail:
avail = "\n\n {0}".format(avail)
# Set op's result
results = ""
if len(operands) > 0 and operands[0]["kind"] == "IdResultType":
results = "\n SPIRV_Type:$result\n "
operands = operands[1:]
if "results" in existing_info:
results = existing_info["results"]
# Ignore the operand standing for the result <id>
if len(operands) > 0 and operands[0]["kind"] == "IdResult":
operands = operands[1:]
# Set op' argument
arguments = existing_info.get("arguments", None)
if arguments is None:
arguments = [map_spec_operand_to_ods_argument(o) for o in operands]
arguments = ",\n ".join(arguments)
if arguments:
# Prepend and append whitespace for formatting
arguments = "\n {}\n ".format(arguments)
description = existing_info.get("description", None)
if description is None:
assembly = (
"\n ```\n"
" [TODO]\n"
" ```\n\n"
" #### Example:\n\n"
" ```mlir\n"
" [TODO]\n"
" ```"
)
description = get_description(text, assembly)
return fmt_str.format(
opname=opname,
opname_src=opname_src,
opcode=instruction["opcode"],
category_args=category_args,
inst_category=inst_category,
vendor_name=vendor_name,
traits=existing_info.get("traits", ""),
summary=summary,
description=description,
availability=avail,
args=arguments,
results=results,
extras=existing_info.get("extras", ""),
)
def get_string_between(base, start, end):
"""Extracts a substring with a specified start and end from a string.
Arguments:
- base: string to extract from.
- start: string to use as the start of the substring.
- end: string to use as the end of the substring.
Returns:
- The substring if found
- The part of the base after end of the substring. Is the base string itself
if the substring wasnt found.
"""
split = base.split(start, 1)
if len(split) == 2:
rest = split[1].split(end, 1)
assert len(rest) == 2, (
'cannot find end "{end}" while extracting substring '
"starting with {start}".format(start=start, end=end)
)
return rest[0].rstrip(end), rest[1]
return "", split[0]
def get_string_between_nested(base, start, end):
"""Extracts a substring with a nested start and end from a string.
Arguments:
- base: string to extract from.
- start: string to use as the start of the substring.
- end: string to use as the end of the substring.
Returns:
- The substring if found
- The part of the base after end of the substring. Is the base string itself
if the substring wasn't found.
"""
split = base.split(start, 1)
if len(split) == 2:
# Handle nesting delimiters
rest = split[1]
unmatched_start = 1
index = 0
while unmatched_start > 0 and index < len(rest):
if rest[index:].startswith(end):
unmatched_start -= 1
if unmatched_start == 0:
break
index += len(end)
elif rest[index:].startswith(start):
unmatched_start += 1
index += len(start)
else:
index += 1
assert index < len(rest), (
'cannot find end "{end}" while extracting substring '
'starting with "{start}"'.format(start=start, end=end)
)
return rest[:index], rest[index + len(end) :]
return "", split[0]
def extract_td_op_info(op_def):
"""Extracts potentially manually specified sections in op's definition.
Arguments: - A string containing the op's TableGen definition
Returns:
- A dict containing potential manually specified sections
"""
# Get opname
prefix = "def SPIRV_"
suffix = "Op"
opname = [
o[len(prefix) : -len(suffix)]
for o in re.findall(prefix + "\w+" + suffix, op_def)
]
assert len(opname) == 1, "more than one ops in the same section!"
opname = opname[0]
# Get instruction category
prefix = "SPIRV_"
inst_category = [
o[len(prefix) :] for o in re.findall(prefix + "\w+Op", op_def.split(":", 1)[1])
]
assert len(inst_category) <= 1, "more than one ops in the same section!"
inst_category = inst_category[0] if len(inst_category) == 1 else "Op"
# Get category_args
op_tmpl_params, _ = get_string_between_nested(op_def, "<", ">")
opstringname, rest = get_string_between(op_tmpl_params, '"', '"')
category_args = rest.split("[", 1)[0]
# Get traits
traits, _ = get_string_between_nested(rest, "[", "]")
# Get description
description, rest = get_string_between(op_def, "let description = [{\n", "}];\n")
# Get arguments
args, rest = get_string_between(rest, " let arguments = (ins", ");\n")
# Get results
results, rest = get_string_between(rest, " let results = (outs", ");\n")
extras = rest.strip(" }\n")
if extras:
extras = "\n {}\n".format(extras)
return {
# Prefix with 'Op' to make it consistent with SPIR-V spec
"opname": "Op{}".format(opname),
"inst_category": inst_category,
"category_args": category_args,
"traits": traits,
"description": description,
"arguments": args,
"results": results,
"extras": extras,
}
def update_td_op_definitions(
path, instructions, docs, filter_list, inst_category, capability_mapping, settings
):
"""Updates SPIRVOps.td with newly generated op definition.
Arguments:
- path: path to SPIRVOps.td
- instructions: SPIR-V JSON grammar for all instructions
- docs: SPIR-V HTML doc for all instructions
- filter_list: a list containing new opnames to include
- capability_mapping: mapping from duplicated capability symbols to the
canonicalized symbol chosen for SPIRVBase.td.
Returns:
- A string containing all the TableGen op definitions
"""
with open(path, "r") as f:
content = f.read()
# Split the file into chunks, each containing one op.
ops = content.split(AUTOGEN_OP_DEF_SEPARATOR)
header = ops[0]
footer = ops[-1]
ops = ops[1:-1]
# For each existing op, extract the manually-written sections out to retain
# them when re-generating the ops. Also append the existing ops to filter
# list.
name_op_map = {} # Map from opname to its existing ODS definition
op_info_dict = {}
for op in ops:
info_dict = extract_td_op_info(op)
opname = info_dict["opname"]
name_op_map[opname] = op
op_info_dict[opname] = info_dict
filter_list.append(opname)
filter_list = sorted(list(set(filter_list)))
op_defs = []
if settings.gen_cl_ops:
fix_opname = lambda src: src.replace("CL", "").lower()
else:
fix_opname = lambda src: src
for opname in filter_list:
# Find the grammar spec for this op
try:
fixed_opname = fix_opname(opname)
instruction = next(
inst for inst in instructions if inst["opname"] == fixed_opname
)
op_defs.append(
get_op_definition(
instruction,
opname,
docs[fixed_opname],
op_info_dict.get(opname, {"inst_category": inst_category}),
capability_mapping,
settings,
)
)
except StopIteration:
# This is an op added by us; use the existing ODS definition.
op_defs.append(name_op_map[opname])
# Substitute the old op definitions
op_defs = [header] + op_defs + [footer]
content = AUTOGEN_OP_DEF_SEPARATOR.join(op_defs)
with open(path, "w") as f:
f.write(content)
if __name__ == "__main__":
import argparse
cli_parser = argparse.ArgumentParser(
description="Update SPIR-V dialect definitions using SPIR-V spec"
)
cli_parser.add_argument(
"--base-td-path",
dest="base_td_path",
type=str,
default=None,
help="Path to SPIRVBase.td",
)
cli_parser.add_argument(
"--op-td-path",
dest="op_td_path",
type=str,
default=None,
help="Path to SPIRVOps.td",
)
cli_parser.add_argument(
"--new-enum",
dest="new_enum",
type=str,
default=None,
help="SPIR-V enum to be added to SPIRVBase.td",
)
cli_parser.add_argument(
"--new-opcodes",
dest="new_opcodes",
type=str,
default=None,
nargs="*",
help="update SPIR-V opcodes in SPIRVBase.td",
)
cli_parser.add_argument(
"--new-inst",
dest="new_inst",
type=str,
default=None,
nargs="*",
help="SPIR-V instruction to be added to ops file",
)
cli_parser.add_argument(
"--inst-category",
dest="inst_category",
type=str,
default="Op",
help="SPIR-V instruction category used for choosing "
"the TableGen base class to define this op",
)
cli_parser.add_argument(
"--gen-cl-ops",
dest="gen_cl_ops",
help="Generate OpenCL Extended Instruction Set op",
action="store_true",
)
cli_parser.set_defaults(gen_cl_ops=False)
cli_parser.add_argument(
"--gen-inst-coverage", dest="gen_inst_coverage", action="store_true"
)
cli_parser.set_defaults(gen_inst_coverage=False)
args = cli_parser.parse_args()
if args.gen_cl_ops:
ext_html_url = SPIRV_CL_EXT_HTML_SPEC_URL
ext_json_url = SPIRV_CL_EXT_JSON_SPEC_URL
else:
ext_html_url = None
ext_json_url = None
operand_kinds, instructions = get_spirv_grammar_from_json_spec(ext_json_url)
# Define new enum attr
if args.new_enum is not None:
assert args.base_td_path is not None
filter_list = [args.new_enum] if args.new_enum else []
update_td_enum_attrs(args.base_td_path, operand_kinds, filter_list)
# Define new opcode
if args.new_opcodes is not None:
assert args.base_td_path is not None
update_td_opcodes(args.base_td_path, instructions, args.new_opcodes)
# Define new op
if args.new_inst is not None:
assert args.op_td_path is not None
docs = get_spirv_doc_from_html_spec(ext_html_url, args)
capability_mapping = get_capability_mapping(operand_kinds)
update_td_op_definitions(
args.op_td_path,
instructions,
docs,
args.new_inst,
args.inst_category,
capability_mapping,
args,
)
print("Done. Note that this script just generates a template; ", end="")
print("please read the spec and update traits, arguments, and ", end="")
print("results accordingly.")
if args.gen_inst_coverage:
gen_instr_coverage_report(args.base_td_path, instructions)