llvm/llvm/lib/Analysis/models/log_reader.py

"""Reader for training log.

See lib/Analysis/TrainingLogger.cpp for a description of the format.
"""
import ctypes
import dataclasses
import io
import json
import math
import sys
from typing import List, Optional

_element_types = {
    "float": ctypes.c_float,
    "double": ctypes.c_double,
    "int8_t": ctypes.c_int8,
    "uint8_t": ctypes.c_uint8,
    "int16_t": ctypes.c_int16,
    "uint16_t": ctypes.c_uint16,
    "int32_t": ctypes.c_int32,
    "uint32_t": ctypes.c_uint32,
    "int64_t": ctypes.c_int64,
    "uint64_t": ctypes.c_uint64,
}


@dataclasses.dataclass(frozen=True)
class TensorSpec:
    name: str
    port: int
    shape: List[int]
    element_type: type

    @staticmethod
    def from_dict(d: dict):
        name = d["name"]
        port = d["port"]
        shape = [int(e) for e in d["shape"]]
        element_type_str = d["type"]
        if element_type_str not in _element_types:
            raise ValueError(f"uknown type: {element_type_str}")
        return TensorSpec(
            name=name,
            port=port,
            shape=shape,
            element_type=_element_types[element_type_str],
        )


class TensorValue:
    def __init__(self, spec: TensorSpec, buffer: bytes):
        self._spec = spec
        self._buffer = buffer
        self._view = ctypes.cast(self._buffer, ctypes.POINTER(self._spec.element_type))
        self._len = math.prod(self._spec.shape)

    def spec(self) -> TensorSpec:
        return self._spec

    def __len__(self) -> int:
        return self._len

    def __getitem__(self, index):
        if index < 0 or index >= self._len:
            raise IndexError(f"Index {index} out of range [0..{self._len})")
        return self._view[index]


def read_tensor(fs: io.BufferedReader, ts: TensorSpec) -> TensorValue:
    size = math.prod(ts.shape) * ctypes.sizeof(ts.element_type)
    data = fs.read(size)
    return TensorValue(ts, data)


def pretty_print_tensor_value(tv: TensorValue):
    print(f'{tv.spec().name}: {",".join([str(v) for v in tv])}')


def read_header(f: io.BufferedReader):
    header = json.loads(f.readline())
    tensor_specs = [TensorSpec.from_dict(ts) for ts in header["features"]]
    score_spec = TensorSpec.from_dict(header["score"]) if "score" in header else None
    advice_spec = TensorSpec.from_dict(header["advice"]) if "advice" in header else None
    return tensor_specs, score_spec, advice_spec


def read_one_observation(
    context: Optional[str],
    event_str: str,
    f: io.BufferedReader,
    tensor_specs: List[TensorSpec],
    score_spec: Optional[TensorSpec],
):
    event = json.loads(event_str)
    if "context" in event:
        context = event["context"]
        event = json.loads(f.readline())
    observation_id = int(event["observation"])
    features = []
    for ts in tensor_specs:
        features.append(read_tensor(f, ts))
    f.readline()
    score = None
    if score_spec is not None:
        score_header = json.loads(f.readline())
        assert int(score_header["outcome"]) == observation_id
        score = read_tensor(f, score_spec)
        f.readline()
    return context, observation_id, features, score


def read_stream(fname: str):
    with io.BufferedReader(io.FileIO(fname, "rb")) as f:
        tensor_specs, score_spec, _ = read_header(f)
        context = None
        while True:
            event_str = f.readline()
            if not event_str:
                break
            context, observation_id, features, score = read_one_observation(
                context, event_str, f, tensor_specs, score_spec
            )
            yield context, observation_id, features, score


def main(args):
    last_context = None
    for ctx, obs_id, features, score in read_stream(args[1]):
        if last_context != ctx:
            print(f"context: {ctx}")
            last_context = ctx
        print(f"observation: {obs_id}")
        for fv in features:
            pretty_print_tensor_value(fv)
        if score:
            pretty_print_tensor_value(score)


if __name__ == "__main__":
    main(sys.argv)