// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ui/android/toolbar/adaptive_toolbar_bridge.h"
#include "base/android/callback_android.h"
#include "base/android/jni_array.h"
#include "base/android/scoped_java_ref.h"
#include "base/no_destructor.h"
#include "chrome/browser/profiles/profile.h"
#include "chrome/browser/segmentation_platform/segmentation_platform_service_factory.h"
#include "chrome/browser/ui/android/toolbar/adaptive_toolbar_enums.h"
#include "components/segmentation_platform/public/android/segmentation_platform_conversion_bridge.h"
#include "components/segmentation_platform/public/constants.h"
#include "components/segmentation_platform/public/features.h"
#include "components/segmentation_platform/public/input_context.h"
#include "components/segmentation_platform/public/segmentation_platform_service.h"
// Must come after all headers that specialize FromJniType() / ToJniType().
#include "chrome/browser/ui/android/toolbar/jni_headers/AdaptiveToolbarBridge_jni.h"
using base::android::AttachCurrentThread;
using base::android::JavaParamRef;
using base::android::JavaRef;
using base::android::ScopedJavaLocalRef;
using segmentation_platform::InputContext;
namespace {
std::map<std::string, AdaptiveToolbarButtonVariant> GetEnumLabelMapping() {
static base::NoDestructor<std::map<std::string, AdaptiveToolbarButtonVariant>>
enum_label_mapping(
{{
segmentation_platform::kAdaptiveToolbarModelLabelNewTab,
AdaptiveToolbarButtonVariant::kNewTab,
},
{
segmentation_platform::kAdaptiveToolbarModelLabelShare,
AdaptiveToolbarButtonVariant::kShare,
},
{
segmentation_platform::kAdaptiveToolbarModelLabelVoice,
AdaptiveToolbarButtonVariant::kVoice,
},
{
segmentation_platform::kAdaptiveToolbarModelLabelTranslate,
AdaptiveToolbarButtonVariant::kTranslate,
},
{
segmentation_platform::kAdaptiveToolbarModelLabelAddToBookmarks,
AdaptiveToolbarButtonVariant::kAddToBookmarks,
},
{
segmentation_platform::kAdaptiveToolbarModelLabelReadAloud,
AdaptiveToolbarButtonVariant::kReadAloud,
}});
return *enum_label_mapping;
}
AdaptiveToolbarButtonVariant ActionLabelToAdaptiveToolbarButtonVariant(
const std::string& label) {
std::map<std::string, AdaptiveToolbarButtonVariant> label_enum_mapping =
GetEnumLabelMapping();
if (label_enum_mapping.contains(label)) {
return label_enum_mapping.at(label);
}
return AdaptiveToolbarButtonVariant::kUnknown;
}
void RunGetSelectedSegmentCallback(
const JavaRef<jobject>& j_callback,
const segmentation_platform::SegmentSelectionResult& result) {
AdaptiveToolbarButtonVariant button_variant =
AdaptiveToolbarButtonVariant::kUnknown;
if (result.segment.has_value()) {
switch (result.segment.value()) {
case segmentation_platform::proto::SegmentId::
OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
button_variant = AdaptiveToolbarButtonVariant::kNewTab;
break;
case segmentation_platform::proto::SegmentId::
OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
button_variant = AdaptiveToolbarButtonVariant::kShare;
break;
case segmentation_platform::proto::SegmentId::
OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
button_variant = AdaptiveToolbarButtonVariant::kVoice;
break;
case segmentation_platform::proto::SegmentId::OPTIMIZATION_TARGET_UNKNOWN:
button_variant = AdaptiveToolbarButtonVariant::kUnknown;
break;
default:
NOTREACHED_IN_MIGRATION();
}
}
ScopedJavaLocalRef<jobject> j_result =
Java_AdaptiveToolbarBridge_createResult(
base::android::AttachCurrentThread(), result.is_ready,
static_cast<int32_t>(button_variant));
base::android::RunObjectCallbackAndroid(j_callback, j_result);
}
void RunGetClassificationSingleResultCallback(
const base::android::JavaRef<jobject>& j_callback,
const segmentation_platform::ClassificationResult& result) {
std::string button_to_show =
result.ordered_labels.empty() ? "" : result.ordered_labels[0];
bool is_ready =
result.status == segmentation_platform::PredictionStatus::kSucceeded;
int button_variant = static_cast<int32_t>(
ActionLabelToAdaptiveToolbarButtonVariant(button_to_show));
ScopedJavaLocalRef<jobject> j_result =
Java_AdaptiveToolbarBridge_createResult(
base::android::AttachCurrentThread(), is_ready, button_variant);
base::android::RunObjectCallbackAndroid(j_callback, j_result);
}
void RunGetClassificationMultipleResultCallback(
base::OnceCallback<void(bool, std::vector<int>)> callback,
const segmentation_platform::ClassificationResult& result) {
std::vector<int> ranked_buttons;
bool is_ready =
result.status == segmentation_platform::PredictionStatus::kSucceeded;
for (std::string label : result.ordered_labels) {
ranked_buttons.emplace_back(
static_cast<int32_t>(ActionLabelToAdaptiveToolbarButtonVariant(label)));
}
if (ranked_buttons.empty()) {
ranked_buttons.emplace_back(
static_cast<int32_t>(AdaptiveToolbarButtonVariant::kUnknown));
}
std::move(callback).Run(is_ready, ranked_buttons);
}
void RunGetAnnotatedNumericResultCallback(
base::OnceCallback<void(bool, std::vector<int>)> callback,
const segmentation_platform::AnnotatedNumericResult& result) {
bool is_ready =
result.status == segmentation_platform::PredictionStatus::kSucceeded;
std::map<std::string, AdaptiveToolbarButtonVariant> enum_label_mapping =
GetEnumLabelMapping();
// Map that sorts elements with largest first.
std::multimap<float, AdaptiveToolbarButtonVariant, std::greater<>>
sorted_button_scores;
for (std::pair<std::string, AdaptiveToolbarButtonVariant> button :
enum_label_mapping) {
std::optional<float> score_for_button =
result.GetResultForLabel(button.first);
if (score_for_button.has_value()) {
sorted_button_scores.emplace(score_for_button.value(), button.second);
}
}
std::vector<int> sorted_buttons;
for (std::pair<float, AdaptiveToolbarButtonVariant> score_button :
sorted_button_scores) {
sorted_buttons.emplace_back(static_cast<int32_t>(score_button.second));
}
if (sorted_buttons.empty()) {
sorted_buttons.emplace_back(
static_cast<int32_t>(AdaptiveToolbarButtonVariant::kUnknown));
}
std::move(callback).Run(is_ready, sorted_buttons);
}
void RunJavaCallbackWithRankedButtons(
const base::android::JavaRef<jobject>& j_callback,
bool is_ready,
std::vector<int> ranked_buttons) {
ScopedJavaLocalRef<jintArray> java_ranked_buttons =
base::android::ToJavaIntArray(base::android::AttachCurrentThread(),
ranked_buttons);
ScopedJavaLocalRef<jobject> j_result =
Java_AdaptiveToolbarBridge_createResultList(
base::android::AttachCurrentThread(), is_ready, java_ranked_buttons);
base::android::RunObjectCallbackAndroid(j_callback, j_result);
}
} // namespace
void JNI_AdaptiveToolbarBridge_GetRankedSessionVariantButtons(
JNIEnv* env,
Profile* profile,
jboolean j_use_raw_results,
const JavaParamRef<jobject>& j_callback) {
bool use_raw_results = static_cast<bool>(j_use_raw_results);
base::OnceCallback<void(bool, std::vector<int>)> wrapped_callback =
base::BindOnce(&RunJavaCallbackWithRankedButtons,
base::android::ScopedJavaGlobalRef<jobject>(j_callback));
adaptive_toolbar::GetRankedSessionVariantButtons(profile, use_raw_results,
std::move(wrapped_callback));
}
void JNI_AdaptiveToolbarBridge_GetSessionVariantButton(
JNIEnv* env,
Profile* profile,
const JavaParamRef<jobject>& j_callback) {
if (!profile) {
RunGetClassificationSingleResultCallback(
j_callback, segmentation_platform::ClassificationResult(
segmentation_platform::PredictionStatus::kFailed));
return;
}
segmentation_platform::SegmentationPlatformService*
segmentation_platform_service = segmentation_platform::
SegmentationPlatformServiceFactory::GetForProfile(profile);
if (!segmentation_platform_service) {
RunGetClassificationSingleResultCallback(
j_callback, segmentation_platform::ClassificationResult(
segmentation_platform::PredictionStatus::kFailed));
return;
}
bool use_multi_output = base::FeatureList::IsEnabled(
segmentation_platform::features::
kSegmentationPlatformAdaptiveToolbarV2Feature);
if (use_multi_output) {
segmentation_platform_service->GetClassificationResult(
segmentation_platform::kAdaptiveToolbarSegmentationKey,
segmentation_platform::PredictionOptions(),
base::MakeRefCounted<segmentation_platform::InputContext>(),
base::BindOnce(
&RunGetClassificationSingleResultCallback,
base::android::ScopedJavaGlobalRef<jobject>(j_callback)));
} else {
segmentation_platform_service->GetSelectedSegment(
segmentation_platform::kAdaptiveToolbarSegmentationKey,
base::BindOnce(
&RunGetSelectedSegmentCallback,
base::android::ScopedJavaGlobalRef<jobject>(j_callback)));
}
}
namespace adaptive_toolbar {
// This method retrieves a list of toolbar buttons ranked by priority, only one
// button can be shown, but we return a list so we can try other options in case
// the top one is not available in the current UI (e.g. tablets already have a
// bookmark button, so we don't show it here).
// This list is retrieved from segmentation platform, which:
// 1) Runs an ML model which returns a score for each button.
// 2) Applies thresholds to filter out buttons with low scores.
// 3) Returns a list sorted by score.
// The current model's low score thresholds are set to use new tab as the
// default option, so the other ones get often filtered out, this is a problem
// in tablets because that button is not supported. The |use_raw_results| option
// uses a segmentation platform API that skips step 2, so we can use the
// unfiltered scores on tablets.
void GetRankedSessionVariantButtons(
Profile* profile,
bool use_raw_results,
base::OnceCallback<void(bool, std::vector<int>)> callback) {
if (!profile) {
std::move(callback).Run(false, std::vector<int>());
return;
}
segmentation_platform::SegmentationPlatformService*
segmentation_platform_service = segmentation_platform::
SegmentationPlatformServiceFactory::GetForProfile(profile);
if (!segmentation_platform_service) {
std::move(callback).Run(false, std::vector<int>());
return;
}
if (use_raw_results) {
segmentation_platform_service->GetAnnotatedNumericResult(
segmentation_platform::kAdaptiveToolbarSegmentationKey,
segmentation_platform::PredictionOptions(),
base::MakeRefCounted<segmentation_platform::InputContext>(),
base::BindOnce(&RunGetAnnotatedNumericResultCallback,
std::move(callback)));
} else {
segmentation_platform_service->GetClassificationResult(
segmentation_platform::kAdaptiveToolbarSegmentationKey,
segmentation_platform::PredictionOptions(),
base::MakeRefCounted<segmentation_platform::InputContext>(),
base::BindOnce(&RunGetClassificationMultipleResultCallback,
std::move(callback)));
}
}
} // namespace adaptive_toolbar