chromium/chrome/browser/segmentation_platform/client_util/tab_data_collection_util_unittest.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/segmentation_platform/client_util/tab_data_collection_util.h"

#include <memory>

#include "base/metrics/user_metrics.h"
#include "base/metrics/user_metrics_action.h"
#include "base/run_loop.h"
#include "base/task/single_thread_task_runner.h"
#include "base/test/gmock_callback_support.h"
#include "chrome/browser/ui/android/tab_model/tab_model.h"
#include "chrome/browser/ui/android/tab_model/tab_model_list.h"
#include "chrome/browser/ui/android/tab_model/tab_model_observer.h"
#include "chrome/browser/ui/android/tab_model/tab_model_test_helper.h"
#include "chrome/test/base/testing_profile.h"
#include "components/segmentation_platform/embedder/input_delegate/tab_rank_dispatcher.h"
#include "components/segmentation_platform/embedder/tab_fetcher.h"
#include "components/segmentation_platform/public/constants.h"
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
#include "components/segmentation_platform/public/segmentation_platform_service.h"
#include "components/segmentation_platform/public/testing/mock_segmentation_platform_service.h"
#include "components/segmentation_platform/public/trigger.h"
#include "components/sessions/core/session_id.h"
#include "content/public/test/browser_task_environment.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace segmentation_platform {
namespace {

using ::base::test::RunOnceCallback;
using ::testing::_;
using ::testing::Return;

class MockRankDispatcher : public TabRankDispatcher {
 public:
  MockRankDispatcher(SegmentationPlatformService* service,
                     sync_sessions::SessionSyncService* session_sync_service,
                     std::unique_ptr<TabFetcher> fetcher)
      : TabRankDispatcher(service, session_sync_service, std::move(fetcher)) {}

  MOCK_METHOD(void,
              GetTopRankedTabs,
              (const std::string& segmentation_key,
               const TabFilter& tab_filter,
               RankedTabsCallback callback),
              (override));
};

class MockSessionSyncService : public sync_sessions::SessionSyncService {
 public:
  MockSessionSyncService() = default;
  ~MockSessionSyncService() override = default;

  MOCK_METHOD(syncer::GlobalIdMapper*,
              GetGlobalIdMapper,
              (),
              (const, override));
  MOCK_METHOD(sync_sessions::OpenTabsUIDelegate*,
              GetOpenTabsUIDelegate,
              (),
              (override));
  MOCK_METHOD(base::CallbackListSubscription,
              SubscribeToForeignSessionsChanged,
              (const base::RepeatingClosure& cb),
              (override));
  MOCK_METHOD(base::WeakPtr<syncer::DataTypeControllerDelegate>,
              GetControllerDelegate,
              ());
};

class MockTabFecther : public TabFetcher {
 public:
  explicit MockTabFecther(MockSessionSyncService* session_sync_service)
      : TabFetcher(session_sync_service) {}

  MOCK_METHOD(bool,
              FillAllLocalTabsFromTabModel,
              (std::vector<TabEntry> & tabs),
              (override));
  MOCK_METHOD(Tab, FindLocalTab, (const TabEntry& entry), (override));
};

}  // namespace

class TabDataCollectionUtilTest : public testing::Test {
 public:
  TabDataCollectionUtilTest() = default;
  ~TabDataCollectionUtilTest() override = default;

  void SetUp() override {
    Test::SetUp();
    base::SetRecordActionTaskRunner(
        base::SingleThreadTaskRunner::GetCurrentDefault());
    tab_model_ = std::make_unique<TestTabModel>(&profile_);
    segmentation_service_ = std::make_unique<MockSegmentationPlatformService>();
    auto tab_fetcher = std::make_unique<MockTabFecther>(&session_sync_service_);
    tab_fetcher_ = tab_fetcher.get();
    rank_dispatcher_ = std::make_unique<MockRankDispatcher>(
        segmentation_service_.get(), &session_sync_service_,
        std::move(tab_fetcher));
    collection_util_ = std::make_unique<TabDataCollectionUtil>(
        segmentation_service_.get(), rank_dispatcher_.get());
  }

  void TearDown() override { Test::TearDown(); }

 protected:
  content::BrowserTaskEnvironment task_environment_;
  TestingProfile profile_;
  std::unique_ptr<TestTabModel> tab_model_;
  MockSessionSyncService session_sync_service_;
  raw_ptr<MockTabFecther> tab_fetcher_;
  std::unique_ptr<MockSegmentationPlatformService> segmentation_service_;
  std::unique_ptr<MockRankDispatcher> rank_dispatcher_;
  std::unique_ptr<TabDataCollectionUtil> collection_util_;
};

TEST_F(TabDataCollectionUtilTest, AddRemoveTabModel) {
  ASSERT_FALSE(tab_model_->GetObserver());
  TabModelList::AddTabModel(tab_model_.get());
  ASSERT_TRUE(tab_model_->GetObserver());
  TabModelList::RemoveTabModel(tab_model_.get());
  ASSERT_FALSE(tab_model_->GetObserver());

  auto other_tab_model = std::make_unique<TestTabModel>(&profile_);
  TabModelList::AddTabModel(other_tab_model.get());
  ASSERT_FALSE(tab_model_->GetObserver());
  ASSERT_TRUE(other_tab_model->GetObserver());

  TabModelList::AddTabModel(tab_model_.get());
  ASSERT_TRUE(tab_model_->GetObserver());
  ASSERT_TRUE(other_tab_model->GetObserver());

  TabModelList::RemoveTabModel(tab_model_.get());
  TabModelList::RemoveTabModel(other_tab_model.get());
  ASSERT_FALSE(tab_model_->GetObserver());
  ASSERT_FALSE(other_tab_model->GetObserver());
}

TEST_F(TabDataCollectionUtilTest, RecordTrainingData) {
  TabModelList::AddTabModel(tab_model_.get());
  ASSERT_TRUE(tab_model_->GetObserver());

  TabAndroid* fake_tab_ptr = reinterpret_cast<TabAndroid*>(1);
  EXPECT_CALL(*tab_fetcher_, FindLocalTab(_))
      .WillOnce(Return(TabFetcher::Tab{.tab_android = fake_tab_ptr}));
  TrainingRequestId id1 = TrainingRequestId::FromUnsafeValue(5);
  TabFetcher::TabEntry entry(SessionID::NewUnique(), nullptr, fake_tab_ptr);
  TabRankDispatcher::RankedTab tab1{
      .tab = entry, .model_score = 0.5, .request_id = id1};
  std::multiset<TabRankDispatcher::RankedTab> tabs{tab1};
  EXPECT_CALL(*rank_dispatcher_,
              GetTopRankedTabs(kTabResumptionClassifierKey, _, _))
      .WillOnce(RunOnceCallback<2>(true, tabs));

  base::RunLoop wait_for_collection;
  EXPECT_CALL(*segmentation_service_,
              CollectTrainingData(proto::SegmentId::TAB_RESUMPTION_CLASSIFIER,
                                  id1, _, _))
      .WillOnce([&wait_for_collection](
                    proto::SegmentId segment_id, TrainingRequestId request_id,
                    const TrainingLabels& param,
                    SegmentationPlatformService::SuccessCallback callback) {
        wait_for_collection.QuitClosure().Run();
      });
  base::RecordAction(base::UserMetricsAction("MobileToolbarShowStackView"));
  tab_model_->GetObserver()->TabPendingClosure(fake_tab_ptr);
  wait_for_collection.Run();

  TabModelList::RemoveTabModel(tab_model_.get());
}

}  // namespace segmentation_platform