chromium/chrome/browser/ash/app_list/search/ranking/ftrl_ranker.cc

// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/ash/app_list/search/ranking/ftrl_ranker.h"

#include "base/strings/string_number_conversions.h"
#include "chrome/browser/ash/app_list/search/chrome_search_result.h"
#include "chrome/browser/ash/app_list/search/ranking/util.h"
#include "chrome/browser/ash/app_list/search/util/ftrl_optimizer.h"

namespace app_list {

// FtrlRanker ------------------------------------------------------------

FtrlRanker::FtrlRanker(FtrlRanker::RankingKind kind,
                       FtrlOptimizer::Params params,
                       FtrlOptimizer::Proto proto)
    : kind_(kind),
      ftrl_(std::make_unique<FtrlOptimizer>(std::move(proto), params)) {}

FtrlRanker::~FtrlRanker() = default;

void FtrlRanker::AddExpert(std::unique_ptr<Ranker> ranker) {
  rankers_.push_back(std::move(ranker));
}

void FtrlRanker::Start(const std::u16string& query,
                       CategoriesList& categories) {
  for (auto& ranker : rankers_)
    ranker->Start(query, categories);

  ftrl_->Clear();
}

void FtrlRanker::Train(const LaunchData& launch) {
  for (auto& ranker : rankers_)
    ranker->Train(launch);

  switch (kind_) {
    case FtrlRanker::RankingKind::kResults:
      ftrl_->Train(launch.id);
      break;
    case FtrlRanker::RankingKind::kCategories:
      ftrl_->Train(CategoryToString(launch.category));
      break;
  }
}

void FtrlRanker::UpdateResultRanks(ResultsMap& results, ProviderType provider) {
  if (kind_ != FtrlRanker::RankingKind::kResults)
    return;

  const auto it = results.find(provider);
  DCHECK(it != results.end());
  const auto& new_results = it->second;

  // Create a vector of result ids.
  std::vector<std::string> ids;
  for (const auto& result : new_results)
    ids.push_back(result->id());

  // Create a vector of vectors for each expert's scores.
  std::vector<std::vector<double>> expert_scores;
  for (auto& ranker : rankers_)
    expert_scores.push_back(ranker->GetResultRanks(results, provider));

  // Get the final scores from the FTRL optimizer set them on the results.
  std::vector<double> result_scores =
      ftrl_->Score(std::move(ids), std::move(expert_scores));
  DCHECK_EQ(new_results.size(), result_scores.size());
  for (size_t i = 0; i < new_results.size(); ++i)
    new_results[i]->scoring().set_ftrl_result_score(result_scores[i]);
}

void FtrlRanker::UpdateCategoryRanks(const ResultsMap& results,
                                     CategoriesList& categories,
                                     ProviderType provider) {
  if (kind_ != FtrlRanker::RankingKind::kCategories)
    return;

  // Create a vector of category strings.
  std::vector<std::string> category_strings;
  for (const auto& category : categories)
    category_strings.push_back(CategoryToString(category.category));

  // Create a vector of vectors for each expert's scores.
  std::vector<std::vector<double>> expert_scores;
  for (auto& ranker : rankers_) {
    expert_scores.push_back(
        ranker->GetCategoryRanks(results, categories, provider));
  }

  // Get the final scores from the FTRL optimizer set them on the categories.
  std::vector<double> category_scores =
      ftrl_->Score(std::move(category_strings), std::move(expert_scores));
  DCHECK_EQ(categories.size(), category_scores.size());
  for (size_t i = 0; i < categories.size(); ++i)
    categories[i].score = category_scores[i];
}

// ResultScoringShim ----------------------------------------------------------

ResultScoringShim::ResultScoringShim(ResultScoringShim::ScoringMember member)
    : member_(member) {}

std::vector<double> ResultScoringShim::GetResultRanks(const ResultsMap& results,
                                                      ProviderType provider) {
  const auto it = results.find(provider);
  if (it == results.end())
    return {};

  std::vector<double> scores;
  for (const auto& result : it->second) {
    double score;
    switch (member_) {
      case ResultScoringShim::ScoringMember::kNormalizedRelevance:
        score = result->scoring().normalized_relevance();
        break;
      case ResultScoringShim::ScoringMember::kMrfuResultScore:
        score = result->scoring().mrfu_result_score();
        break;
    }
    scores.push_back(score);
  }
  return scores;
}

// BestResultCategoryRanker ----------------------------------------------------

BestResultCategoryRanker::BestResultCategoryRanker() = default;
BestResultCategoryRanker::~BestResultCategoryRanker() = default;

void BestResultCategoryRanker::Start(const std::u16string& query,
                                     CategoriesList& categories) {
  current_category_scores_.clear();
  for (const auto& category : categories)
    current_category_scores_[category.category] = 0.0;
}

std::vector<double> BestResultCategoryRanker::GetCategoryRanks(
    const ResultsMap& results,
    const CategoriesList& categories,
    ProviderType provider) {
  const auto it = results.find(provider);
  DCHECK(it != results.end());

  for (const auto& result : it->second) {
    // Ignore best match, answer card results, and filtered results for the
    // purposes of deciding category scores, because they are not displayed in
    // their category.
    if (result->best_match() || result->scoring().filtered() ||
        result->display_type() ==
            ChromeSearchResult::DisplayType::kAnswerCard) {
      continue;
    }
    current_category_scores_[result->category()] =
        std::max(result->scoring().normalized_relevance(),
                 current_category_scores_[result->category()]);
  }

  // Collect those scores into a vector with the same ordering as |categories|.
  // For categories with no results in them, defaults to 0.0.
  std::vector<double> result;
  for (const auto& category : categories)
    result.push_back(current_category_scores_[category.category]);
  return result;
}

}  // namespace app_list