chromium/components/segmentation_platform/internal/android/segmentation_platform_service_android.cc

// Copyright 2020 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/segmentation_platform/internal/android/segmentation_platform_service_android.h"

#include <memory>

#include "base/android/callback_android.h"
#include "base/android/jni_string.h"
#include "base/android/scoped_java_ref.h"
#include "base/functional/bind.h"
#include "base/memory/scoped_refptr.h"
#include "components/segmentation_platform/public/android/input_context_android.h"
#include "components/segmentation_platform/public/android/prediction_options_android.h"
#include "components/segmentation_platform/public/android/segmentation_platform_conversion_bridge.h"
#include "components/segmentation_platform/public/input_context.h"
#include "components/segmentation_platform/public/prediction_options.h"
#include "components/segmentation_platform/public/result.h"
#include "components/segmentation_platform/public/segment_selection_result.h"
#include "components/segmentation_platform/public/segmentation_platform_service.h"

// Must come after all headers that specialize FromJniType() / ToJniType().
#include "components/segmentation_platform/internal/jni_headers/SegmentationPlatformServiceImpl_jni.h"

using base::android::AttachCurrentThread;
using base::android::ConvertJavaStringToUTF8;
using base::android::JavaParamRef;

namespace segmentation_platform {
namespace {

const char kSegmentationPlatformServiceBridgeKey[] =
    "segmentation_platform_service_bridge";

void RunGetSelectedSegmentCallback(const JavaRef<jobject>& j_callback,
                                   const SegmentSelectionResult& result) {
  JNIEnv* env = AttachCurrentThread();
  base::android::RunObjectCallbackAndroid(
      j_callback,
      SegmentationPlatformConversionBridge::CreateJavaSegmentSelectionResult(
          env, result));
}

void RunGetClassificationResultCallback(const JavaRef<jobject>& j_callback,
                                        const ClassificationResult& result) {
  JNIEnv* env = AttachCurrentThread();
  base::android::RunObjectCallbackAndroid(
      j_callback,
      SegmentationPlatformConversionBridge::CreateJavaClassificationResult(
          env, result));
}

}  // namespace

// This function is declared in segmentation_platform_service.h and
// should be linked in to any binary using
// SegmentationPlatformService::GetJavaObject. static
ScopedJavaLocalRef<jobject> SegmentationPlatformService::GetJavaObject(
    SegmentationPlatformService* service) {
  if (!service->GetUserData(kSegmentationPlatformServiceBridgeKey)) {
    service->SetUserData(
        kSegmentationPlatformServiceBridgeKey,
        std::make_unique<SegmentationPlatformServiceAndroid>(service));
  }

  SegmentationPlatformServiceAndroid* bridge =
      static_cast<SegmentationPlatformServiceAndroid*>(
          service->GetUserData(kSegmentationPlatformServiceBridgeKey));

  return bridge->GetJavaObject();
}

SegmentationPlatformServiceAndroid::SegmentationPlatformServiceAndroid(
    SegmentationPlatformService* segmentation_platform_service)
    : segmentation_platform_service_(segmentation_platform_service) {
  DCHECK(segmentation_platform_service_);
  JNIEnv* env = base::android::AttachCurrentThread();
  java_obj_.Reset(env, Java_SegmentationPlatformServiceImpl_create(
                           env, reinterpret_cast<int64_t>(this))
                           .obj());
}

SegmentationPlatformServiceAndroid::~SegmentationPlatformServiceAndroid() {
  JNIEnv* env = base::android::AttachCurrentThread();
  Java_SegmentationPlatformServiceImpl_clearNativePtr(env, java_obj_);
}

void SegmentationPlatformServiceAndroid::GetSelectedSegment(
    JNIEnv* env,
    const JavaParamRef<jobject>& jcaller,
    const JavaParamRef<jstring>& j_segmentation_key,
    const JavaParamRef<jobject>& jcallback) {
  segmentation_platform_service_->GetSelectedSegment(
      ConvertJavaStringToUTF8(env, j_segmentation_key),
      base::BindOnce(&RunGetSelectedSegmentCallback,
                     ScopedJavaGlobalRef<jobject>(jcallback)));
}

void SegmentationPlatformServiceAndroid::GetClassificationResult(
    JNIEnv* env,
    const JavaParamRef<jobject>& j_caller,
    const JavaParamRef<jstring>& j_segmentation_key,
    const JavaParamRef<jobject>& j_prediction_options,
    const JavaParamRef<jobject>& j_input_context,
    const JavaParamRef<jobject>& j_callback) {
  scoped_refptr<InputContext> native_input_context =
      InputContextAndroid::ToNativeInputContext(env, j_input_context);
  PredictionOptions native_prediction_options =
      PredictionOptionsAndroid::ToNativePredictionOptions(env,
                                                          j_prediction_options);

  segmentation_platform_service_->GetClassificationResult(
      ConvertJavaStringToUTF8(env, j_segmentation_key),
      native_prediction_options, native_input_context,
      base::BindOnce(&RunGetClassificationResultCallback,
                     ScopedJavaGlobalRef<jobject>(j_callback)));
}

ScopedJavaLocalRef<jobject>
SegmentationPlatformServiceAndroid::GetCachedSegmentResult(
    JNIEnv* env,
    const JavaParamRef<jobject>& jcaller,
    const JavaParamRef<jstring>& j_segmentation_key) {
  return SegmentationPlatformConversionBridge::CreateJavaSegmentSelectionResult(
      env, segmentation_platform_service_->GetCachedSegmentResult(
               ConvertJavaStringToUTF8(env, j_segmentation_key)));
}

ScopedJavaLocalRef<jobject>
SegmentationPlatformServiceAndroid::GetJavaObject() {
  return ScopedJavaLocalRef<jobject>(java_obj_);
}

}  // namespace segmentation_platform