chromium/ios/chrome/browser/segmentation_platform/model/segmentation_platform_service_factory_unittest.mm

// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#import "ios/chrome/browser/segmentation_platform/model/segmentation_platform_service_factory.h"

#import "base/functional/bind.h"
#import "base/memory/scoped_refptr.h"
#import "base/test/scoped_command_line.h"
#import "base/test/scoped_feature_list.h"
#import "base/test/task_environment.h"
#import "components/optimization_guide/core/optimization_guide_features.h"
#import "components/prefs/pref_change_registrar.h"
#import "components/prefs/pref_service.h"
#import "components/segmentation_platform/internal/constants.h"
#import "components/segmentation_platform/internal/database/client_result_prefs.h"
#import "components/segmentation_platform/public/constants.h"
#import "components/segmentation_platform/public/features.h"
#import "components/segmentation_platform/public/prediction_options.h"
#import "components/segmentation_platform/public/result.h"
#import "components/segmentation_platform/public/segment_selection_result.h"
#import "components/segmentation_platform/public/segmentation_platform_service.h"
#import "components/segmentation_platform/public/service_proxy.h"
#import "components/ukm/test_ukm_recorder.h"
#import "ios/chrome/browser/segmentation_platform/model/ukm_data_manager_test_utils.h"
#import "ios/chrome/browser/shared/model/profile/test/test_profile_ios.h"
#import "ios/web/public/test/web_task_environment.h"
#import "testing/gtest/include/gtest/gtest.h"
#import "testing/platform_test.h"

namespace segmentation_platform {
namespace {

// Observer that waits for service initialization.
class WaitServiceInitializedObserver : public ServiceProxy::Observer {
 public:
  explicit WaitServiceInitializedObserver(base::OnceClosure closure)
      : closure_(std::move(closure)) {}
  void OnServiceStatusChanged(bool initialized, int status_flags) override {
    if (initialized) {
      std::move(closure_).Run();
    }
  }

 private:
  base::OnceClosure closure_;
};

}  // namespace
class SegmentationPlatformServiceFactoryTest : public PlatformTest {
 public:
  SegmentationPlatformServiceFactoryTest()
      : test_utils_(std::make_unique<UkmDataManagerTestUtils>(&ukm_recorder_)) {
    // TODO(b/293500507): Create a base class for testing default models.
    scoped_feature_list_.InitWithFeaturesAndParameters(
        {{optimization_guide::features::kOptimizationTargetPrediction, {}},
         {features::kSegmentationPlatformFeature, {}},
         {features::kSegmentationPlatformUkmEngine, {}},
         {features::kContextualPageActionShareModel, {}}},
        {});
    scoped_command_line_.GetProcessCommandLine()->AppendSwitch(
        kSegmentationPlatformRefreshResultsSwitch);
    scoped_command_line_.GetProcessCommandLine()->AppendSwitch(
        kSegmentationPlatformDisableModelExecutionDelaySwitch);
  }
  ~SegmentationPlatformServiceFactoryTest() override = default;

  void SetUp() override {
    PlatformTest::SetUp();

    test_utils_->PreProfileInit({});
    profile_ = std::make_unique<ProfileData>(test_utils_.get(), "");
    WaitForServiceInit();

    ChromeBrowserState* otr_browser_state =
        profile_->browser_state
            ->CreateOffTheRecordBrowserStateWithTestingFactories(
                {TestChromeBrowserState::TestingFactory{
                    SegmentationPlatformServiceFactory::GetInstance(),
                    SegmentationPlatformServiceFactory::GetDefaultFactory()}});
    ASSERT_FALSE(SegmentationPlatformServiceFactory::GetForBrowserState(
        otr_browser_state));
  }

  void TearDown() override {
    web_task_env_.RunUntilIdle();
    profile_.reset();
    test_utils_.reset();
  }

  void InitServiceAndCacheResults(const std::string& segmentation_key) {
    WaitForServiceInit();
    WaitForClientResultPrefUpdate(segmentation_key);
    const std::string output = profile_->browser_state->GetPrefs()->GetString(
        kSegmentationClientResultPrefs);

    // TODO(b/297091996): Remove this when leak is fixed.
    web_task_env_.RunUntilIdle();

    profile_.reset();

    // Creating profile and initialising segmentation service again with prefs
    // from the last session.
    profile_ = std::make_unique<ProfileData>(test_utils_.get(), output);
    // Copying the prefs from last session.
    WaitForServiceInit();
    // TODO(b/297091996): Remove this when leak is fixed.
    web_task_env_.RunUntilIdle();
  }

  bool HasClientResultPref(const std::string& segmentation_key) {
    PrefService* pref_service_ = profile_->browser_state->GetPrefs();
    std::unique_ptr<ClientResultPrefs> result_prefs_ =
        std::make_unique<ClientResultPrefs>(pref_service_);
    return result_prefs_->ReadClientResultFromPrefs(segmentation_key) !=
           nullptr;
  }

  void OnClientResultPrefUpdated(const std::string& segmentation_key) {
    if (!wait_for_pref_callback_.is_null() &&
        HasClientResultPref(segmentation_key)) {
      std::move(wait_for_pref_callback_).Run();
    }
  }

  void WaitForClientResultPrefUpdate(const std::string& segmentation_key) {
    if (HasClientResultPref(segmentation_key)) {
      return;
    }

    base::RunLoop wait_for_pref;
    wait_for_pref_callback_ = wait_for_pref.QuitClosure();
    pref_registrar_.Init(profile_->browser_state->GetPrefs());
    pref_registrar_.Add(
        kSegmentationClientResultPrefs,
        base::BindRepeating(
            &SegmentationPlatformServiceFactoryTest::OnClientResultPrefUpdated,
            base::Unretained(this), segmentation_key));
    wait_for_pref.Run();

    pref_registrar_.RemoveAll();
  }

 protected:
  struct ProfileData {
    explicit ProfileData(UkmDataManagerTestUtils* test_utils,
                         const std::string& result_pref)
        : test_utils(test_utils) {
      TestChromeBrowserState::Builder builder;
      builder.AddTestingFactory(
          SegmentationPlatformServiceFactory::GetInstance(),
          SegmentationPlatformServiceFactory::GetDefaultFactory());
      browser_state = std::move(builder).Build();

      browser_state->GetPrefs()->SetString(kSegmentationClientResultPrefs,
                                           result_pref);
      test_utils->SetupForProfile(browser_state.get());
      service = SegmentationPlatformServiceFactory::GetForBrowserState(
          browser_state.get());
    }

    ~ProfileData() { test_utils->WillDestroyProfile(browser_state.get()); }

    ProfileData(ProfileData&) = delete;

    const raw_ptr<UkmDataManagerTestUtils> test_utils;
    std::unique_ptr<TestChromeBrowserState> browser_state;
    raw_ptr<SegmentationPlatformService> service;
  };

  void WaitForServiceInit() {
    if (profile_->service->IsPlatformInitialized()) {
      return;
    }
    base::RunLoop wait_for_init;
    WaitServiceInitializedObserver wait_observer(wait_for_init.QuitClosure());
    profile_->service->GetServiceProxy()->AddObserver(&wait_observer);

    wait_for_init.Run();
    while (!profile_->service->IsPlatformInitialized()) {
      base::RunLoop().RunUntilIdle();
    }

    profile_->service->GetServiceProxy()->RemoveObserver(&wait_observer);
  }

  void ExpectGetClassificationResult(
      const std::string& segmentation_key,
      const PredictionOptions& prediction_options,
      scoped_refptr<InputContext> input_context,
      PredictionStatus expected_status,
      std::optional<std::vector<std::string>> expected_labels) {
    base::RunLoop loop;
    profile_->service->GetClassificationResult(
        segmentation_key, prediction_options, input_context,
        base::BindOnce(
            &SegmentationPlatformServiceFactoryTest::OnGetClassificationResult,
            base::Unretained(this), loop.QuitClosure(), expected_status,
            expected_labels));
    loop.Run();
  }

  void OnGetClassificationResult(
      base::RepeatingClosure closure,
      PredictionStatus expected_status,
      std::optional<std::vector<std::string>> expected_labels,
      const ClassificationResult& actual_result) {
    EXPECT_EQ(actual_result.status, expected_status);
    if (expected_labels.has_value()) {
      EXPECT_EQ(actual_result.ordered_labels, expected_labels.value());
    }
    std::move(closure).Run();
  }

  base::test::ScopedFeatureList scoped_feature_list_;
  web::WebTaskEnvironment web_task_env_;
  base::test::ScopedCommandLine scoped_command_line_;
  ukm::TestUkmRecorder ukm_recorder_;
  PrefChangeRegistrar pref_registrar_;
  base::OnceClosure wait_for_pref_callback_;

  std::unique_ptr<UkmDataManagerTestUtils> test_utils_;
  std::unique_ptr<ProfileData> profile_;
};

TEST_F(SegmentationPlatformServiceFactoryTest, Test) {
  // TODO(crbug.com/40227968): Add test for the API once the initialization is
  // fixed.
}

TEST_F(SegmentationPlatformServiceFactoryTest, TestSearchUserModel) {
  InitServiceAndCacheResults(kSearchUserKey);

  PredictionOptions prediction_options;

  ExpectGetClassificationResult(
      kSearchUserKey, prediction_options, nullptr,
      /*expected_status=*/PredictionStatus::kSucceeded,
      /*expected_labels=*/
      std::vector<std::string>(1, kSearchUserModelLabelNone));
}

TEST_F(SegmentationPlatformServiceFactoryTest, TestIosModuleRankerModel) {
  segmentation_platform::PredictionOptions prediction_options;
  prediction_options.on_demand_execution = true;

  auto input_context =
      base::MakeRefCounted<segmentation_platform::InputContext>();
  int mvt_freshness_impression_count = -1;
  int shortcuts_freshness_impression_count = -1;
  int safety_check_freshness_impression_count = -1;
  int tab_resumption_freshness_impression_count = -1;
  int parcel_tracking_freshness_impression_count = -1;

  input_context->metadata_args.emplace(
      segmentation_platform::kMostVisitedTilesFreshness,
      segmentation_platform::processing::ProcessedValue::FromFloat(
          mvt_freshness_impression_count));
  input_context->metadata_args.emplace(
      segmentation_platform::kShortcutsFreshness,
      segmentation_platform::processing::ProcessedValue::FromFloat(
          shortcuts_freshness_impression_count));
  input_context->metadata_args.emplace(
      segmentation_platform::kSafetyCheckFreshness,
      segmentation_platform::processing::ProcessedValue::FromFloat(
          safety_check_freshness_impression_count));
  input_context->metadata_args.emplace(
      segmentation_platform::kTabResumptionFreshness,
      segmentation_platform::processing::ProcessedValue::FromFloat(
          tab_resumption_freshness_impression_count));
  input_context->metadata_args.emplace(
      segmentation_platform::kParcelTrackingFreshness,
      segmentation_platform::processing::ProcessedValue::FromFloat(
          parcel_tracking_freshness_impression_count));

  ExpectGetClassificationResult(
      segmentation_platform::kIosModuleRankerKey, prediction_options,
      input_context, PredictionStatus::kSucceeded,
      std::vector<std::string>{"MostVisitedTiles", "Shortcuts", "SafetyCheck",
                               "TabResumption", "ParcelTracking"});
}

}  // namespace segmentation_platform