chromium/chrome/browser/segmentation_platform/segmentation_platform_service_factory_unittest.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/segmentation_platform/segmentation_platform_service_factory.h"

#include <memory>
#include <string_view>

#include "base/functional/bind.h"
#include "base/memory/scoped_refptr.h"
#include "base/test/scoped_command_line.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/simple_test_clock.h"
#include "base/test/task_environment.h"
#include "chrome/browser/commerce/shopping_service_factory.h"
#include "chrome/browser/segmentation_platform/ukm_data_manager_test_utils.h"
#include "chrome/test/base/testing_profile.h"
#include "components/commerce/core/mock_shopping_service.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/prefs/pref_change_registrar.h"
#include "components/prefs/pref_observer.h"
#include "components/prefs/pref_service.h"
#include "components/segmentation_platform/embedder/default_model/contextual_page_actions_model.h"
#include "components/segmentation_platform/embedder/default_model/metrics_clustering.h"
#include "components/segmentation_platform/embedder/default_model/most_visited_tiles_user.h"
#include "components/segmentation_platform/embedder/home_modules/ephemeral_home_module_backend.h"
#include "components/segmentation_platform/internal/constants.h"
#include "components/segmentation_platform/internal/database/client_result_prefs.h"
#include "components/segmentation_platform/internal/segmentation_ukm_helper.h"
#include "components/segmentation_platform/public/constants.h"
#include "components/segmentation_platform/public/features.h"
#include "components/segmentation_platform/public/prediction_options.h"
#include "components/segmentation_platform/public/result.h"
#include "components/segmentation_platform/public/segmentation_platform_service.h"
#include "components/segmentation_platform/public/service_proxy.h"
#include "components/ukm/test_ukm_recorder.h"
#include "components/visited_url_ranking/public/test_support.h"
#include "components/visited_url_ranking/public/url_visit_schema.h"
#include "components/visited_url_ranking/public/url_visit_util.h"
#include "content/public/test/browser_task_environment.h"
#include "services/metrics/public/cpp/ukm_builders.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "url/gurl.h"

namespace segmentation_platform {
namespace {

Segmentation_ModelExecutionUkmRecorder;

// Observer that waits for service_ initialization.
class WaitServiceInitializedObserver : public ServiceProxy::Observer {};

class SegmentationPlatformServiceFactoryTest : public testing::Test {};

TEST_F(SegmentationPlatformServiceFactoryTest, TestPasswordManagerUserSegment) {}

// Segmentation Ukm Engine is disabled on CrOS.
#if !BUILDFLAG(IS_CHROMEOS)
TEST_F(SegmentationPlatformServiceFactoryTest, TestSearchUserModel) {}
#endif  //! BUILDFLAG(IS_CHROMEOS)

TEST_F(SegmentationPlatformServiceFactoryTest, TestShoppingUserModel) {}

TEST_F(SegmentationPlatformServiceFactoryTest, TestResumeHeavyUserModel) {}

TEST_F(SegmentationPlatformServiceFactoryTest, TestLowUserEngagementModel) {}

TEST_F(SegmentationPlatformServiceFactoryTest, TestCrossDeviceModel) {}

TEST_F(SegmentationPlatformServiceFactoryTest, TestDeviceSwitcherModel) {}

TEST_F(SegmentationPlatformServiceFactoryTest, URLVisitResumptionRanker) {}

// Segmentation Ukm Engine is disabled on CrOS.
#if !BUILDFLAG(IS_CHROMEOS)
TEST_F(SegmentationPlatformServiceFactoryTest, TabResupmtionRanker) {}
#endif  //! BUILDFLAG(IS_CHROMEOS)

TEST_F(SegmentationPlatformServiceFactoryTest, MetricsClustering) {}

#if BUILDFLAG(IS_ANDROID)
// Tests for models in android platform.
TEST_F(SegmentationPlatformServiceFactoryTest, TestDeviceTierSegment) {
  InitServiceAndCacheResults(kDeviceTierKey);

  PredictionOptions prediction_options;

  ExpectGetClassificationResult(
      kDeviceTierKey, prediction_options, nullptr,
      /*expected_status=*/PredictionStatus::kSucceeded,
      /*expected_labels=*/std::nullopt);
}

TEST_F(SegmentationPlatformServiceFactoryTest,
       TestTabletProductivityUserModel) {
  InitServiceAndCacheResults(kTabletProductivityUserKey);

  PredictionOptions prediction_options;

  ExpectGetClassificationResult(
      kTabletProductivityUserKey, prediction_options, nullptr,
      /*expected_status=*/PredictionStatus::kSucceeded,
      /*expected_labels=*/
      std::vector<std::string>(1, kTabletProductivityUserModelLabelNone));
}

TEST_F(SegmentationPlatformServiceFactoryTest, TestContextualPageActionsShare) {
  InitService();

  PredictionOptions prediction_options;
  prediction_options.on_demand_execution = true;

  auto input_context = base::MakeRefCounted<InputContext>();
  input_context->metadata_args.emplace(
      segmentation_platform::kContextualPageActionModelInputPriceInsights,
      segmentation_platform::processing::ProcessedValue::FromFloat(1));
  input_context->metadata_args.emplace(
      segmentation_platform::kContextualPageActionModelInputPriceTracking,
      segmentation_platform::processing::ProcessedValue::FromFloat(0));
  input_context->metadata_args.emplace(
      segmentation_platform::kContextualPageActionModelInputReaderMode,
      segmentation_platform::processing::ProcessedValue::FromFloat(0));

  ExpectGetClassificationResult(
      kContextualPageActionsKey, prediction_options, input_context,
      /*expected_status=*/PredictionStatus::kSucceeded,
      /*expected_labels=*/
      std::vector<std::string>(1,
                               kContextualPageActionModelLabelPriceInsights));
  clock()->Advance(base::Seconds(
      ContextualPageActionsModel::kShareOutputCollectionDelayInSec));

  WaitAndCheckUkmRecord(
      proto::OPTIMIZATION_TARGET_CONTEXTUAL_PAGE_ACTION_PRICE_TRACKING,
      /*inputs=*/
      {SegmentationUkmHelper::FloatToInt64(1.f), 0, 0, 0, 0, 0, 0, 0},
      /*outputs=*/{0, 0, 0, 0, 0, 0});
}

TEST_F(SegmentationPlatformServiceFactoryTest, TestFrequentFeatureModel) {
  InitServiceAndCacheResults(kFrequentFeatureUserKey);

  PredictionOptions prediction_options;

  ExpectGetClassificationResult(
      kFrequentFeatureUserKey, prediction_options, nullptr,
      /*expected_status=*/PredictionStatus::kSucceeded,
      /*expected_labels=*/
      std::vector<std::string>{kLegacyNegativeLabel});
}

TEST_F(SegmentationPlatformServiceFactoryTest, TestIntentionalUserModel) {
  InitServiceAndCacheResults(segmentation_platform::kIntentionalUserKey);

  segmentation_platform::PredictionOptions prediction_options;

  ExpectGetClassificationResult(
      segmentation_platform::kIntentionalUserKey, prediction_options, nullptr,
      /*expected_status=*/segmentation_platform::PredictionStatus::kSucceeded,
      /*expected_labels=*/
      std::vector<std::string>(1, kLegacyNegativeLabel));
}

TEST_F(SegmentationPlatformServiceFactoryTest, TestPowerUserSegment) {
  InitServiceAndCacheResults(kPowerUserKey);

  PredictionOptions prediction_options;

  ExpectGetClassificationResult(
      kPowerUserKey, prediction_options, nullptr,
      /*expected_status=*/PredictionStatus::kSucceeded,
      /*expected_labels=*/
      std::vector<std::string>{"None"});
}

TEST_F(SegmentationPlatformServiceFactoryTest, MostVisitedTilesUser) {
  InitServiceAndCacheResults(
      segmentation_platform::MostVisitedTilesUser::kMostVisitedTilesUserKey);

  segmentation_platform::PredictionOptions prediction_options;

  ExpectGetClassificationResult(
      segmentation_platform::MostVisitedTilesUser::kMostVisitedTilesUserKey,
      prediction_options, nullptr,
      /*expected_status=*/segmentation_platform::PredictionStatus::kSucceeded,
      /*expected_labels=*/
      std::vector<std::string>(1, "None"));
}

TEST_F(SegmentationPlatformServiceFactoryTest, TestFeedUserModel) {
  InitServiceAndCacheResults(segmentation_platform::kFeedUserSegmentationKey);
  segmentation_platform::PredictionOptions prediction_options;

  ExpectGetClassificationResult(
      segmentation_platform::kFeedUserSegmentationKey, prediction_options,
      nullptr,
      /*expected_status=*/segmentation_platform::PredictionStatus::kSucceeded,
      /*expected_labels=*/
      std::vector<std::string>(1, kLegacyNegativeLabel));
}

TEST_F(SegmentationPlatformServiceFactoryTest, TestAndroidHomeModuleRanker) {
  InitService();
  PredictionOptions prediction_options;
  prediction_options.on_demand_execution = true;

  auto input_context = base::MakeRefCounted<InputContext>();
  input_context->metadata_args.emplace(
      segmentation_platform::kSingleTabFreshness,
      segmentation_platform::processing::ProcessedValue::FromFloat(-1));
  input_context->metadata_args.emplace(
      segmentation_platform::kPriceChangeFreshness,
      segmentation_platform::processing::ProcessedValue::FromFloat(-1));
  input_context->metadata_args.emplace(
      segmentation_platform::kTabResumptionForAndroidHomeFreshness,
      segmentation_platform::processing::ProcessedValue::FromFloat(-1));
  input_context->metadata_args.emplace(
      segmentation_platform::kSafetyHubFreshness,
      segmentation_platform::processing::ProcessedValue::FromFloat(-1));

  std::vector<std::string> result = {kPriceChange, kSingleTab,
                                     kTabResumptionForAndroidHome, kSafetyHub};
  ExpectGetClassificationResult(
      segmentation_platform::kAndroidHomeModuleRankerKey, prediction_options,
      input_context,
      /*expected_status=*/segmentation_platform::PredictionStatus::kSucceeded,
      /*expected_labels=*/result);
}

#endif  // BUILDFLAG(IS_ANDROID)

TEST_F(SegmentationPlatformServiceFactoryTest, EphemeralHomeMdouleBackend) {}

}  // namespace
}  // namespace segmentation_platform