chromium/chrome/browser/ash/app_list/search/util/keyword_cache.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/ash/app_list/search/util/keyword_cache.h"

#include <algorithm>

#include "chrome/browser/ash/app_list/search/chrome_search_result.h"

namespace app_list {

KeywordCache::KeywordCache() = default;
KeywordCache::~KeywordCache() = default;

void KeywordCache::Clear() {
  last_item_scores_.clear();
  last_matched_items_.clear();
}

void KeywordCache::ScoreAndRecordMatch(ResultsMap& results,
                                       ProviderType provider,
                                       double keyword_multiplier) {
  const auto it = results.find(provider);

  if (it == results.end()) {
    return;
  }

  for (auto& result : it->second) {
    result->scoring().set_keyword_multiplier(keyword_multiplier);
    last_item_scores_[result->id()] =
        result->scoring().ftrl_result_score() * keyword_multiplier;
    last_matched_items_.insert(result->id());
  }
}

void KeywordCache::RecordNonMatch(ResultsMap& results, ProviderType provider) {
  const auto it = results.find(provider);

  if (it == results.end()) {
    return;
  }

  for (auto& result : it->second) {
    last_item_scores_[result->id()] = result->scoring().ftrl_result_score();
  }
}

void KeywordCache::Train(const std::string& item) {
  if (last_item_scores_.empty()) {
    return;
  }
  boost_factor_ += CalculateBoostChange(item);
}

// If the change is positive, boost_factor_ increase as (0.5 - original
// boost_factor) * (num_item_above) / rank. If the change is negative,
// boost_factor_ decreases as (original boost_factor) * (num_item_above) / rank.
double KeywordCache::CalculateBoostChange(const std::string& item) const {
  bool is_boosted = IsBoosted(item);

  auto it = last_item_scores_.find(item);
  if (it == last_item_scores_.end()) {
    return 0;
  }
  double selected_item_score = it->second;

  int rank = 0;
  int num_item_above = 0;

  for (const auto& [cached_item, cached_item_score] : last_item_scores_) {
    if (cached_item_score > selected_item_score) {
      rank++;
      if (IsBoosted(cached_item) && !is_boosted) {
        num_item_above--;
      }
      if (!IsBoosted(cached_item) && is_boosted) {
        num_item_above++;
      }
    }
  }
  if (rank == 0) {
    return 0;
  }
  if (num_item_above >= 0) {
    return (kMaxBoostFactor - boost_factor_) *
           static_cast<double>(num_item_above) / static_cast<double>(rank);
  }
  return boost_factor_ * static_cast<double>(num_item_above) /
         static_cast<double>(rank);
}

bool KeywordCache::IsBoosted(const std::string& item) const {
  return find(last_matched_items_.begin(), last_matched_items_.end(), item) !=
         last_matched_items_.end();
}

}  // namespace app_list