cpython/Tools/cases_generator/stack.py

import re
from analyzer import StackItem, StackEffect, Instruction, Uop, PseudoInstruction
from collections import defaultdict
from dataclasses import dataclass
from cwriter import CWriter
from typing import Iterator, Tuple

UNUSED = {"unused"}


def maybe_parenthesize(sym: str) -> str:
    """Add parentheses around a string if it contains an operator
       and is not already parenthesized.

    An exception is made for '*' which is common and harmless
    in the context where the symbolic size is used.
    """
    if sym.startswith("(") and sym.endswith(")"):
        return sym
    if re.match(r"^[\s\w*]+$", sym):
        return sym
    else:
        return f"({sym})"


def var_size(var: StackItem) -> str:
    if var.condition:
        # Special case simplifications
        if var.condition == "0":
            return "0"
        elif var.condition == "1":
            return var.get_size()
        elif var.condition == "oparg & 1" and not var.size:
            return f"({var.condition})"
        else:
            return f"(({var.condition}) ? {var.get_size()} : 0)"
    elif var.size:
        return var.size
    else:
        return "1"


@dataclass
class Local:
    item: StackItem
    cached: bool
    in_memory: bool
    defined: bool

    def __repr__(self) -> str:
        return f"Local('{self.item.name}', mem={self.in_memory}, defined={self.defined}, array={self.is_array()})"

    def compact_str(self) -> str:
        mtag = "M" if self.in_memory else ""
        dtag = "D" if self.defined else ""
        atag = "A" if self.is_array() else ""
        return f"'{self.item.name}'{mtag}{dtag}{atag}"

    @staticmethod
    def unused(defn: StackItem) -> "Local":
        return Local(defn, False, defn.is_array(), False)

    @staticmethod
    def undefined(defn: StackItem) -> "Local":
        array = defn.is_array()
        return Local(defn, not array, array, False)

    @staticmethod
    def redefinition(var: StackItem, prev: "Local") -> "Local":
        assert var.is_array() == prev.is_array()
        return Local(var, prev.cached, prev.in_memory, True)

    @staticmethod
    def from_memory(defn: StackItem) -> "Local":
        return Local(defn, True, True, True)

    def copy(self) -> "Local":
        return Local(
            self.item,
            self.cached,
            self.in_memory,
            self.defined
        )

    @property
    def size(self) -> str:
        return self.item.size

    @property
    def name(self) -> str:
        return self.item.name

    @property
    def condition(self) -> str | None:
        return self.item.condition

    def is_array(self) -> bool:
        return self.item.is_array()

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, Local):
            return NotImplemented
        return (
            self.item is other.item
            and self.cached is other.cached
            and self.in_memory is other.in_memory
            and self.defined is other.defined
        )


@dataclass
class StackOffset:
    "The stack offset of the virtual base of the stack from the physical stack pointer"

    popped: list[str]
    pushed: list[str]

    @staticmethod
    def empty() -> "StackOffset":
        return StackOffset([], [])

    def copy(self) -> "StackOffset":
        return StackOffset(self.popped[:], self.pushed[:])

    def pop(self, item: StackItem) -> None:
        self.popped.append(var_size(item))

    def push(self, item: StackItem) -> None:
        self.pushed.append(var_size(item))

    def __sub__(self, other: "StackOffset") -> "StackOffset":
        return StackOffset(self.popped + other.pushed, self.pushed + other.popped)

    def __neg__(self) -> "StackOffset":
        return StackOffset(self.pushed, self.popped)

    def simplify(self) -> None:
        "Remove matching values from both the popped and pushed list"
        if not self.popped:
            self.pushed.sort()
            return
        if not self.pushed:
            self.popped.sort()
            return
        # Sort the list so the lexically largest element is last.
        popped = sorted(self.popped)
        pushed = sorted(self.pushed)
        self.popped = []
        self.pushed = []
        while popped and pushed:
            pop = popped.pop()
            push = pushed.pop()
            if pop == push:
                pass
            elif pop > push:
                # if pop > push, there can be no element in pushed matching pop.
                self.popped.append(pop)
                pushed.append(push)
            else:
                self.pushed.append(push)
                popped.append(pop)
        self.popped.extend(popped)
        self.pushed.extend(pushed)
        self.pushed.sort()
        self.popped.sort()

    def to_c(self) -> str:
        self.simplify()
        int_offset = 0
        symbol_offset = ""
        for item in self.popped:
            try:
                int_offset -= int(item)
            except ValueError:
                symbol_offset += f" - {maybe_parenthesize(item)}"
        for item in self.pushed:
            try:
                int_offset += int(item)
            except ValueError:
                symbol_offset += f" + {maybe_parenthesize(item)}"
        if symbol_offset and not int_offset:
            res = symbol_offset
        else:
            res = f"{int_offset}{symbol_offset}"
        if res.startswith(" + "):
            res = res[3:]
        if res.startswith(" - "):
            res = "-" + res[3:]
        return res

    def as_int(self) -> int | None:
        self.simplify()
        int_offset = 0
        for item in self.popped:
            try:
                int_offset -= int(item)
            except ValueError:
                return None
        for item in self.pushed:
            try:
                int_offset += int(item)
            except ValueError:
                return None
        return int_offset

    def clear(self) -> None:
        self.popped = []
        self.pushed = []

    def __bool__(self) -> bool:
        self.simplify()
        return bool(self.popped) or bool(self.pushed)

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, StackOffset):
            return NotImplemented
        return self.to_c() == other.to_c()


class StackError(Exception):
    pass

def array_or_scalar(var: StackItem | Local) -> str:
    return "array" if var.is_array() else "scalar"

class Stack:
    def __init__(self) -> None:
        self.top_offset = StackOffset.empty()
        self.base_offset = StackOffset.empty()
        self.variables: list[Local] = []
        self.defined: set[str] = set()

    def pop(self, var: StackItem, extract_bits: bool = True) -> tuple[str, Local]:
        self.top_offset.pop(var)
        indirect = "&" if var.is_array() else ""
        if self.variables:
            popped = self.variables.pop()
            if var.is_array() ^ popped.is_array():
                raise StackError(
                    f"Array mismatch when popping '{popped.name}' from stack to assign to '{var.name}'. "
                    f"Expected {array_or_scalar(var)} got {array_or_scalar(popped)}"
                )
            if popped.size != var.size:
                raise StackError(
                    f"Size mismatch when popping '{popped.name}' from stack to assign to '{var.name}'. "
                    f"Expected {var_size(var)} got {var_size(popped.item)}"
                )
            if var.name in UNUSED:
                if popped.name not in UNUSED and popped.name in self.defined:
                    raise StackError(
                        f"Value is declared unused, but is already cached by prior operation as '{popped.name}'"
                    )
                return "", popped
            if not var.used:
                return "", popped
            self.defined.add(var.name)
            if popped.defined:
                if popped.name == var.name:
                    return "", popped
                else:
                    defn = f"{var.name} = {popped.name};\n"
            else:
                if var.is_array():
                    defn = f"{var.name} = &stack_pointer[{self.top_offset.to_c()}];\n"
                else:
                    defn = f"{var.name} = stack_pointer[{self.top_offset.to_c()}];\n"
                    popped.in_memory = True
            return defn, Local.redefinition(var, popped)

        self.base_offset.pop(var)
        if var.name in UNUSED or not var.used:
            return "", Local.unused(var)
        self.defined.add(var.name)
        cast = f"({var.type})" if (not indirect and var.type) else ""
        bits = ".bits" if cast and extract_bits else ""
        assign = f"{var.name} = {cast}{indirect}stack_pointer[{self.base_offset.to_c()}]{bits};"
        if var.condition:
            if var.condition == "1":
                assign = f"{assign}\n"
            elif var.condition == "0":
                return "", Local.unused(var)
            else:
                assign = f"if ({var.condition}) {{ {assign} }}\n"
        else:
            assign = f"{assign}\n"
        return assign, Local.from_memory(var)

    def push(self, var: Local) -> None:
        assert(var not in self.variables)
        self.variables.append(var)
        self.top_offset.push(var.item)
        if var.item.used:
            self.defined.add(var.name)

    @staticmethod
    def _do_emit(
        out: CWriter,
        var: StackItem,
        base_offset: StackOffset,
        cast_type: str = "uintptr_t",
        extract_bits: bool = True,
    ) -> None:
        cast = f"({cast_type})" if var.type else ""
        bits = ".bits" if cast and extract_bits else ""
        if var.condition == "0":
            return
        if var.condition and var.condition != "1":
            out.emit(f"if ({var.condition}) ")
        out.emit(f"stack_pointer[{base_offset.to_c()}]{bits} = {cast}{var.name};\n")

    def _adjust_stack_pointer(self, out: CWriter, number: str) -> None:
        if number != "0":
            out.start_line()
            out.emit(f"stack_pointer += {number};\n")
            out.emit("assert(WITHIN_STACK_BOUNDS());\n")

    def flush(
        self, out: CWriter, cast_type: str = "uintptr_t", extract_bits: bool = True
    ) -> None:
        out.start_line()
        var_offset = self.base_offset.copy()
        for var in self.variables:
            if (
                var.defined and
                not var.in_memory
            ):
                Stack._do_emit(out, var.item, var_offset, cast_type, extract_bits)
                var.in_memory = True
            var_offset.push(var.item)
        number = self.top_offset.to_c()
        self._adjust_stack_pointer(out, number)
        self.base_offset -= self.top_offset
        self.top_offset.clear()
        out.start_line()

    def is_flushed(self) -> bool:
        return not self.variables and not self.base_offset and not self.top_offset

    def peek_offset(self) -> str:
        return self.top_offset.to_c()

    def as_comment(self) -> str:
        variables = ", ".join([v.compact_str() for v in self.variables])
        return (
            f"/* Variables: {variables}. base: {self.base_offset.to_c()}. top: {self.top_offset.to_c()} */"
        )

    def copy(self) -> "Stack":
        other = Stack()
        other.top_offset = self.top_offset.copy()
        other.base_offset = self.base_offset.copy()
        other.variables = [var.copy() for var in self.variables]
        other.defined = set(self.defined)
        return other

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, Stack):
            return NotImplemented
        return (
            self.top_offset == other.top_offset
            and self.base_offset == other.base_offset
            and self.variables == other.variables
        )

    def align(self, other: "Stack", out: CWriter) -> None:
        if len(self.variables) != len(other.variables):
            raise StackError("Cannot align stacks: differing variables")
        if self.top_offset == other.top_offset:
            return
        diff = self.top_offset - other.top_offset
        try:
            self.top_offset -= diff
            self.base_offset -= diff
            self._adjust_stack_pointer(out, diff.to_c())
        except ValueError:
            raise StackError("Cannot align stacks: cannot adjust stack pointer")

    def merge(self, other: "Stack", out: CWriter) -> None:
        if len(self.variables) != len(other.variables):
            raise StackError("Cannot merge stacks: differing variables")
        for self_var, other_var in zip(self.variables, other.variables):
            if self_var.name != other_var.name:
                raise StackError(f"Mismatched variables on stack: {self_var.name} and {other_var.name}")
            self_var.defined = self_var.defined and other_var.defined
            self_var.in_memory = self_var.in_memory and other_var.in_memory
        self.align(other, out)


def stacks(inst: Instruction | PseudoInstruction) -> Iterator[StackEffect]:
    if isinstance(inst, Instruction):
        for uop in inst.parts:
            if isinstance(uop, Uop):
                yield uop.stack
    else:
        assert isinstance(inst, PseudoInstruction)
        yield inst.stack


def apply_stack_effect(stack: Stack, effect: StackEffect) -> None:
    locals: dict[str, Local] = {}
    for var in reversed(effect.inputs):
        _, local = stack.pop(var)
        if var.name != "unused":
            locals[local.name] = local
    for var in effect.outputs:
        if var.name in locals:
            local = locals[var.name]
        else:
            local = Local.unused(var)
        stack.push(local)


def get_stack_effect(inst: Instruction | PseudoInstruction) -> Stack:
    stack = Stack()
    for s in stacks(inst):
        apply_stack_effect(stack, s)
    return stack


def get_stack_effects(inst: Instruction | PseudoInstruction) -> list[Stack]:
    """Returns a list of stack effects after each uop"""
    result = []
    stack = Stack()
    for s in stacks(inst):
        apply_stack_effect(stack, s)
        result.append(stack.copy())
    return result


@dataclass
class Storage:

    stack: Stack
    inputs: list[Local]
    outputs: list[Local]
    peeks: list[Local]
    spilled: int = 0

    @staticmethod
    def needs_defining(var: Local) -> bool:
        return (
            not var.defined and
            not var.is_array() and
            var.name != "unused"
        )

    @staticmethod
    def is_live(var: Local) -> bool:
        return (
            var.defined and
            var.name != "unused"
        )

    def first_input_not_cleared(self) -> str:
        for input in self.inputs:
            if input.defined:
                return input.name
        return ""

    def clear_inputs(self, reason:str) -> None:
        while self.inputs:
            tos = self.inputs.pop()
            if self.is_live(tos) and not tos.is_array():
                raise StackError(
                    f"Input '{tos.name}' is still live {reason}"
                )
            self.stack.pop(tos.item)

    def clear_dead_inputs(self) -> None:
        live = ""
        while self.inputs:
            tos = self.inputs[-1]
            if self.is_live(tos):
                live = tos.name
                break
            self.inputs.pop()
            self.stack.pop(tos.item)
        for var in self.inputs:
            if not var.defined and not var.is_array() and var.name != "unused":
                raise StackError(
                    f"Input '{var.name}' is not live, but '{live}' is"
                )

    def _push_defined_outputs(self) -> None:
        defined_output = ""
        for output in self.outputs:
            if output.defined and not output.in_memory:
                defined_output = output.name
        if not defined_output:
            return
        self.clear_inputs(f"when output '{defined_output}' is defined")
        undefined = ""
        for out in self.outputs:
            if out.defined:
                if undefined:
                    f"Locals not defined in stack order. "
                    f"Expected '{undefined}' to be defined before '{out.name}'"
            else:
                undefined = out.name
        while self.outputs and not self.needs_defining(self.outputs[0]):
            out = self.outputs.pop(0)
            self.stack.push(out)

    def locals_cached(self) -> bool:
        for out in self.outputs:
            if out.defined:
                return True
        return False

    def flush(self, out: CWriter, cast_type: str = "uintptr_t", extract_bits: bool = True) -> None:
        self.clear_dead_inputs()
        self._push_defined_outputs()
        self.stack.flush(out, cast_type, extract_bits)

    def save(self, out: CWriter) -> None:
        assert self.spilled >= 0
        if self.spilled == 0:
            self.flush(out)
            out.start_line()
            out.emit("_PyFrame_SetStackPointer(frame, stack_pointer);\n")
        self.spilled += 1

    def reload(self, out: CWriter) -> None:
        if self.spilled == 0:
            raise StackError("Cannot reload stack as it hasn't been saved")
        assert self.spilled > 0
        self.spilled -= 1
        if self.spilled == 0:
            out.start_line()
            out.emit("stack_pointer = _PyFrame_GetStackPointer(frame);\n")

    @staticmethod
    def for_uop(stack: Stack, uop: Uop, extract_bits: bool = True) -> tuple[list[str], "Storage"]:
        code_list: list[str] = []
        inputs: list[Local] = []
        peeks: list[Local] = []
        for input in reversed(uop.stack.inputs):
            code, local = stack.pop(input, extract_bits)
            code_list.append(code)
            if input.peek:
                peeks.append(local)
            else:
                inputs.append(local)
        inputs.reverse()
        peeks.reverse()
        for peek in peeks:
            stack.push(peek)
        top_offset = stack.top_offset.copy()
        for ouput in uop.stack.outputs:
            if ouput.is_array() and ouput.used and not ouput.peek:
                c_offset = top_offset.to_c()
                top_offset.push(ouput)
                code_list.append(f"{ouput.name} = &stack_pointer[{c_offset}];\n")
            else:
                top_offset.push(ouput)
        for var in inputs:
            stack.push(var)
        outputs = [ Local.undefined(var) for var in uop.stack.outputs if not var.peek ]
        return code_list, Storage(stack, inputs, outputs, peeks)

    @staticmethod
    def copy_list(arg: list[Local]) -> list[Local]:
        return [ l.copy() for l in arg ]

    def copy(self) -> "Storage":
        new_stack = self.stack.copy()
        variables = { var.name: var for var in new_stack.variables }
        inputs = [ variables[var.name] for var in self.inputs]
        assert [v.name for v in inputs] == [v.name for v in self.inputs], (inputs, self.inputs)
        return Storage(
            new_stack, inputs,
            self.copy_list(self.outputs), self.copy_list(self.peeks)
        )

    def sanity_check(self) -> None:
        names: set[str] = set()
        for var in self.inputs:
            if var.name in names:
                raise StackError(f"Duplicate name {var.name}")
            names.add(var.name)
        names = set()
        for var in self.outputs:
            if var.name in names:
                raise StackError(f"Duplicate name {var.name}")
            names.add(var.name)
        names = set()
        for var in self.stack.variables:
            if var.name in names:
                raise StackError(f"Duplicate name {var.name}")
            names.add(var.name)

    def is_flushed(self) -> bool:
        for var in self.outputs:
            if var.defined and not var.in_memory:
                return False
        return self.stack.is_flushed()

    def merge(self, other: "Storage", out: CWriter) -> None:
        self.sanity_check()
        if len(self.inputs) != len(other.inputs):
            self.clear_dead_inputs()
            other.clear_dead_inputs()
        if len(self.inputs) != len(other.inputs):
            diff = self.inputs[-1] if len(self.inputs) > len(other.inputs) else other.inputs[-1]
            raise StackError(f"Unmergeable inputs. Differing state of '{diff.name}'")
        for var, other_var in zip(self.inputs, other.inputs):
            if var.defined != other_var.defined:
                raise StackError(f"'{var.name}' is cleared on some paths, but not all")
        if len(self.outputs) != len(other.outputs):
            self._push_defined_outputs()
            other._push_defined_outputs()
        if len(self.outputs) != len(other.outputs):
            var = self.outputs[0] if len(self.outputs) > len(other.outputs) else other.outputs[0]
            raise StackError(f"'{var.name}' is set on some paths, but not all")
        self.stack.merge(other.stack, out)
        self.sanity_check()

    def push_outputs(self) -> None:
        if self.spilled:
            raise StackError(f"Unbalanced stack spills")
        self.clear_inputs("at the end of the micro-op")
        if self.inputs:
            raise StackError(f"Input variable '{self.inputs[-1].name}' is still live")
        self._push_defined_outputs()
        if self.outputs:
            for out in self.outputs:
                if self.needs_defining(out):
                    raise StackError(f"Output variable '{self.outputs[0].name}' is not defined")
                self.stack.push(out)
            self.outputs = []

    def as_comment(self) -> str:
        stack_comment = self.stack.as_comment()
        next_line = "\n               "
        inputs = ", ".join([var.compact_str() for var in self.inputs])
        outputs = ", ".join([var.compact_str() for var in self.outputs])
        peeks = ", ".join([var.name for var in self.peeks])
        return f"{stack_comment[:-2]}{next_line}inputs: {inputs}{next_line}outputs: {outputs}{next_line}peeks: {peeks} */"