chromium/mojo/public/tools/bindings/generators/mojom_rust_generator.py

# Copyright 2024 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import argparse
import json
import pathlib

import mojom.generate.generator as generator
import mojom.generate.module as mojom
import mojom.generate.pack as pack
from mojom.generate.template_expander import UseJinja, UseJinjaForImportedTemplate

GENERATOR_PREFIX = 'rust'

_kind_to_rust_type = {
    mojom.BOOL: "bool",
    mojom.INT8: "i8",
    mojom.INT16: "i16",
    mojom.INT32: "i32",
    mojom.INT64: "i64",
    mojom.UINT8: "u8",
    mojom.UINT16: "u16",
    mojom.UINT32: "u32",
    mojom.UINT64: "u64",
    mojom.FLOAT: "f32",
    mojom.DOUBLE: "f64",
}


class Generator(generator.Generator):

  def __init__(self, *args, **kwargs):
    super(Generator, self).__init__(*args, **kwargs)

    # Lists other generated bindings this module depends on. Unlike C++
    # bindings, the generated Rust code must reference which GN target the
    # bindings come from. Each entry has the GN target (e.g. foo_mojom_rust)
    # and an arbitrary name to import it under.
    self._imports = []

    # Maps each mojom source file to the Rust name its GN target was imported
    # under.
    self._path_import_map = {}

  def _GetNameForKind(self, kind, is_data=False):
    ''' Get the full Rust type name to refer to a generated binding.

    Args:
      is_data: If true, get the name of the wire-format type.
    '''
    type_name = kind.name

    # If the type is defined in another type (e.g. an Enum defined in a
    # Struct), prepend the enclosing type name. Rust does not support nested
    # type definitions, unlike C++.
    if kind.parent_kind:
      type_name = f"{kind.parent_kind.name}_{kind.name}"

    if is_data:
      type_name += "_Data"

    # Use the name as is if it's defined in the current module.
    if kind.module is self.module:
      return type_name

    # Construct the fully qualified path to the type if it's defined in another
    # mojom module.
    source_name = pathlib.Path(kind.module.path).stem
    if kind.module.path in self._path_import_map:
      imp_name = self._path_import_map[kind.module.path]
      return f"{imp_name}::{source_name}::{type_name}"

    return f"crate::{source_name}::{type_name}"

  def _GetRustFieldType(self, kind):
    if mojom.IsNullableKind(kind):
      return f"Option<{self._GetRustFieldType(kind.MakeUnnullableKind())}>"
    if mojom.IsEnumKind(kind):
      return self._GetNameForKind(kind)
    if mojom.IsStructKind(kind) or mojom.IsUnionKind(kind):
      return f"Box<{self._GetNameForKind(kind)}>"
    if mojom.IsArrayKind(kind):
      return f"Vec<{self._GetRustFieldType(kind.kind)}>"
    if mojom.IsMapKind(kind):
      return (f"std::collections::HashMap<"
              f"{self._GetRustFieldType(kind.key_kind)}, "
              f"{self._GetRustFieldType(kind.value_kind)}>")
    if mojom.IsStringKind(kind):
      return "String"
    if mojom.IsAnyHandleKind(kind):
      return "::mojo::UntypedHandle"
    if mojom.IsReferenceKind(kind):
      return "usize"
    if kind not in _kind_to_rust_type:
      return "()"
    return _kind_to_rust_type[kind]

  def _GetRustDataFieldType(self, kind):
    if mojom.IsEnumKind(kind):
      return self._GetNameForKind(kind, is_data=True)
    if mojom.IsStructKind(kind):
      return (f"bindings::data::Pointer<"
              f"{self._GetNameForKind(kind, is_data=True)}>")
    if mojom.IsArrayKind(kind):
      return (f"bindings::data::Pointer<bindings::data::Array<"
              f"{self._GetRustDataFieldType(kind.kind)}>>")
    if mojom.IsStringKind(kind):
      return "bindings::data::Pointer<bindings::data::Array<u8>>"
    if mojom.IsMapKind(kind):
      return (f"bindings::data::Pointer<bindings::data::Map<"
              f"{self._GetRustDataFieldType(kind.key_kind)}, "
              f"{self._GetRustDataFieldType(kind.value_kind)}>>")
    if mojom.IsUnionKind(kind):
      return self._GetNameForKind(kind, is_data=True)
    if mojom.IsInterfaceKind(kind) or mojom.IsPendingRemoteKind(kind):
      return "bindings::data::InterfaceData"
    if mojom.IsPendingReceiverKind(kind):
      return "bindings::data::HandleRef"
    if mojom.IsPendingAssociatedRemoteKind(kind):
      return "bindings::data::InterfaceData"
    if mojom.IsPendingAssociatedReceiverKind(kind):
      return "bindings::data::HandleRef"
    if mojom.IsAnyHandleKind(kind):
      return "bindings::data::HandleRef"
    if kind not in _kind_to_rust_type:
      return "()"
    return _kind_to_rust_type[kind]

  def _GetRustUnionFieldType(self, kind):
    if kind == mojom.BOOL:
      return "u8"
    if mojom.IsUnionKind(kind):
      return (
          f"bindings::data::Pointer<{self._GetNameForKind(kind, is_data=True)}>"
      )
    return self._GetRustDataFieldType(kind)

  def _GetRustReferentDataType(self, kind):
    if mojom.IsStructKind(kind):
      return self._GetNameForKind(kind, is_data=True)
    else:
      print(kind.Repr())
      raise Exception("Not implemented")

  def _ToUpperSnakeCase(self, ident):
    return generator.ToUpperSnakeCase(ident)

  def _ToLowerSnakeCase(self, ident):
    return generator.ToLowerSnakeCase(ident)

  def _GetRustDataFields(self, packed_struct):
    ''' Map pack.PackedStruct to a list of Rust fields.

    Adjacent bitfield members are packed into u8 fields, and explicit padding is
    added so that the resulting type has no possibly uninitialized bits.
    '''
    rust_fields = []
    for i in range(len(packed_struct.packed_fields)):
      packed_field = packed_struct.packed_fields[i]

      # Compute padding needed, if any, between current and previous field.
      if i > 0:
        prev_pf = packed_struct.packed_fields[i - 1]
        pad_start = prev_pf.offset + prev_pf.size
        pad_size = packed_field.offset - pad_start
        if pad_size > 0:
          rust_fields.append({
              "name": f"_pad_{pad_start}",
              "type": f"[u8; {pad_size}]"
          })

      # Bitfields are packed together. Since Rust doesn't have C-style bitfields
      # this must be done manually. On the first bool bit at a given offset,
      # declare a unique u8 member. `packed_struct`'s fields are ordered by byte
      # and bit offset so skip fields other than the first bit.
      if packed_field.field.kind == mojom.BOOL:
        if packed_field.bit == 0:
          rust_fields.append({
              "name": f"_packed_bits_{packed_field.offset}",
              "type": "u8"
          })
        continue

      # Pick the field name. If `pf.original_field` is present, this is a
      # nullable primitive kind so its value component has a name which includes
      # a special character; use the name from the original field. Otherwise use
      # the name as-is.
      if packed_field.original_field:
        name = packed_field.original_field.name
      else:
        name = packed_field.field.name

      rust_fields.append({
          "name":
          name,
          "type":
          self._GetRustDataFieldType(packed_field.field.kind)
      })

    # Create end padding, if needed.
    if len(packed_struct.packed_fields) > 0:
      last_pf = packed_struct.packed_fields[-1]
      pad = pack.GetPad(last_pf.offset + last_pf.size, 8)
      if pad > 0:
        rust_fields.append({"name": "_pad_end", "type": f"[u8; {pad}]"})

    return rust_fields

  def _GetPackedBoolLocation(self, packed_field):
    return {
        "field_name": f"_packed_bits_{packed_field.offset}",
        "bit_offset": f"{packed_field.bit}"
    }

  @staticmethod
  def GetTemplatePrefix():
    return "rust_templates"

  def GetFilters(self):
    rust_filters = {
        "get_packed_bool_location": self._GetPackedBoolLocation,
        "get_pad": pack.GetPad,
        "get_rust_data_fields": self._GetRustDataFields,
        "is_enum_kind": mojom.IsEnumKind,
        "is_pointer_kind": mojom.IsPointerKind,
        "is_nullable_kind": mojom.IsNullableKind,
        "is_struct_kind": mojom.IsStructKind,
        "rust_field_type": self._GetRustFieldType,
        "rust_union_field_type": self._GetRustUnionFieldType,
        "rust_referent_data_type": self._GetRustReferentDataType,
        "to_upper_snake_case": self._ToUpperSnakeCase,
        "to_lower_snake_case": self._ToLowerSnakeCase,
    }
    return rust_filters

  @UseJinja("module.tmpl")
  def _GenerateModule(self):
    return {"module": self.module, "imports": self._imports}

  def GenerateFiles(self, unparsed_args):
    parser = argparse.ArgumentParser()
    parser.add_argument('--rust_dep_info')
    args = parser.parse_args(unparsed_args)

    # Load the list of mojom dependencies: each GN target and its list of
    # mojom source files.
    dep_info = []
    with open(args.rust_dep_info) as dep_info_json:
      dep_info = json.loads(dep_info_json.read())

    import_ndx = 1
    for dep in dep_info:
      use_name = f"dep{import_ndx}"
      self._imports.append({
          "target": dep["target_name"],
          "use_name": use_name,
      })

      for src in dep["mojom_sources"]:
        self._path_import_map[src] = use_name

      import_ndx += 1

    self.module.Stylize(generator.Stylizer())

    self.WriteWithComment(self._GenerateModule(), f"{self.module.path}.rs")