chromium/chrome/browser/ui/android/toolbar/adaptive_toolbar_bridge_unittest.cc

// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/ui/android/toolbar/adaptive_toolbar_bridge.h"

#include "base/run_loop.h"
#include "chrome/browser/segmentation_platform/segmentation_platform_service_factory.h"
#include "chrome/test/base/chrome_test_utils.h"
#include "chrome/test/base/testing_profile.h"
#include "components/segmentation_platform/public/constants.h"
#include "components/segmentation_platform/public/result.h"
#include "components/segmentation_platform/public/testing/mock_segmentation_platform_service.h"
#include "content/public/test/browser_task_environment.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"

using segmentation_platform::MockSegmentationPlatformService;
using testing::_;

namespace adaptive_toolbar {

class AdaptiveToolbarBridgeTest : public ::testing::Test {
 public:
  AdaptiveToolbarBridgeTest(const AdaptiveToolbarBridgeTest&) = delete;
  AdaptiveToolbarBridgeTest& operator=(const AdaptiveToolbarBridgeTest&) =
      delete;

  ~AdaptiveToolbarBridgeTest() override = default;

  void SetUp() override {
    TestingProfile::Builder builder;
    builder.AddTestingFactory(
        segmentation_platform::SegmentationPlatformServiceFactory::
            GetInstance(),
        base::BindRepeating([](content::BrowserContext* context)
                                -> std::unique_ptr<KeyedService> {
          return std::make_unique<MockSegmentationPlatformService>();
        }));

    profile_ = builder.Build();
  }

  void TearDown() override {
    // Clear default actions for safe teardown.
    testing::Mock::VerifyAndClear(&GetSegmentationPlatformService());
  }

  AdaptiveToolbarBridgeTest() = default;

  MockSegmentationPlatformService& GetSegmentationPlatformService() {
    return *static_cast<MockSegmentationPlatformService*>(
        segmentation_platform::SegmentationPlatformServiceFactory::
            GetForProfile(profile_.get()));
  }

  void BridgeCallback(bool is_ready, std::vector<int> ranked_buttons) {
    callback_buttons_ = ranked_buttons;
    callback_is_ready_ = is_ready;
    run_loop_.Quit();
  }

 protected:
  // Needed for TestingProfile::Builder.
  content::BrowserTaskEnvironment task_environment_;
  std::unique_ptr<TestingProfile> profile_;

  std::vector<int> callback_buttons_;
  bool callback_is_ready_;
  base::RunLoop run_loop_;
};

TEST_F(AdaptiveToolbarBridgeTest, GetRankedButtons) {
  Profile* profile = profile_.get();

  ON_CALL(GetSegmentationPlatformService(), GetClassificationResult(_, _, _, _))
      .WillByDefault(testing::WithArg<3>(testing::Invoke(
          [](segmentation_platform::ClassificationResultCallback callback) {
            auto result = segmentation_platform::ClassificationResult(
                segmentation_platform::PredictionStatus::kSucceeded);
            // Set segmentation to return a sorted list of labels.
            result.ordered_labels = {
                segmentation_platform::kAdaptiveToolbarModelLabelShare,
                segmentation_platform::kAdaptiveToolbarModelLabelAddToBookmarks,
                segmentation_platform::kAdaptiveToolbarModelLabelTranslate};
            base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
                FROM_HERE, base::BindOnce(std::move(callback), result));
          })));

  adaptive_toolbar::GetRankedSessionVariantButtons(
      profile, /* use_raw_results= */ false,
      base::BindOnce(&AdaptiveToolbarBridgeTest::BridgeCallback,
                     base::Unretained(this)));
  run_loop_.Run();

  EXPECT_TRUE(callback_is_ready_);
  // The returned enum values should match the order of the segmentation result.
  std::vector<int> expected_buttons = {
      static_cast<int>(AdaptiveToolbarButtonVariant::kShare),
      static_cast<int>(AdaptiveToolbarButtonVariant::kAddToBookmarks),
      static_cast<int>(AdaptiveToolbarButtonVariant::kTranslate)};
  EXPECT_EQ(callback_buttons_, expected_buttons);
}

TEST_F(AdaptiveToolbarBridgeTest, GetRankedButtons_NotReady) {
  Profile* profile = profile_.get();

  ON_CALL(GetSegmentationPlatformService(), GetClassificationResult(_, _, _, _))
      .WillByDefault(testing::WithArg<3>(testing::Invoke(
          [](segmentation_platform::ClassificationResultCallback callback) {
            // Set segmentation to return kNotReady.
            auto result = segmentation_platform::ClassificationResult(
                segmentation_platform::PredictionStatus::kNotReady);
            base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
                FROM_HERE, base::BindOnce(std::move(callback), result));
          })));

  adaptive_toolbar::GetRankedSessionVariantButtons(
      profile, /* use_raw_results= */ false,
      base::BindOnce(&AdaptiveToolbarBridgeTest::BridgeCallback,
                     base::Unretained(this)));
  run_loop_.Run();

  EXPECT_FALSE(callback_is_ready_);
  std::vector<int> expected_buttons = {
      static_cast<int>(AdaptiveToolbarButtonVariant::kUnknown),
  };
  EXPECT_EQ(callback_buttons_, expected_buttons);
}

TEST_F(AdaptiveToolbarBridgeTest, GetRankedButtons_RawResults) {
  Profile* profile = profile_.get();

  ON_CALL(GetSegmentationPlatformService(),
          GetAnnotatedNumericResult(_, _, _, _))
      .WillByDefault(testing::WithArg<3>(testing::Invoke(
          [](segmentation_platform::AnnotatedNumericResultCallback callback) {
            segmentation_platform::AnnotatedNumericResult result(
                segmentation_platform::PredictionStatus::kSucceeded);
            // Set segmentation to result an annotated numeric result, this
            // includes a list of labels on the model's config and a list of
            // scores. Both lists are parallel, the first score belongs to the
            // first label.
            auto* result_classifier = result.result.mutable_output_config()
                                          ->mutable_predictor()
                                          ->mutable_multi_class_classifier();

            result_classifier->add_class_labels(
                segmentation_platform::kAdaptiveToolbarModelLabelNewTab);
            result_classifier->add_class_labels(
                segmentation_platform::kAdaptiveToolbarModelLabelVoice);
            result_classifier->add_class_labels(
                segmentation_platform::kAdaptiveToolbarModelLabelShare);
            result_classifier->add_class_labels(
                segmentation_platform::kAdaptiveToolbarModelLabelTranslate);
            result_classifier->add_class_labels(
                segmentation_platform::
                    kAdaptiveToolbarModelLabelAddToBookmarks);

            result.result.add_result(0.2);
            result.result.add_result(0.3);
            result.result.add_result(0.7);
            result.result.add_result(0.9);
            result.result.add_result(0.4);

            base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
                FROM_HERE, base::BindOnce(std::move(callback), result));
          })));

  adaptive_toolbar::GetRankedSessionVariantButtons(
      profile, /* use_raw_results= */ true,
      base::BindOnce(&AdaptiveToolbarBridgeTest::BridgeCallback,
                     base::Unretained(this)));
  run_loop_.Run();

  EXPECT_TRUE(callback_is_ready_);
  // The returned enum values should be sorted by score.
  std::vector<int> expected_buttons = {
      static_cast<int>(AdaptiveToolbarButtonVariant::kTranslate),
      static_cast<int>(AdaptiveToolbarButtonVariant::kShare),
      static_cast<int>(AdaptiveToolbarButtonVariant::kAddToBookmarks),
      static_cast<int>(AdaptiveToolbarButtonVariant::kVoice),
      static_cast<int>(AdaptiveToolbarButtonVariant::kNewTab),
  };
  EXPECT_EQ(callback_buttons_, expected_buttons);
}

TEST_F(AdaptiveToolbarBridgeTest, GetRankedButtons_RawResults_NotReady) {
  Profile* profile = profile_.get();

  ON_CALL(GetSegmentationPlatformService(),
          GetAnnotatedNumericResult(_, _, _, _))
      .WillByDefault(testing::WithArg<3>(testing::Invoke(
          [](segmentation_platform::AnnotatedNumericResultCallback callback) {
            segmentation_platform::AnnotatedNumericResult result(
                segmentation_platform::PredictionStatus::kNotReady);

            base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
                FROM_HERE, base::BindOnce(std::move(callback), result));
          })));

  adaptive_toolbar::GetRankedSessionVariantButtons(
      profile, /* use_raw_results= */ true,
      base::BindOnce(&AdaptiveToolbarBridgeTest::BridgeCallback,
                     base::Unretained(this)));
  run_loop_.Run();

  EXPECT_FALSE(callback_is_ready_);
  std::vector<int> expected_buttons = {
      static_cast<int>(AdaptiveToolbarButtonVariant::kUnknown),
  };
  EXPECT_EQ(callback_buttons_, expected_buttons);
}

}  // namespace adaptive_toolbar