chromium/chrome/browser/ash/app_list/search/ranking/mrfu_ranker.cc

// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/ash/app_list/search/ranking/mrfu_ranker.h"

#include "chrome/browser/ash/app_list/search/chrome_search_result.h"
#include "chrome/browser/ash/app_list/search/ranking/util.h"
#include "chrome/browser/ash/app_list/search/types.h"
#include "chrome/browser/ash/app_list/search/util/ftrl_optimizer.h"

namespace app_list {

// MrfuResultRanker -----------------------------------------------------------

MrfuResultRanker::MrfuResultRanker(MrfuCache::Params params,
                                   MrfuCache::Proto proto)
    : mrfu_(std::make_unique<MrfuCache>(std::move(proto), params)) {}

MrfuResultRanker::~MrfuResultRanker() = default;

std::vector<double> MrfuResultRanker::GetResultRanks(const ResultsMap& results,
                                                     ProviderType provider) {
  const auto it = results.find(provider);
  if (it == results.end())
    return {};

  std::vector<double> scores;
  for (const auto& result : it->second)
    scores.push_back(mrfu_->Get(result->id()));
  return scores;
}

void MrfuResultRanker::UpdateResultRanks(ResultsMap& results,
                                         ProviderType provider) {
  const auto it = results.find(provider);
  if (it == results.end())
    return;

  for (auto& result : it->second)
    result->scoring().set_mrfu_result_score(mrfu_->Get(result->id()));
}

void MrfuResultRanker::Train(const LaunchData& launch) {
  if (launch.launched_from !=
      ash::AppListLaunchedFrom::kLaunchedFromSearchBox) {
    return;
  }
  mrfu_->Use(launch.id);
}

// MrfuCategoryRanker ----------------------------------------------------------

MrfuCategoryRanker::MrfuCategoryRanker(
    MrfuCache::Params params,
    ash::PersistentProto<MrfuCacheProto> proto)
    : mrfu_(std::make_unique<MrfuCache>(std::move(proto), params)) {}

MrfuCategoryRanker::~MrfuCategoryRanker() = default;

void MrfuCategoryRanker::Start(const std::u16string& query,
                               CategoriesList& categories) {
  if (mrfu_->initialized() && mrfu_->empty())
    SetDefaultCategoryScores();
}

std::vector<double> MrfuCategoryRanker::GetCategoryRanks(
    const ResultsMap& results,
    const CategoriesList& categories,
    ProviderType provider) {
  if (!mrfu_->initialized())
    return std::vector<double>(categories.size(), 0.0);

  // Build a map of the MRFU category scores, with 0.0 for unseen categories.
  base::flat_map<Category, double> scores_map;
  for (const auto& id_score : mrfu_->GetAllNormalized())
    scores_map[StringToCategory(id_score.first)] = id_score.second;

  std::vector<double> scores;
  for (const auto& category : categories) {
    const auto it = scores_map.find(category.category);
    scores.push_back(it != scores_map.end() ? it->second : 0.0);
  }
  DCHECK_EQ(scores.size(), categories.size());
  return scores;
}

void MrfuCategoryRanker::UpdateCategoryRanks(const ResultsMap& results,
                                             CategoriesList& categories,
                                             ProviderType provider) {
  const auto& scores = GetCategoryRanks(results, categories, provider);
  DCHECK_EQ(scores.size(), categories.size());
  if (scores.size() != categories.size())
    return;
  for (size_t i = 0; i < categories.size(); ++i)
    categories[i].score = scores[i];
}

void MrfuCategoryRanker::Train(const LaunchData& launch) {
  if (launch.launched_from !=
      ash::AppListLaunchedFrom::kLaunchedFromSearchBox) {
    return;
  }
  mrfu_->Use(CategoryToString(launch.category));
}

void MrfuCategoryRanker::SetDefaultCategoryScores() {
  // Default category prioritization:
  //
  //   P1: Apps
  //       Play Store
  //       Settings
  //   P2: Web
  //       Files (local and Drive)
  //       Shortcuts (which are under the Help category)
  //   P3: Everything else
  //
  // Achieve this by training once each, in reverse order.
  mrfu_->Use(CategoryToString(Category::kHelp));
  mrfu_->Use(CategoryToString(Category::kFiles));
  mrfu_->Use(CategoryToString(Category::kWeb));
  mrfu_->Use(CategoryToString(Category::kSettings));
  mrfu_->Use(CategoryToString(Category::kPlayStore));
  mrfu_->Use(CategoryToString(Category::kApps));
}

}  // namespace app_list