llvm/llvm/test/CodeGen/WebAssembly/multivalue-stackify.py

#!/usr/bin/env python3

"""A test case generator for register stackification.

This script exhaustively generates small linear SSA programs, then filters them
based on heuristics designed to keep interesting multivalue test cases and
prints them as LLVM IR functions in a FileCheck test file.

The output of this script is meant to be used in conjunction with
update_llc_test_checks.py.

  ```
  ./multivalue-stackify.py > multivalue-stackify.ll
  ../../../utils/update_llc_test_checks.py multivalue-stackify.ll
  ```

Programs are represented internally as lists of operations, where each operation
is a pair of tuples, the first of which specifies the operation's uses and the
second of which specifies its defs.

TODO: Before embarking on a rewrite of the register stackifier, an abstract
interpreter should be written to automatically check that the test assertions
generated by update_llc_test_checks.py have the same semantics as the functions
generated by this script. Once that is done, exhaustive testing can be done by
making `is_interesting` return True.
"""


from itertools import product
from collections import deque


MAX_PROGRAM_OPS = 4
MAX_PROGRAM_DEFS = 3
MAX_OP_USES = 2


def get_num_defs(program):
    num_defs = 0
    for _, defs in program:
        num_defs += len(defs)
    return num_defs


def possible_ops(program):
    program_defs = get_num_defs(program)
    for num_defs in range(MAX_PROGRAM_DEFS - program_defs + 1):
        for num_uses in range(MAX_OP_USES + 1):
            if num_defs == 0 and num_uses == 0:
                continue
            for uses in product(range(program_defs), repeat=num_uses):
                yield uses, tuple(program_defs + i for i in range(num_defs))


def generate_programs():
    queue = deque()
    queue.append([])
    program_id = 0
    while True:
        program = queue.popleft()
        if len(program) == MAX_PROGRAM_OPS:
            break
        for op in possible_ops(program):
            program_id += 1
            new_program = program + [op]
            queue.append(new_program)
            yield program_id, new_program


def get_num_terminal_ops(program):
    num_terminal_ops = 0
    for _, defs in program:
        if len(defs) == 0:
            num_terminal_ops += 1
    return num_terminal_ops


def get_max_uses(program):
    num_uses = [0] * MAX_PROGRAM_DEFS
    for uses, _ in program:
        for u in uses:
            num_uses[u] += 1
    return max(num_uses)


def has_unused_op(program):
    used = [False] * MAX_PROGRAM_DEFS
    for uses, defs in program[::-1]:
        if defs and all(not used[d] for d in defs):
            return True
        for u in uses:
            used[u] = True
    return False


def has_multivalue_use(program):
    is_multi = [False] * MAX_PROGRAM_DEFS
    for uses, defs in program:
        if any(is_multi[u] for u in uses):
            return True
        if len(defs) >= 2:
            for d in defs:
                is_multi[d] = True
    return False


def has_mvp_use(program):
    is_mvp = [False] * MAX_PROGRAM_DEFS
    for uses, defs in program:
        if uses and all(is_mvp[u] for u in uses):
            return True
        if len(defs) <= 1:
            if any(is_mvp[u] for u in uses):
                return True
            for d in defs:
                is_mvp[d] = True
    return False


def is_interesting(program):
    # Allow only multivalue single-op programs
    if len(program) == 1:
        return len(program[0][1]) > 1

    # Reject programs where the last two instructions are identical
    if len(program) >= 2 and program[-1][0] == program[-2][0]:
        return False

    # Reject programs with too many ops that don't produce values
    if get_num_terminal_ops(program) > 2:
        return False

    # The third use of a value is no more interesting than the second
    if get_max_uses(program) >= 3:
        return False

    # Reject nontrivial programs that have unused instructions
    if has_unused_op(program):
        return False

    # Reject programs that have boring MVP uses of MVP defs
    if has_mvp_use(program):
        return False

    # Otherwise if it has multivalue usage it is interesting
    return has_multivalue_use(program)


def make_llvm_type(num_defs):
    if num_defs == 0:
        return "void"
    else:
        return "{" + ", ".join(["i32"] * num_defs) + "}"


def make_llvm_op_name(num_uses, num_defs):
    return f"op_{num_uses}_to_{num_defs}"


def make_llvm_args(first_use, num_uses):
    return ", ".join([f"i32 %t{first_use + i}" for i in range(num_uses)])


def print_llvm_program(program, name):
    tmp = 0
    def_data = []
    print(f"define void @{name}() {{")
    for uses, defs in program:
        first_arg = tmp
        # Extract operands
        for use in uses:
            ret_type, var, idx = def_data[use]
            print(f"  %t{tmp} = extractvalue {ret_type} %t{var}, {idx}")
            tmp += 1
        # Print instruction
        assignment = ""
        if len(defs) > 0:
            assignment = f"%t{tmp} = "
            result_var = tmp
            tmp += 1
        ret_type = make_llvm_type(len(defs))
        op_name = make_llvm_op_name(len(uses), len(defs))
        args = make_llvm_args(first_arg, len(uses))
        print(f"  {assignment}call {ret_type} @{op_name}({args})")
        # Update def_data
        for i in range(len(defs)):
            def_data.append((ret_type, result_var, i))
    print("  ret void")
    print("}")


def print_header():
    print("; NOTE: Test functions have been generated by multivalue-stackify.py.")
    print()
    print("; RUN: llc < %s -verify-machineinstrs -mattr=+multivalue", "| FileCheck %s")
    print()
    print("; Test that the multivalue stackification works")
    print()
    print('target triple = "wasm32-unknown-unknown"')
    print()
    for num_uses in range(MAX_OP_USES + 1):
        for num_defs in range(MAX_PROGRAM_DEFS + 1):
            if num_uses == 0 and num_defs == 0:
                continue
            ret_type = make_llvm_type(num_defs)
            op_name = make_llvm_op_name(num_uses, num_defs)
            args = make_llvm_args(0, num_uses)
            print(f"declare {ret_type} @{op_name}({args})")
    print()


if __name__ == "__main__":
    print_header()
    for i, program in generate_programs():
        if is_interesting(program):
            print_llvm_program(program, "f" + str(i))
            print()