chromium/chrome/browser/segmentation_platform/client_util/tab_data_collection_util.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/segmentation_platform/client_util/tab_data_collection_util.h"

#include "base/memory/raw_ptr.h"
#include "base/metrics/user_metrics.h"
#include "base/scoped_observation.h"
#include "chrome/browser/android/tab_android.h"
#include "chrome/browser/ui/android/tab_model/tab_model_list.h"
#include "chrome/browser/ui/android/tab_model/tab_model_list_observer.h"
#include "chrome/browser/ui/android/tab_model/tab_model_observer.h"
#include "components/segmentation_platform/embedder/input_delegate/tab_rank_dispatcher.h"
#include "components/segmentation_platform/embedder/tab_fetcher.h"
#include "components/segmentation_platform/public/constants.h"
#include "components/segmentation_platform/public/segmentation_platform_service.h"

namespace segmentation_platform {

class TabDataCollectionUtil::LocalTabModelObserver : public TabModelObserver {
 public:
  explicit LocalTabModelObserver(TabDataCollectionUtil* collection_util)
      : collection_util_(collection_util) {}
  ~LocalTabModelObserver() override = default;

  void DidSelectTab(TabAndroid* tab, TabModel::TabSelectionType type) override {
    collection_util_->OnTabAction(tab, TabAction::kTabSelected);
  }
  void TabPendingClosure(TabAndroid* tab) override {
    collection_util_->OnTabAction(tab, TabAction::kTabClose);
  }
  void AllTabsPendingClosure(
      const std::vector<raw_ptr<TabAndroid, VectorExperimental>>& tabs)
      override {
    for (const TabAndroid* tab : tabs) {
      collection_util_->OnTabAction(tab, TabAction::kTabClose);
    }
  }

 private:
  const raw_ptr<TabDataCollectionUtil> collection_util_;
};

class TabDataCollectionUtil::LocalTabModelListObserver
    : public TabModelListObserver {
 public:
  LocalTabModelListObserver(LocalTabModelObserver* tab_observer,
                            TabDataCollectionUtil* collection_util)
      : tab_observer_(tab_observer), collection_util_(collection_util) {
    ResetObservers();
  }
  ~LocalTabModelListObserver() override = default;

  void OnTabModelAdded() override { ResetObservers(); }
  void OnTabModelRemoved() override { ResetObservers(); }

  void ResetObservers() {
    observing_models_.clear();
    for (TabModel* model : TabModelList::models()) {
      observing_models_.emplace_back(
          std::make_unique<TabModelObservation>(tab_observer_.get()));
      observing_models_.back()->Observe(model);
    }
  }

 private:
  using TabModelObservation =
      base::ScopedObservation<TabModel, TabModelObserver>;
  const raw_ptr<LocalTabModelObserver> tab_observer_;
  const raw_ptr<TabDataCollectionUtil> collection_util_;
  std::vector<std::unique_ptr<TabModelObservation>> observing_models_;
};

TabDataCollectionUtil::TabDataCollectionUtil(
    SegmentationPlatformService* segmentation_service,
    TabRankDispatcher* tab_rank_dispatcher)
    : segmentation_service_(segmentation_service),
      tab_rank_dispatcher_(tab_rank_dispatcher),
      tab_observer_(std::make_unique<LocalTabModelObserver>(this)),
      tab_model_list_observer_(
          std::make_unique<LocalTabModelListObserver>(tab_observer_.get(),
                                                      this)) {
  TabModelList::AddObserver(tab_model_list_observer_.get());
  base::AddActionCallback(base::BindRepeating(
      &TabDataCollectionUtil::OnUserAction, weak_ptr_factory_.GetWeakPtr()));
}

TabDataCollectionUtil::~TabDataCollectionUtil() {
  TabModelList::RemoveObserver(tab_model_list_observer_.get());
}

void TabDataCollectionUtil::OnTabAction(const TabAndroid* tab,
                                        TabAction tab_action) {
  auto it = tab_requests_.find(tab);
  if (it != tab_requests_.end()) {
    TrainingLabels labels;
    labels.output_metric =
        std::make_pair("TabAction", static_cast<int>(tab_action));
    segmentation_service_->CollectTrainingData(
        proto::SegmentId::TAB_RESUMPTION_CLASSIFIER, it->second, labels,
        base::DoNothing());
  }
}

void TabDataCollectionUtil::OnUserAction(const std::string& action,
                                         base::TimeTicks time) {
  if (action != "MobileToolbarShowStackView") {
    return;
  }
  // Filter only local tabs, with TabAndroid object usable.
  auto filter = base::BindRepeating(
      [](const TabFetcher::Tab& tab) { return !!tab.tab_android; });
  tab_rank_dispatcher_->GetTopRankedTabs(
      kTabResumptionClassifierKey, TabRankDispatcher::TabFilter(),
      base::BindOnce(&TabDataCollectionUtil::OnGetRankedTabs,
                     weak_ptr_factory_.GetWeakPtr()));
}

void TabDataCollectionUtil::OnGetRankedTabs(
    bool success,
    std::multiset<TabRankDispatcher::RankedTab> ranked_tabs) {
  if (!success) {
    return;
  }
  tab_requests_.clear();
  auto* fetcher = tab_rank_dispatcher_->tab_fetcher();
  for (const auto& candidate : ranked_tabs) {
    auto tab = fetcher->FindTab(candidate.tab);
    if (tab.tab_android) {
      tab_requests_[tab.tab_android] = candidate.request_id;
    }
  }
}

}  // namespace segmentation_platform