chromium/chrome/browser/ash/app_list/search/ranking/ftrl_ranker.h

// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef CHROME_BROWSER_ASH_APP_LIST_SEARCH_RANKING_FTRL_RANKER_H_
#define CHROME_BROWSER_ASH_APP_LIST_SEARCH_RANKING_FTRL_RANKER_H_

#include "chrome/browser/ash/app_list/search/ranking/ranker.h"
#include "chrome/browser/ash/app_list/search/types.h"
#include "chrome/browser/ash/app_list/search/util/ftrl_optimizer.h"
#include "chrome/browser/ash/app_list/search/util/mrfu_cache.h"

namespace app_list {

// A ranker for search results using a Follow the Regularized Leader algorithm.
// This learns weightings for the 'experts' below.
class FtrlRanker : public Ranker {
 public:
  enum class RankingKind {
    kResults,
    kCategories,
  };

  FtrlRanker(RankingKind kind,
             FtrlOptimizer::Params params,
             FtrlOptimizer::Proto proto);
  ~FtrlRanker() override;

  FtrlRanker(const FtrlRanker&) = delete;
  FtrlRanker& operator=(const FtrlRanker&) = delete;

  void AddExpert(std::unique_ptr<Ranker> ranker);

  // Ranker:
  void Start(const std::u16string& query, CategoriesList& categories) override;
  void Train(const LaunchData& launch) override;
  void UpdateResultRanks(ResultsMap& results, ProviderType provider) override;
  void UpdateCategoryRanks(const ResultsMap& results,
                           CategoriesList& categories,
                           ProviderType provider) override;

 private:
  RankingKind kind_;

  // The Follow the Regularized Leader instance that chooses amongst the expert
  // |rankers_|.
  std::unique_ptr<FtrlOptimizer> ftrl_;

  // The 'experts' in the follow-the-regularized-leader model.
  std::vector<std::unique_ptr<Ranker>> rankers_;
};

// The following classes are 'experts', ie. sub-rankers to be used within the
// FtrlRanker.

// An expert that exposes a score from each result's scoring struct.
class ResultScoringShim : public Ranker {
 public:
  // Correspond to the members of a search result's `Scoring`.
  enum class ScoringMember {
    kNormalizedRelevance,
    kMrfuResultScore,
  };

  explicit ResultScoringShim(ScoringMember member);

  // Ranker:
  std::vector<double> GetResultRanks(const ResultsMap& results,
                                     ProviderType provider) override;

 private:
  ScoringMember member_;
};

// Ranks a category based on the normalized relevance of its best result.
class BestResultCategoryRanker : public Ranker {
 public:
  BestResultCategoryRanker();
  ~BestResultCategoryRanker() override;

  // Ranker:
  void Start(const std::u16string& query, CategoriesList& categories) override;
  std::vector<double> GetCategoryRanks(const ResultsMap& results,
                                       const CategoriesList& categories,
                                       ProviderType provider) override;

 private:
  base::flat_map<Category, double> current_category_scores_;
};

}  // namespace app_list

#endif  // CHROME_BROWSER_ASH_APP_LIST_SEARCH_RANKING_FTRL_RANKER_H_