chromium/components/optimization_guide/tools/gen_on_device_proto_descriptors.py

#!/usr/bin/env python
# Copyright 2023 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Code generator for proto descriptors used for on-device model execution.

This script generates a C++ source file containing the proto descriptors.
"""
from __future__ import annotations

import dataclasses
import functools
from io import StringIO
import optparse
import os
import collections
import re
import sys

_HERE_PATH = os.path.dirname(__file__)
_SRC_PATH = os.path.normpath(os.path.join(_HERE_PATH, '..', '..', '..'))
sys.path.insert(0, os.path.join(_SRC_PATH, 'third_party', 'protobuf',
                                'python'))

from google.protobuf import descriptor_pb2


class Error(Exception):
    pass


class Type:
    """Aliases for FieldDescriptorProto::Type(s)."""
    DOUBLE = 1
    FLOAT = 2
    INT64 = 3
    UINT64 = 4
    INT32 = 5
    FIXED64 = 6
    FIXED32 = 7
    BOOL = 8
    STRING = 9
    GROUP = 10
    MESSAGE = 11
    BYTES = 12
    UINT32 = 13
    ENUM = 14
    SFIXED32 = 15
    SFIXED64 = 16
    SINT32 = 17
    SINT64 = 18


@dataclasses.dataclass(frozen=True)
class BaseValueType:
    cpptype: str
    getIfFn: str


class VType:
    """Base::Value types."""
    DOUBLE = BaseValueType("std::optional<double>", "Double")
    BOOL = BaseValueType("std::optional<bool>", "Bool")
    INT = BaseValueType("std::optional<int>", "Int")
    STRING = BaseValueType("std::string*", "String")
    BLOB = BaseValueType("BlobStorage*", "Blob")
    DICT = BaseValueType("Dict*", "Dict")
    LIST = BaseValueType("List*", "List")


BASE_VALUE_TYPES = {
    Type.DOUBLE: VType.DOUBLE,
    Type.FLOAT: VType.DOUBLE,
    Type.INT64: VType.INT,
    Type.UINT64: VType.INT,
    Type.INT32: VType.INT,
    Type.FIXED64: VType.INT,
    Type.FIXED32: VType.INT,
    Type.BOOL: VType.BOOL,
    Type.STRING: VType.STRING,
    Type.GROUP: VType.STRING,  # Not handled
    Type.MESSAGE: VType.DICT,  # Not handled
    Type.BYTES: VType.STRING,  # Not handled
    Type.UINT32: VType.INT,
    Type.ENUM: VType.INT,  # Not handled
    Type.SFIXED32: VType.INT,
    Type.SFIXED64: VType.INT,
    Type.SINT32: VType.INT,
    Type.SINT64: VType.INT,
}


@dataclasses.dataclass(frozen=True)
class Message:
    desc: descriptor_pb2.DescriptorProto
    package: str
    parent_names: tuple[str, ...] = ()

    @functools.cached_property
    def type_name(self) -> str:
        """Returns the value returned for MessageLite::GetTypeName()."""
        return '.'.join((self.package, *self.parent_names, self.desc.name))

    @functools.cached_property
    def cpp_name(self) -> str:
        """Returns the fully qualified c++ type name."""
        namespace = self.package.replace('.', '::')
        classname = '_'.join((*self.parent_names, self.desc.name))
        return f'{namespace}::{classname}'

    @functools.cached_property
    def iname(self) -> str:
        """Returns the identifier piece for generated function names."""
        return '_' + self.type_name.replace('.', '_')

    @functools.cached_property
    def fields(self):
        return tuple(Field(fdesc) for fdesc in self.desc.field)


@dataclasses.dataclass(frozen=True)
class Field:
    desc: descriptor_pb2.FieldDescriptorProto

    @property
    def tag_number(self):
        return self.desc.number

    @property
    def name(self):
        return self.desc.name

    @property
    def type(self):
        return self.desc.type

    @property
    def is_repeated(self):
        return self.desc.label == 3

    @property
    def typename(self):
        return self.desc.type_name.replace('.', '_')


@dataclasses.dataclass()
class KnownMessages:
    _known: dict[str, Message] = dataclasses.field(default_factory=dict)

    def _AddMessage(self, msg: Message) -> None:
        self._known['.' + msg.type_name] = msg
        for nested_type in msg.desc.nested_type:
            self._AddMessage(
                Message(desc=nested_type,
                        package=msg.package,
                        parent_names=(*msg.parent_names, msg.desc.name)))

    def AddFileDescriptorSet(self,
                             fds: descriptor_pb2.FileDescriptorSet) -> None:
        for f in fds.file:
            for m in f.message_type:
                self._AddMessage(Message(desc=m, package=f.package))

    def GetMessages(self, message_types: set[str]) -> list[Message]:
        return [self._known[t] for t in sorted(message_types)]

    def GetAllTransitiveDeps(self, message_types: set[str]) -> list[Message]:
        seen = message_types
        stack = list(message_types)
        while stack:
            msg = self._known[stack.pop()]
            field_types = {
                field.desc.type_name
                for field in msg.fields if field.type == Type.MESSAGE
            }
            stack.extend(field_types - seen)
            seen.update(field_types)
        return self.GetMessages(seen)


def GenerateProtoDescriptors(out, includes: set[str], messages: KnownMessages,
                             requests: set[str], responses: set[str]):
    """Generate the on_device_model_execution_proto_descriptors.cc content."""

    readable_messages = messages.GetAllTransitiveDeps(requests)
    writable_messages = messages.GetAllTransitiveDeps(responses)

    out.write(
        '// DO NOT MODIFY. GENERATED BY gen_on_device_proto_descriptors.py\n')
    out.write('\n')

    out.write(
        '#include "components/optimization_guide/core/model_execution/on_device_model_execution_proto_descriptors.h"\n'  # pylint: disable=line-too-long
        '#include "components/optimization_guide/core/optimization_guide_util.h"\n'  # pylint: disable=line-too-long
    )
    out.write('\n')

    includes.add('"base/values.h"')
    for include in sorted(includes):
        out.write(f'#include {include}\n')
    out.write('\n')

    out.write('namespace optimization_guide {\n')
    out.write('\n')
    out.write('namespace {\n')
    _GetProtoValue.GenPrivate(out, readable_messages)
    _GetProtoRepeated.GenPrivate(out, readable_messages)
    _SetProtoValue.GenPrivate(out, writable_messages)
    _ConvertValue.GenPrivate(out, writable_messages)
    out.write('}  // namespace\n\n')
    _GetProtoValue.GenPublic(out)
    _GetProtoRepeated.GenPublic(out)
    _GetProtoFromAny.GenPublic(out, readable_messages)
    _SetProtoValue.GenPublic(out)
    _NestedMessageIteratorGet.GenPublic(out, readable_messages)
    _ConvertValue.GenPublic(out, writable_messages)
    out.write("""\
      NestedMessageIterator::NestedMessageIterator(
            const google::protobuf::MessageLite* parent,
            int32_t tag_number,
            int32_t field_size,
            int32_t offset) :
          parent_(parent),
          tag_number_(tag_number),
          field_size_(field_size),
          offset_(offset) {}
      """)
    out.write('}  // namespace optimization_guide\n')
    out.write('\n')


class _GetProtoValue:
    """Namespace class for GetProtoValue method builders."""

    @classmethod
    def GenPublic(cls, out):
        out.write("""
          std::optional<proto::Value> GetProtoValue(
              const google::protobuf::MessageLite& msg,
              const proto::ProtoField& proto_field) {
            return GetProtoValue(msg, proto_field, /*index=*/0);
          }
        """)

    @classmethod
    def GenPrivate(cls, out, messages: list[Message]):
        out.write("""
          std::optional<proto::Value> GetProtoValue(
              const google::protobuf::MessageLite& msg,
              const proto::ProtoField& proto_field, int32_t index) {
            if (index >= proto_field.proto_descriptors_size()) {
              return std::nullopt;
            }
            int32_t tag_number =
                proto_field.proto_descriptors(index).tag_number();
        """)

        for msg in messages:
            cls._IfMsg(out, msg)
        out.write('return std::nullopt;\n')
        out.write('}\n\n')  # End function

    @classmethod
    def _IfMsg(cls, out, msg: Message):
        if all(field.is_repeated for field in msg.fields):
            # Omit the empty case to avoid unused variable warnings.
            return
        out.write(f'if (msg.GetTypeName() == "{msg.type_name}") {{\n')
        out.write(f'const {msg.cpp_name}& casted_msg = ')
        out.write(f'  static_cast<const {msg.cpp_name}&>(msg);\n')
        out.write('switch (tag_number) {\n')
        for field in msg.fields:
            if field.is_repeated:
                continue
            cls._FieldCase(out, field)
        out.write('}\n')  # End switch
        out.write('}\n\n')  # End if statement

    @classmethod
    def _FieldCase(cls, out, field: Field):
        out.write(f'case {field.tag_number}: {{\n')
        name = f'casted_msg.{field.name}()'
        if field.type == Type.MESSAGE:
            out.write(f'return GetProtoValue({name}, proto_field, index+1);\n')
        else:
            out.write('proto::Value value;\n')
            if field.type in {Type.DOUBLE, Type.FLOAT}:
                out.write(
                    f'value.set_float_value(static_cast<double>({name}));\n')
            elif field.type in {Type.INT64, Type.UINT64}:
                out.write(
                    f'value.set_int64_value(static_cast<int64_t>({name}));\n')
            elif field.type in {Type.INT32, Type.UINT32, Type.ENUM}:
                out.write(
                    f'value.set_int32_value(static_cast<int32_t>({name}));\n')
            elif field.type in {Type.BOOL}:
                out.write(f'value.set_boolean_value({name});\n')
            elif field.type in {Type.STRING}:
                out.write(f'value.set_string_value({name});\n')
            else:
                raise Error()
            out.write('return value;\n')
        out.write('}\n')  # End case


class _GetProtoFromAny:
    """Namespace class for GetProtoFromAny method builders."""

    @classmethod
    def GenPublic(cls, out, messages: list[Message]):
        out.write("""
          std::unique_ptr<google::protobuf::MessageLite> GetProtoFromAny(
              const proto::Any& msg) {
        """)

        for msg in messages:
            cls._IfMsg(out, msg)
        out.write('return nullptr;\n')
        out.write('}\n\n')  # End function

    @classmethod
    def _IfMsg(cls, out, msg: Message):
        out.write(f"""if (msg.type_url() ==
                    "type.googleapis.com/{msg.type_name}") {{
            """)
        out.write(
            f'auto casted_msg = ParsedAnyMetadata<{msg.cpp_name}>(msg);\n')
        out.write("""
            std::unique_ptr<google::protobuf::MessageLite> copy(
                casted_msg->New());\n
        """)
        out.write('copy->CheckTypeAndMergeFrom(*casted_msg);\n')
        out.write('return copy;\n')
        out.write('}\n\n')  # End if statement


class _NestedMessageIteratorGet:
    """Namespace class for NestedMessageIterator::Get method builders."""

    @classmethod
    def GenPublic(cls, out, messages: list[Message]):
        out.write('const google::protobuf::MessageLite* '
                  'NestedMessageIterator::Get() const {\n')
        for msg in messages:
            cls._IfMsg(out, msg)
        out.write('  NOTREACHED_IN_MIGRATION();\n')
        out.write('  return nullptr;\n')
        out.write('}\n')

    @classmethod
    def _IfMsg(cls, out, msg: Message):
        out.write(f'if (parent_->GetTypeName() == "{msg.type_name}") {{\n')
        out.write('switch (tag_number_) {\n')
        for field in msg.fields:
            if field.type == Type.MESSAGE and field.is_repeated:
                cls._FieldCase(out, msg, field)
        out.write('}\n')  # End switch
        out.write('}\n\n')  # End if statement

    @classmethod
    def _FieldCase(cls, out, msg: Message, field: Field):
        cast_msg = f'static_cast<const {msg.cpp_name}*>(parent_)'
        out.write(f'case {field.tag_number}: {{\n')
        out.write(f'return &{cast_msg}->{field.name}(offset_);\n')
        out.write('}\n')  # End case


class _GetProtoRepeated:
    """Namespace class for GetProtoRepeated method builders."""

    @classmethod
    def GenPublic(cls, out):
        out.write("""
          std::optional<NestedMessageIterator> GetProtoRepeated(
              const google::protobuf::MessageLite* msg,
              const proto::ProtoField& proto_field) {
            return GetProtoRepeated(msg, proto_field, /*index=*/0);
          }
          """)

    @classmethod
    def GenPrivate(cls, out, messages: list[Message]):
        out.write("""\
          std::optional<NestedMessageIterator> GetProtoRepeated(
              const google::protobuf::MessageLite* msg,
              const proto::ProtoField& proto_field,
              int32_t index) {
            if (index >= proto_field.proto_descriptors_size()) {
              return std::nullopt;
            }
            int32_t tag_number =
                proto_field.proto_descriptors(index).tag_number();
          """)

        for msg in messages:
            cls._IfMsg(out, msg)
        out.write('return std::nullopt;\n')
        out.write('}\n\n')  # End function

    @classmethod
    def _IfMsg(cls, out, msg: Message):
        out.write(f'if (msg->GetTypeName() == "{msg.type_name}") {{\n')
        out.write('switch (tag_number) {\n')
        for field in msg.fields:
            if field.type == Type.MESSAGE:
                cls._FieldCase(out, msg, field)
        out.write('}\n')  # End switch
        out.write('}\n\n')  # End if statement

    @classmethod
    def _FieldCase(cls, out, msg: Message, field: Field):
        field_expr = f'static_cast<const {msg.cpp_name}*>(msg)->{field.name}()'
        out.write(f'case {field.tag_number}: {{\n')
        if field.is_repeated:
            out.write(f'return NestedMessageIterator('
                      f'msg, tag_number, {field_expr}.size(), 0);\n')
        else:
            out.write(f'return GetProtoRepeated('
                      f'&{field_expr}, proto_field, index+1);\n')
        out.write('}\n')  # End case


class _SetProtoValue:
    """Namespace class for SetProtoValue method builders."""

    @classmethod
    def GenPublic(cls, out):
        out.write("""
      std::optional<proto::Any> SetProtoValue(
          const std::string& proto_name,
          const proto::ProtoField& proto_field,
          const std::string& value) {
        return SetProtoValue(proto_name, proto_field, value, /*index=*/0);
      }
    """)

    @classmethod
    def GenPrivate(cls, out, messages: list[Message]):
        out.write("""
      std::optional<proto::Any> SetProtoValue(
          const std::string& proto_name,
          const proto::ProtoField& proto_field,
          const std::string& value,
          int32_t index) {
        if (index >= proto_field.proto_descriptors_size()) {
          return std::nullopt;
        }
    """)
        for msg in messages:
            cls._IfMsg(out, msg)
        out.write("""
        return std::nullopt;
      }
    """)

    @classmethod
    def _IfMsg(cls, out, msg: Message):
        out.write(f'if (proto_name == "{msg.type_name}") {{\n')
        out.write(
            'switch(proto_field.proto_descriptors(index).tag_number()) {\n')
        for field in msg.fields:
            cls._FieldCase(out, msg, field)
        out.write("""
      default:
        return std::nullopt;\n
      """)
        out.write('}')
        out.write('}\n')  # End if statement

    @classmethod
    def _FieldCase(cls, out, msg: Message, field: Field):
        if field.type == Type.STRING:
            out.write(f'case {field.tag_number}: {{\n')
            out.write('proto::Any any;\n')
            out.write(
                f'any.set_type_url("type.googleapis.com/{msg.type_name}");\n')
            out.write(f'{msg.cpp_name} response_value;\n')
            out.write(f'response_value.set_{field.name}(value);')
            out.write('response_value.SerializeToString(any.mutable_value());')
            out.write('return any;')
            out.write('}\n')


class _ConvertValue:
    """Namespace class for base::Value->Message method builders."""

    @classmethod
    def GenPublic(cls, out, messages: list[Message]):
        out.write(f"""
          std::optional<proto::Any> ConvertToAnyWrappedProto(
              const base::Value& object, const std::string& type_name) {{
            proto::Any any;
            any.set_type_url("type.googleapis.com/" + type_name);
        """)
        for msg in messages:
            out.write(f"""
            if (type_name == "{msg.type_name}") {{
              {msg.cpp_name} msg;
              if (Convert{msg.iname}(object, msg)) {{
                msg.SerializeToString(any.mutable_value());
                return any;
              }}
            }}
          """)

        out.write(f"""
            return std::nullopt;
          }}
        """)

    @classmethod
    def GenPrivate(cls, out, messages: list[Message]):
        for msg in messages:
            out.write(f"""
            bool Convert{msg.iname}(
                const base::Value& object, {msg.cpp_name}& proto);
          """)
        for msg in messages:
            cls._DefineConvert(out, msg)

    @classmethod
    def _DefineConvert(cls, out, msg: Message):
        out.write(f"""
          bool Convert{msg.iname}(
              const base::Value& object, {msg.cpp_name}& proto) {{
            const base::Value::Dict* asdict = object.GetIfDict();
            if (!asdict) {{
              return false;
            }}
        """)
        for field in msg.fields:
            if field.type == Type.GROUP:
                continue
            if field.type == Type.ENUM:
                continue
            out.write('if (const base::Value* field_value =\n')
            out.write(f'     asdict->Find("{field.desc.json_name}")) {{')
            cls._FieldCase(out, msg, field)
            out.write(f'}}')
        out.write(f"""
            return true;
          }}
        """)

    @classmethod
    def _FieldCase(cls, out, msg: Message, field: Field):
        if field.is_repeated:
            out.write(f"""
              const auto* lst = field_value->GetIfList();
              if (!lst) {{
                return false;
              }}
              for (const base::Value& entry_value : *lst) {{
            """)
            if field.type == Type.MESSAGE:
                out.write(f"""
                  if (!Convert{field.typename}(
                      entry_value, *proto.add_{field.name}())) {{
                    return false;
                  }}
                """)
            else:
                vtype = BASE_VALUE_TYPES[field.type]
                out.write(f"""
                  const {vtype.cpptype} v = entry_value.GetIf{vtype.getIfFn}();
                  if (!v) {{
                    return false;
                  }}
                  proto.add_{field.name}(*v);
                """)
            out.write("}")  # end for loop
        else:
            if field.type == Type.MESSAGE:
                out.write(f"""
                  if (!Convert{field.typename}(
                      *field_value, *proto.mutable_{field.name}())) {{
                    return false;
                  }}
                """)
                return
            else:
                vtype = BASE_VALUE_TYPES[field.type]
                out.write(f"""
                  const {vtype.cpptype} v = field_value->GetIf{vtype.getIfFn}();
                  if (!v) {{
                    return false;
                  }}
                  proto.set_{field.name}(*v);
                """)


def main(argv):
    parser = optparse.OptionParser()
    parser.add_option('--input_file', action='append', default=[])
    parser.add_option('--output_cc')
    parser.add_option('--include', action='append', default=[])
    parser.add_option('--request', action='append', default=[])
    parser.add_option('--response', action='append', default=[])
    options, _ = parser.parse_args(argv)

    input_files = list(options.input_file)
    includes = set(options.include)
    requests = set(options.request)
    responses = set(options.response)

    # Write to standard output or file specified by --output_cc.
    out_cc = getattr(sys.stdout, 'buffer', sys.stdout)
    if options.output_cc:
        out_cc = open(options.output_cc, 'wb')

    messages = KnownMessages()
    for input_file in input_files:
        fds = descriptor_pb2.FileDescriptorSet()
        with open(input_file, 'rb') as fp:
            fds.ParseFromString(fp.read())
            messages.AddFileDescriptorSet(fds)

    out_cc_str = StringIO()
    GenerateProtoDescriptors(out_cc_str, includes, messages, requests,
                             responses)
    out_cc.write(out_cc_str.getvalue().encode('utf-8'))

    if options.output_cc:
        out_cc.close()

    return 0


if __name__ == '__main__':
    sys.exit(main(sys.argv[1:]))