chromium/third_party/jni_zero/codegen/proxy_impl_java.py

# Copyright 2023 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import java_types
import common


class _Context:

  def __init__(self, jni_obj, gen_jni_class, script_name, per_file_natives):
    self.jni_obj = jni_obj
    self.gen_jni_class = gen_jni_class
    self.script_name = script_name
    self.per_file_natives = per_file_natives

    self.interface_name = jni_obj.proxy_interface.name_with_dots
    self.proxy_class = java_types.JavaClass(
        f'{self.jni_obj.java_class.full_name_with_slashes}Jni')
    self.type_resolver = java_types.TypeResolver(self.proxy_class)
    self.type_resolver.imports = jni_obj.GetClassesToBeImported()


def _implicit_array_class_param(native, type_resolver):
  return_type = native.return_type
  class_name = return_type.to_array_element_type().to_java(type_resolver)
  return class_name + '.class'


def _proxy_method(sb, ctx, native, method_fqn):
  return_type_str = native.return_type.to_java(ctx.type_resolver)
  sig_params = native.params.to_java_declaration(ctx.type_resolver)

  sb(f"""
@Override
public {return_type_str} {native.name}({sig_params})""")
  with sb.block():
    if native.first_param_cpp_type:
      sb(f'assert {native.params[0].name} != 0;\n')
    for p in native.params:
      if not p.java_type.nullable:
        sb(f'assert {p.name} != null;\n')
    with sb.statement():
      if not native.return_type.is_void():
        sb(f'return ({return_type_str}) ')
      sb(method_fqn)
      with sb.param_list() as plist:
        plist.extend(p.name for p in native.params)
        if native.needs_implicit_array_element_class_param:
          plist.append(_implicit_array_class_param(native, ctx.type_resolver))


def _native_method(sb, native, method_fqn):
  params = native.proxy_params.to_java_declaration()
  return_type = native.proxy_return_type.to_java()
  sb(f'private static native {return_type} {method_fqn}({params});\n')


def _get_method(sb, ctx):
  sb(f'\npublic static {ctx.interface_name} get()')
  with sb.block():
    if not ctx.per_file_natives:
      sb(f"""\
if ({ctx.gen_jni_class.name}.TESTING_ENABLED) {{
  if (testInstance != null) {{
    return testInstance;
  }}
  if ({ctx.gen_jni_class.name}.REQUIRE_MOCK) {{
    throw new UnsupportedOperationException(
        "No mock found for the native implementation of {ctx.interface_name}. "
        + "The current configuration requires implementations be mocked.");
  }}
}}
""")
    sb(f"""\
NativeLibraryLoadedStatus.checkLoaded();
return new {ctx.proxy_class.name}();
""")


def _class_body(sb, ctx):
  sb(f"""\
private static {ctx.interface_name} testInstance;

public static final JniStaticTestMocker<{ctx.interface_name}> TEST_HOOKS =
    new JniStaticTestMocker<{ctx.interface_name}>()""")
  with sb.block(end='};\n'):
    sb(f"""\
@Override
public void setInstanceForTesting({ctx.interface_name} instance)""")
    with sb.block():
      if not ctx.per_file_natives:
        sb(f"""\
if (!{ctx.gen_jni_class.name}.TESTING_ENABLED) {{
  throw new RuntimeException(
      "Tried to set a JNI mock when mocks aren't enabled!");
}}
""")
      sb('testInstance = instance;\n')

  for native in ctx.jni_obj.proxy_natives:
    if ctx.per_file_natives:
      method_fqn = f'native{common.capitalize(native.name)}'
    else:
      method_fqn = f'{ctx.gen_jni_class.name}.{native.proxy_name}'

    _proxy_method(sb, ctx, native, method_fqn)
    if ctx.per_file_natives:
      _native_method(sb, native, method_fqn)

  _get_method(sb, ctx)


def _imports(sb, ctx):
  classes = {
      'org.jni_zero.CheckDiscard',
      'org.jni_zero.JniStaticTestMocker',
      'org.jni_zero.NativeLibraryLoadedStatus',
  }
  if not ctx.per_file_natives:
    classes.add(ctx.gen_jni_class.full_name_with_dots)

  for c in ctx.type_resolver.imports:
    # Since this is pure Java, the class generated here will go through jarjar
    # and thus we want to avoid prefixes.
    c = c.class_without_prefix
    if c.is_nested:
      # We will refer to all nested classes by OuterClass.InnerClass. We do this
      # to reduce risk of naming collisions.
      c = c.get_outer_class()
    classes.add(c.full_name_with_dots)

  for c in sorted(classes):
    sb(f'import {c};\n')


def Generate(jni_obj, *, gen_jni_class, script_name, per_file_natives=False):
  ctx = _Context(jni_obj, gen_jni_class, script_name, per_file_natives)

  sb = common.StringBuilder()
  sb(f"""\
//
// This file was generated by {script_name}
//
package {jni_obj.java_class.class_without_prefix.package_with_dots};

""")
  _imports(sb, ctx)
  sb('\n')

  visibility = 'public ' if jni_obj.proxy_visibility == 'public' else ''
  class_name = ctx.proxy_class.name
  if not per_file_natives:
    sb('@CheckDiscard("crbug.com/993421")\n')
  sb(f'{visibility}class {class_name} implements {ctx.interface_name}')
  with sb.block():
    _class_body(sb, ctx)
  return sb.to_string()