# Implementation of marshal.loads() in pure Python
import ast
from typing import Any, Tuple
class Type:
# Adapted from marshal.c
NULL = ord('0')
NONE = ord('N')
FALSE = ord('F')
TRUE = ord('T')
STOPITER = ord('S')
ELLIPSIS = ord('.')
INT = ord('i')
INT64 = ord('I')
FLOAT = ord('f')
BINARY_FLOAT = ord('g')
COMPLEX = ord('x')
BINARY_COMPLEX = ord('y')
LONG = ord('l')
STRING = ord('s')
INTERNED = ord('t')
REF = ord('r')
TUPLE = ord('(')
LIST = ord('[')
DICT = ord('{')
CODE = ord('c')
UNICODE = ord('u')
UNKNOWN = ord('?')
SET = ord('<')
FROZENSET = ord('>')
ASCII = ord('a')
ASCII_INTERNED = ord('A')
SMALL_TUPLE = ord(')')
SHORT_ASCII = ord('z')
SHORT_ASCII_INTERNED = ord('Z')
FLAG_REF = 0x80 # with a type, add obj to index
NULL = object() # marker
# Cell kinds
CO_FAST_LOCAL = 0x20
CO_FAST_CELL = 0x40
CO_FAST_FREE = 0x80
class Code:
def __init__(self, **kwds: Any):
self.__dict__.update(kwds)
def __repr__(self) -> str:
return f"Code(**{self.__dict__})"
co_localsplusnames: Tuple[str]
co_localspluskinds: Tuple[int]
def get_localsplus_names(self, select_kind: int) -> Tuple[str, ...]:
varnames: list[str] = []
for name, kind in zip(self.co_localsplusnames,
self.co_localspluskinds):
if kind & select_kind:
varnames.append(name)
return tuple(varnames)
@property
def co_varnames(self) -> Tuple[str, ...]:
return self.get_localsplus_names(CO_FAST_LOCAL)
@property
def co_cellvars(self) -> Tuple[str, ...]:
return self.get_localsplus_names(CO_FAST_CELL)
@property
def co_freevars(self) -> Tuple[str, ...]:
return self.get_localsplus_names(CO_FAST_FREE)
@property
def co_nlocals(self) -> int:
return len(self.co_varnames)
class Reader:
# A fairly literal translation of the marshal reader.
def __init__(self, data: bytes):
self.data: bytes = data
self.end: int = len(self.data)
self.pos: int = 0
self.refs: list[Any] = []
self.level: int = 0
def r_string(self, n: int) -> bytes:
assert 0 <= n <= self.end - self.pos
buf = self.data[self.pos : self.pos + n]
self.pos += n
return buf
def r_byte(self) -> int:
buf = self.r_string(1)
return buf[0]
def r_short(self) -> int:
buf = self.r_string(2)
x = buf[0]
x |= buf[1] << 8
x |= -(x & (1<<15)) # Sign-extend
return x
def r_long(self) -> int:
buf = self.r_string(4)
x = buf[0]
x |= buf[1] << 8
x |= buf[2] << 16
x |= buf[3] << 24
x |= -(x & (1<<31)) # Sign-extend
return x
def r_long64(self) -> int:
buf = self.r_string(8)
x = buf[0]
x |= buf[1] << 8
x |= buf[2] << 16
x |= buf[3] << 24
x |= buf[4] << 32
x |= buf[5] << 40
x |= buf[6] << 48
x |= buf[7] << 56
x |= -(x & (1<<63)) # Sign-extend
return x
def r_PyLong(self) -> int:
n = self.r_long()
size = abs(n)
x = 0
# Pray this is right
for i in range(size):
x |= self.r_short() << i*15
if n < 0:
x = -x
return x
def r_float_bin(self) -> float:
buf = self.r_string(8)
import struct # Lazy import to avoid breaking UNIX build
return struct.unpack("d", buf)[0]
def r_float_str(self) -> float:
n = self.r_byte()
buf = self.r_string(n)
return ast.literal_eval(buf.decode("ascii"))
def r_ref_reserve(self, flag: int) -> int:
if flag:
idx = len(self.refs)
self.refs.append(None)
return idx
else:
return 0
def r_ref_insert(self, obj: Any, idx: int, flag: int) -> Any:
if flag:
self.refs[idx] = obj
return obj
def r_ref(self, obj: Any, flag: int) -> Any:
assert flag & FLAG_REF
self.refs.append(obj)
return obj
def r_object(self) -> Any:
old_level = self.level
try:
return self._r_object()
finally:
self.level = old_level
def _r_object(self) -> Any:
code = self.r_byte()
flag = code & FLAG_REF
type = code & ~FLAG_REF
# print(" "*self.level + f"{code} {flag} {type} {chr(type)!r}")
self.level += 1
def R_REF(obj: Any) -> Any:
if flag:
obj = self.r_ref(obj, flag)
return obj
if type == Type.NULL:
return NULL
elif type == Type.NONE:
return None
elif type == Type.ELLIPSIS:
return Ellipsis
elif type == Type.FALSE:
return False
elif type == Type.TRUE:
return True
elif type == Type.INT:
return R_REF(self.r_long())
elif type == Type.INT64:
return R_REF(self.r_long64())
elif type == Type.LONG:
return R_REF(self.r_PyLong())
elif type == Type.FLOAT:
return R_REF(self.r_float_str())
elif type == Type.BINARY_FLOAT:
return R_REF(self.r_float_bin())
elif type == Type.COMPLEX:
return R_REF(complex(self.r_float_str(),
self.r_float_str()))
elif type == Type.BINARY_COMPLEX:
return R_REF(complex(self.r_float_bin(),
self.r_float_bin()))
elif type == Type.STRING:
n = self.r_long()
return R_REF(self.r_string(n))
elif type == Type.ASCII_INTERNED or type == Type.ASCII:
n = self.r_long()
return R_REF(self.r_string(n).decode("ascii"))
elif type == Type.SHORT_ASCII_INTERNED or type == Type.SHORT_ASCII:
n = self.r_byte()
return R_REF(self.r_string(n).decode("ascii"))
elif type == Type.INTERNED or type == Type.UNICODE:
n = self.r_long()
return R_REF(self.r_string(n).decode("utf8", "surrogatepass"))
elif type == Type.SMALL_TUPLE:
n = self.r_byte()
idx = self.r_ref_reserve(flag)
retval: Any = tuple(self.r_object() for _ in range(n))
self.r_ref_insert(retval, idx, flag)
return retval
elif type == Type.TUPLE:
n = self.r_long()
idx = self.r_ref_reserve(flag)
retval = tuple(self.r_object() for _ in range(n))
self.r_ref_insert(retval, idx, flag)
return retval
elif type == Type.LIST:
n = self.r_long()
retval = R_REF([])
for _ in range(n):
retval.append(self.r_object())
return retval
elif type == Type.DICT:
retval = R_REF({})
while True:
key = self.r_object()
if key == NULL:
break
val = self.r_object()
retval[key] = val
return retval
elif type == Type.SET:
n = self.r_long()
retval = R_REF(set())
for _ in range(n):
v = self.r_object()
retval.add(v)
return retval
elif type == Type.FROZENSET:
n = self.r_long()
s: set[Any] = set()
idx = self.r_ref_reserve(flag)
for _ in range(n):
v = self.r_object()
s.add(v)
retval = frozenset(s)
self.r_ref_insert(retval, idx, flag)
return retval
elif type == Type.CODE:
retval = R_REF(Code())
retval.co_argcount = self.r_long()
retval.co_posonlyargcount = self.r_long()
retval.co_kwonlyargcount = self.r_long()
retval.co_stacksize = self.r_long()
retval.co_flags = self.r_long()
retval.co_code = self.r_object()
retval.co_consts = self.r_object()
retval.co_names = self.r_object()
retval.co_localsplusnames = self.r_object()
retval.co_localspluskinds = self.r_object()
retval.co_filename = self.r_object()
retval.co_name = self.r_object()
retval.co_qualname = self.r_object()
retval.co_firstlineno = self.r_long()
retval.co_linetable = self.r_object()
retval.co_exceptiontable = self.r_object()
return retval
elif type == Type.REF:
n = self.r_long()
retval = self.refs[n]
assert retval is not None
return retval
else:
breakpoint()
raise AssertionError(f"Unknown type {type} {chr(type)!r}")
def loads(data: bytes) -> Any:
assert isinstance(data, bytes)
r = Reader(data)
return r.r_object()
def main():
# Test
import marshal, pprint
sample = {'foo': {(42, "bar", 3.14)}}
data = marshal.dumps(sample)
retval = loads(data)
assert retval == sample, retval
sample = main.__code__
data = marshal.dumps(sample)
retval = loads(data)
assert isinstance(retval, Code), retval
pprint.pprint(retval.__dict__)
if __name__ == "__main__":
main()