chromium/third_party/jni_zero/java_types.py

# Copyright 2023 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import dataclasses
from typing import Dict
from typing import Optional
from typing import Tuple

import java_lang_classes

_CPP_TYPE_BY_JAVA_TYPE = {
    'boolean': 'jboolean',
    'byte': 'jbyte',
    'char': 'jchar',
    'double': 'jdouble',
    'float': 'jfloat',
    'int': 'jint',
    'long': 'jlong',
    'short': 'jshort',
    'void': 'void',
    'java/lang/Class': 'jclass',
    'java/lang/String': 'jstring',
    'java/lang/Throwable': 'jthrowable',
}

_DESCRIPTOR_CHAR_BY_PRIMITIVE_TYPE = {
    'boolean': 'Z',
    'byte': 'B',
    'char': 'C',
    'double': 'D',
    'float': 'F',
    'int': 'I',
    'long': 'J',
    'short': 'S',
    'void': 'V',
}

_PRIMITIVE_TYPE_BY_DESCRIPTOR_CHAR = {
    v: k
    for k, v in _DESCRIPTOR_CHAR_BY_PRIMITIVE_TYPE.items()
}

_DEFAULT_VALUE_BY_PRIMITIVE_TYPE = {
    'boolean': 'false',
    'byte': '0',
    'char': '0',
    'double': '0',
    'float': '0',
    'int': '0',
    'long': '0',
    'short': '0',
    'void': '',
}

PRIMITIVES = frozenset(_DEFAULT_VALUE_BY_PRIMITIVE_TYPE)


@dataclasses.dataclass(frozen=True, order=True)
class JavaClass:
  """Represents a reference type."""
  _fqn: str
  # This is only meaningful if make_prefix have been called on the original class.
  _prefix: str = None

  def __post_init__(self):
    assert '.' not in self._fqn, f'{self._fqn} should have / and $, but not .'

  def __str__(self):
    return self.full_name_with_slashes

  @property
  def name(self):
    return self._fqn.rsplit('/', 1)[-1]

  @property
  def name_with_dots(self):
    return self.name.replace('$', '.')

  @property
  def nested_name(self):
    return self.name.rsplit('$', 1)[-1]

  @property
  def package_with_slashes(self):
    return self._fqn.rsplit('/', 1)[0]

  @property
  def package_with_dots(self):
    return self.package_with_slashes.replace('/', '.')

  @property
  def full_name_with_slashes(self):
    return self._fqn

  @property
  def full_name_with_dots(self):
    return self._fqn.replace('/', '.').replace('$', '.')

  @property
  def prefix_with_dots(self):
    return self._prefix.replace('/', '.') if self._prefix else self._prefix

  @property
  def class_without_prefix(self):
    if not self._prefix:
      return self
    return JavaClass(self._fqn[len(self._prefix) + 1:])

  @property
  def outer_class_name(self):
    return self.name.split('$', 1)[0]

  def is_nested(self):
    return '$' in self.name

  def get_outer_class(self):
    return JavaClass(f'{self.package_with_slashes}/{self.outer_class_name}')

  def is_system_class(self):
    return self._fqn.startswith(('android/', 'java/'))

  def to_java(self, type_resolver=None):
    # Empty resolver used to shorted java.lang classes.
    type_resolver = type_resolver or _EMPTY_TYPE_RESOLVER
    return type_resolver.contextualize(self)

  def as_type(self):
    return JavaType(java_class=self)

  def make_prefixed(self, prefix):
    if not prefix:
      return self
    prefix = prefix.replace('.', '/')
    return JavaClass(f'{prefix}/{self._fqn}', prefix)

  def make_nested(self, name):
    return JavaClass(f'{self._fqn}${name}')


@dataclasses.dataclass(frozen=True)
class JavaType:
  """Represents a parameter or return type."""
  array_dimensions: int = 0
  primitive_name: Optional[str] = None
  java_class: Optional[JavaClass] = None
  converted_type: Optional[str] = dataclasses.field(default=None, compare=False)
  nullable: bool = True

  @staticmethod
  def from_descriptor(descriptor):
    # E.g.: [Ljava/lang/Class;
    without_arrays = descriptor.lstrip('[')
    array_dimensions = len(descriptor) - len(without_arrays)
    descriptor = without_arrays

    if descriptor[0] == 'L':
      assert descriptor[-1] == ';', 'invalid descriptor: ' + descriptor
      return JavaType(array_dimensions=array_dimensions,
                      java_class=JavaClass(descriptor[1:-1]))
    primitive_name = _PRIMITIVE_TYPE_BY_DESCRIPTOR_CHAR[descriptor[0]]
    return JavaType(array_dimensions=array_dimensions,
                    primitive_name=primitive_name)

  @property
  def non_array_full_name_with_slashes(self):
    return self.primitive_name or self.java_class.full_name_with_slashes

  # Cannot use dataclass(order=True) because some fields are None.
  def __lt__(self, other):
    if self.primitive_name and not other.primitive_name:
      return True
    if other.primitive_name and not self.primitive_name:
      return False
    lhs = (self.array_dimensions, self.primitive_name or self.java_class)
    rhs = (other.array_dimensions, other.primitive_name or other.java_class)
    return lhs < rhs

  def is_primitive(self):
    return self.primitive_name is not None and self.array_dimensions == 0

  def is_array(self):
    return self.array_dimensions > 0

  def is_primitive_array(self):
    return self.primitive_name is not None and self.array_dimensions > 0

  def is_object_array(self):
    return self.array_dimensions > 1 or (self.primitive_name is None
                                         and self.array_dimensions > 0)

  def is_collection(self):
    return not self.is_array() and self.java_class in COLLECTION_CLASSES

  def is_void(self):
    return self.primitive_name == 'void'

  def to_array_element_type(self):
    assert self.is_array()
    return JavaType(array_dimensions=self.array_dimensions - 1,
                    primitive_name=self.primitive_name,
                    java_class=self.java_class)

  def to_descriptor(self):
    """Converts a Java type into a JNI signature type."""
    if self.primitive_name:
      name = _DESCRIPTOR_CHAR_BY_PRIMITIVE_TYPE[self.primitive_name]
    else:
      name = f'L{self.java_class.full_name_with_slashes};'
    return ('[' * self.array_dimensions) + name

  def to_java(self, type_resolver=None):
    if self.primitive_name:
      ret = self.primitive_name
    else:
      ret = self.java_class.to_java(type_resolver)
    return ret + '[]' * self.array_dimensions

  def to_cpp(self):
    """Returns a C datatype for the given java type."""
    if self.array_dimensions > 1:
      return 'jobjectArray'
    if self.array_dimensions > 0 and self.primitive_name is None:
      # There is no jstringArray.
      return 'jobjectArray'

    cpp_type = _CPP_TYPE_BY_JAVA_TYPE.get(self.non_array_full_name_with_slashes,
                                          'jobject')
    if self.array_dimensions:
      cpp_type = f'{cpp_type}Array'
    return cpp_type

  def to_cpp_default_value(self):
    """Returns a valid C return value for the given java type."""
    if self.is_primitive():
      return _DEFAULT_VALUE_BY_PRIMITIVE_TYPE[self.primitive_name]
    return 'nullptr'

  def to_proxy(self):
    """Converts to types used over JNI boundary."""
    return self if self.is_primitive() else OBJECT


@dataclasses.dataclass(frozen=True)
class JavaParam:
  """Represents a parameter."""
  java_type: JavaType
  name: str

  def to_proxy(self):
    """Converts to types used over JNI boundary."""
    return JavaParam(self.java_type.to_proxy(), self.name)

  def cpp_name(self):
    if self.name in ('env', 'jcaller'):
      return f'_{self.name}'
    return self.name


class JavaParamList(tuple):
  """Represents a parameter list."""
  def to_proxy(self):
    """Converts to types used over JNI boundary."""
    return JavaParamList(p.to_proxy() for p in self)

  def to_java_declaration(self, type_resolver=None):
    return ', '.join('%s %s' % (p.java_type.to_java(type_resolver), p.name)
                     for p in self)

  def to_call_str(self):
    return ', '.join(p.name for p in self)


@dataclasses.dataclass(frozen=True, order=True)
class JavaSignature:
  """Represents a method signature (return type + parameter types)."""
  return_type: JavaType
  param_types: Tuple[JavaType]
  # Signatures should be considered equal if parameter names differ, so exclude
  # param_list from comparisons.
  param_list: JavaParamList = dataclasses.field(compare=False)

  @staticmethod
  def from_params(return_type, param_list):
    return JavaSignature(return_type=return_type,
                         param_types=tuple(p.java_type for p in param_list),
                         param_list=param_list)

  @staticmethod
  def from_descriptor(descriptor):
    # E.g.: (Ljava/lang/Object;Ljava/lang/Runnable;)Ljava/lang/Class;
    assert descriptor[0] == '('
    i = 1
    start_idx = i
    params = []
    while True:
      char = descriptor[i]
      if char == ')':
        break
      elif char == '[':
        i += 1
        continue
      elif char == 'L':
        end_idx = descriptor.index(';', i) + 1
      else:
        end_idx = i + 1
      param_type = JavaType.from_descriptor(descriptor[start_idx:end_idx])
      params.append(JavaParam(param_type, f'p{len(params)}'))
      i = end_idx
      start_idx = end_idx

    return_type = JavaType.from_descriptor(descriptor[i + 1:])
    return JavaSignature.from_params(return_type, JavaParamList(params))

  def to_descriptor(self):
    """Returns the JNI signature."""
    sb = ['(']
    sb += [t.to_descriptor() for t in self.param_types]
    sb += [')']
    sb += [self.return_type.to_descriptor()]
    return ''.join(sb)

  def to_proxy(self):
    """Converts to types used over JNI boundary."""
    return_type = self.return_type.to_proxy()
    param_list = self.param_list.to_proxy()
    return JavaSignature.from_params(return_type, param_list)

  def with_params_reordered(self):
    return JavaSignature.from_params(
        self.return_type,
        JavaParamList(
            tuple(sorted(self.param_list,
                         key=lambda x: x.java_type.to_proxy()))))


class TypeResolver:
  """Converts type names to fully qualified names."""
  def __init__(self, java_class):
    self.java_class = java_class
    self.imports = []
    self.nested_classes = []

  def add_import(self, java_class):
    self.imports.append(java_class)

  def add_nested_class(self, java_class):
    self.nested_classes.append(java_class)

  def contextualize(self, java_class):
    """Return the shortest string that resolves to the given class."""
    type_package = java_class.package_with_slashes
    if type_package in ('java/lang', self.java_class.package_with_slashes):
      return java_class.name_with_dots
    if java_class in self.imports:
      return java_class.name_with_dots

    return java_class.full_name_with_dots

  def resolve(self, name):
    """Return a JavaClass for the given type name."""
    assert name not in PRIMITIVES, 'Name: ' + name
    assert ' ' not in name, 'Name: ' + name
    assert name != '', 'Cannot resolve empty string'

    if '/' in name:
      # Coming from javap, use the fully qualified name directly.
      return JavaClass(name)

    if self.java_class.name == name:
      return self.java_class

    for clazz in self.nested_classes:
      if name in (clazz.name, clazz.nested_name):
        return clazz

    # Is it from an import? (e.g. referencing Class from import pkg.Class).
    for clazz in self.imports:
      if name in (clazz.name, clazz.nested_name):
        return clazz

    # Is it an inner class from an outer class import? (e.g. referencing
    # Class.Inner from import pkg.Class).
    if '.' in name:
      # Assume lowercase means it's a fully qualifited name.
      if name[0].islower():
        return JavaClass(name.replace('.', '/'))
      # Otherwise, try and find the outer class in imports.
      components = name.split('.')
      outer = '/'.join(components[:-1])
      inner = components[-1]
      for clazz in self.imports:
        if clazz.name == outer:
          return clazz.make_nested(inner)
      name = name.replace('.', '$')

    # java.lang classes always take priority over types from the same package.
    # To use a type from the same package that has the same name as a java.lang
    # type, it must be explicitly imported.
    if java_lang_classes.contains(name):
      return JavaClass(f'java/lang/{name}')

    # Type not found, falling back to same package as this class.
    # Set the same prefix with this class.
    ret = JavaClass(
        f'{self.java_class.class_without_prefix.package_with_slashes}/{name}')
    return ret.make_prefixed(self.java_class.prefix_with_dots)


CLASS_CLASS = JavaClass('java/lang/Class')
OBJECT_CLASS = JavaClass('java/lang/Object')
STRING_CLASS = JavaClass('java/lang/String')
_LIST_CLASS = JavaClass('java/util/List')

# Collection and types that extend it (for use with toArray()).
# More can be added here if the need arises.
COLLECTION_CLASSES = (
    _LIST_CLASS,
    JavaClass('java/util/Collection'),
    JavaClass('java/util/Set'),
)

OBJECT = JavaType(java_class=OBJECT_CLASS)
CLASS = JavaType(java_class=CLASS_CLASS)
LIST = JavaType(java_class=_LIST_CLASS)
INT = JavaType(primitive_name='int')
VOID = JavaType(primitive_name='void')

_EMPTY_TYPE_RESOLVER = TypeResolver(OBJECT_CLASS)
EMPTY_PARAM_LIST = JavaParamList()