cpython/Tools/build/umarshal.py

# Implementation of marshal.loads() in pure Python

import ast

from typing import Any, Tuple


class Type:
    # Adapted from marshal.c
    NULL                = ord('0')
    NONE                = ord('N')
    FALSE               = ord('F')
    TRUE                = ord('T')
    STOPITER            = ord('S')
    ELLIPSIS            = ord('.')
    INT                 = ord('i')
    INT64               = ord('I')
    FLOAT               = ord('f')
    BINARY_FLOAT        = ord('g')
    COMPLEX             = ord('x')
    BINARY_COMPLEX      = ord('y')
    LONG                = ord('l')
    STRING              = ord('s')
    INTERNED            = ord('t')
    REF                 = ord('r')
    TUPLE               = ord('(')
    LIST                = ord('[')
    DICT                = ord('{')
    CODE                = ord('c')
    UNICODE             = ord('u')
    UNKNOWN             = ord('?')
    SET                 = ord('<')
    FROZENSET           = ord('>')
    ASCII               = ord('a')
    ASCII_INTERNED      = ord('A')
    SMALL_TUPLE         = ord(')')
    SHORT_ASCII         = ord('z')
    SHORT_ASCII_INTERNED = ord('Z')


FLAG_REF = 0x80  # with a type, add obj to index

NULL = object()  # marker

# Cell kinds
CO_FAST_LOCAL = 0x20
CO_FAST_CELL = 0x40
CO_FAST_FREE = 0x80


class Code:
    def __init__(self, **kwds: Any):
        self.__dict__.update(kwds)

    def __repr__(self) -> str:
        return f"Code(**{self.__dict__})"

    co_localsplusnames: Tuple[str]
    co_localspluskinds: Tuple[int]

    def get_localsplus_names(self, select_kind: int) -> Tuple[str, ...]:
        varnames: list[str] = []
        for name, kind in zip(self.co_localsplusnames,
                              self.co_localspluskinds):
            if kind & select_kind:
                varnames.append(name)
        return tuple(varnames)

    @property
    def co_varnames(self) -> Tuple[str, ...]:
        return self.get_localsplus_names(CO_FAST_LOCAL)

    @property
    def co_cellvars(self) -> Tuple[str, ...]:
        return self.get_localsplus_names(CO_FAST_CELL)

    @property
    def co_freevars(self) -> Tuple[str, ...]:
        return self.get_localsplus_names(CO_FAST_FREE)

    @property
    def co_nlocals(self) -> int:
        return len(self.co_varnames)


class Reader:
    # A fairly literal translation of the marshal reader.

    def __init__(self, data: bytes):
        self.data: bytes = data
        self.end: int = len(self.data)
        self.pos: int = 0
        self.refs: list[Any] = []
        self.level: int = 0

    def r_string(self, n: int) -> bytes:
        assert 0 <= n <= self.end - self.pos
        buf = self.data[self.pos : self.pos + n]
        self.pos += n
        return buf

    def r_byte(self) -> int:
        buf = self.r_string(1)
        return buf[0]

    def r_short(self) -> int:
        buf = self.r_string(2)
        x = buf[0]
        x |= buf[1] << 8
        x |= -(x & (1<<15))  # Sign-extend
        return x

    def r_long(self) -> int:
        buf = self.r_string(4)
        x = buf[0]
        x |= buf[1] << 8
        x |= buf[2] << 16
        x |= buf[3] << 24
        x |= -(x & (1<<31))  # Sign-extend
        return x

    def r_long64(self) -> int:
        buf = self.r_string(8)
        x = buf[0]
        x |= buf[1] << 8
        x |= buf[2] << 16
        x |= buf[3] << 24
        x |= buf[4] << 32
        x |= buf[5] << 40
        x |= buf[6] << 48
        x |= buf[7] << 56
        x |= -(x & (1<<63))  # Sign-extend
        return x

    def r_PyLong(self) -> int:
        n = self.r_long()
        size = abs(n)
        x = 0
        # Pray this is right
        for i in range(size):
            x |= self.r_short() << i*15
        if n < 0:
            x = -x
        return x

    def r_float_bin(self) -> float:
        buf = self.r_string(8)
        import struct  # Lazy import to avoid breaking UNIX build
        return struct.unpack("d", buf)[0]

    def r_float_str(self) -> float:
        n = self.r_byte()
        buf = self.r_string(n)
        return ast.literal_eval(buf.decode("ascii"))

    def r_ref_reserve(self, flag: int) -> int:
        if flag:
            idx = len(self.refs)
            self.refs.append(None)
            return idx
        else:
            return 0

    def r_ref_insert(self, obj: Any, idx: int, flag: int) -> Any:
        if flag:
            self.refs[idx] = obj
        return obj

    def r_ref(self, obj: Any, flag: int) -> Any:
        assert flag & FLAG_REF
        self.refs.append(obj)
        return obj

    def r_object(self) -> Any:
        old_level = self.level
        try:
            return self._r_object()
        finally:
            self.level = old_level

    def _r_object(self) -> Any:
        code = self.r_byte()
        flag = code & FLAG_REF
        type = code & ~FLAG_REF
        # print("  "*self.level + f"{code} {flag} {type} {chr(type)!r}")
        self.level += 1

        def R_REF(obj: Any) -> Any:
            if flag:
                obj = self.r_ref(obj, flag)
            return obj

        if type == Type.NULL:
            return NULL
        elif type == Type.NONE:
            return None
        elif type == Type.ELLIPSIS:
            return Ellipsis
        elif type == Type.FALSE:
            return False
        elif type == Type.TRUE:
            return True
        elif type == Type.INT:
            return R_REF(self.r_long())
        elif type == Type.INT64:
            return R_REF(self.r_long64())
        elif type == Type.LONG:
            return R_REF(self.r_PyLong())
        elif type == Type.FLOAT:
            return R_REF(self.r_float_str())
        elif type == Type.BINARY_FLOAT:
            return R_REF(self.r_float_bin())
        elif type == Type.COMPLEX:
            return R_REF(complex(self.r_float_str(),
                                    self.r_float_str()))
        elif type == Type.BINARY_COMPLEX:
            return R_REF(complex(self.r_float_bin(),
                                    self.r_float_bin()))
        elif type == Type.STRING:
            n = self.r_long()
            return R_REF(self.r_string(n))
        elif type == Type.ASCII_INTERNED or type == Type.ASCII:
            n = self.r_long()
            return R_REF(self.r_string(n).decode("ascii"))
        elif type == Type.SHORT_ASCII_INTERNED or type == Type.SHORT_ASCII:
            n = self.r_byte()
            return R_REF(self.r_string(n).decode("ascii"))
        elif type == Type.INTERNED or type == Type.UNICODE:
            n = self.r_long()
            return R_REF(self.r_string(n).decode("utf8", "surrogatepass"))
        elif type == Type.SMALL_TUPLE:
            n = self.r_byte()
            idx = self.r_ref_reserve(flag)
            retval: Any = tuple(self.r_object() for _ in range(n))
            self.r_ref_insert(retval, idx, flag)
            return retval
        elif type == Type.TUPLE:
            n = self.r_long()
            idx = self.r_ref_reserve(flag)
            retval = tuple(self.r_object() for _ in range(n))
            self.r_ref_insert(retval, idx, flag)
            return retval
        elif type == Type.LIST:
            n = self.r_long()
            retval = R_REF([])
            for _ in range(n):
                retval.append(self.r_object())
            return retval
        elif type == Type.DICT:
            retval = R_REF({})
            while True:
                key = self.r_object()
                if key == NULL:
                    break
                val = self.r_object()
                retval[key] = val
            return retval
        elif type == Type.SET:
            n = self.r_long()
            retval = R_REF(set())
            for _ in range(n):
                v = self.r_object()
                retval.add(v)
            return retval
        elif type == Type.FROZENSET:
            n = self.r_long()
            s: set[Any] = set()
            idx = self.r_ref_reserve(flag)
            for _ in range(n):
                v = self.r_object()
                s.add(v)
            retval = frozenset(s)
            self.r_ref_insert(retval, idx, flag)
            return retval
        elif type == Type.CODE:
            retval = R_REF(Code())
            retval.co_argcount = self.r_long()
            retval.co_posonlyargcount = self.r_long()
            retval.co_kwonlyargcount = self.r_long()
            retval.co_stacksize = self.r_long()
            retval.co_flags = self.r_long()
            retval.co_code = self.r_object()
            retval.co_consts = self.r_object()
            retval.co_names = self.r_object()
            retval.co_localsplusnames = self.r_object()
            retval.co_localspluskinds = self.r_object()
            retval.co_filename = self.r_object()
            retval.co_name = self.r_object()
            retval.co_qualname = self.r_object()
            retval.co_firstlineno = self.r_long()
            retval.co_linetable = self.r_object()
            retval.co_exceptiontable = self.r_object()
            return retval
        elif type == Type.REF:
            n = self.r_long()
            retval = self.refs[n]
            assert retval is not None
            return retval
        else:
            breakpoint()
            raise AssertionError(f"Unknown type {type} {chr(type)!r}")


def loads(data: bytes) -> Any:
    assert isinstance(data, bytes)
    r = Reader(data)
    return r.r_object()


def main():
    # Test
    import marshal, pprint
    sample = {'foo': {(42, "bar", 3.14)}}
    data = marshal.dumps(sample)
    retval = loads(data)
    assert retval == sample, retval
    sample = main.__code__
    data = marshal.dumps(sample)
    retval = loads(data)
    assert isinstance(retval, Code), retval
    pprint.pprint(retval.__dict__)


if __name__ == "__main__":
    main()