// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ash/app_list/search/util/mrfu_cache.h"
#include <cmath>
#include "base/files/file_path.h"
#include "base/functional/bind.h"
#include "base/logging.h"
#include "base/numerics/safe_conversions.h"
#include "base/time/time.h"
namespace app_list {
namespace {
constexpr int kVersion = 2;
// We boost scores with the equation
//
// score = score + (1 - score) * k
//
// where k is a boost coefficient. This is hard to reason about, so instead
// set a 'boost factor', which is the answer to question: "how many consecutive
// uses should it take for a score to start at 0 and reach 0.8?". We can then
// define k based on the boost factor. Note 0.8 is chosen arbitrarily, but
// it's a reasonably high score.
//
// Here's our terminology:
// - x is the score
// - f(x) is one decay followed by one boost
// - f_n(x) is f applied n times to x
// - k is the boost coefficient
// - D is the decay coefficient
// - N is the boost factor
//
// We want an equation for the score after using a new item N times - ie.
// f_N(0) - and then solve for k.
//
// f(x) = Dx + (1-Dx)k
// = k + (1-k)Dx
//
// f_n(x) = (1-k)^n D^n x + k sum((1-k)^i D^i, 0 <= i < n)
//
// f_N(0) = k sum( (1-k)^i D^i, 0 <= i < N)
// = k (1 - D^N (1-k)^N) / (1 + D(k-1)) identity on sum
//
// Therefore we're looking for the value of k that satisfies
//
// k (1 - D^N (1-k)^N) / (1 + D(k-1)) = 0.8
//
// which isn't easily solvable, so this function approximates it numerically.
float ApproximateBoostCoefficient(float decay_coefficient, float boost_factor) {
float D = decay_coefficient;
float N = boost_factor;
float target = 0.8f;
float k_min = 0.0f;
float k_max = 1.0f;
for (int i = 0; i < 10; ++i) {
float k = (k_min + k_max) / 2.0f;
float value = k * (1 - pow(D, N) * pow(1 - k, N)) / (1 + D * k - D);
if (value < target) {
k_min = k;
} else {
k_max = k;
}
}
return (k_min + k_max) / 2.0f;
}
} // namespace
MrfuCache::MrfuCache(MrfuCache::Proto proto, const Params& params)
: proto_(std::move(proto)) {
// `proto_` is a class member so it is safe to call `RegisterOnInitUnsafe()`.
proto_.RegisterOnInitUnsafe(
base::BindOnce(&MrfuCache::OnProtoInit, base::Unretained(this)));
proto_.Init();
// See header comment for explanation.
decay_coeff_ = exp(log(0.5f) / params.half_life);
boost_coeff_ = ApproximateBoostCoefficient(decay_coeff_, params.boost_factor);
max_items_ = params.max_items;
min_score_ = params.min_score;
}
MrfuCache::~MrfuCache() {}
void MrfuCache::Sort(Items& items) {
std::sort(items.begin(), items.end(),
[](auto const& a, auto const& b) { return a.second > b.second; });
}
void MrfuCache::Use(const std::string& item) {
if (!proto_.initialized())
return;
// Get the Score for |item| from the proto. If it doesn't exist, create an
// empty score.
Score* score;
auto* items = proto_->mutable_items();
const auto& it = items->find(item);
if (it != items->end()) {
score = &it->second;
} else {
auto ret = items->insert({item, Score()});
DCHECK(ret.second);
score = &ret.first->second;
}
// The order of these three steps is important: first move 'time' forward one
// step, then decay the score, then add the boost for the current use.
proto_->set_update_count(proto_->update_count() + 1);
Decay(score);
float boost = boost_coeff_ * (1.0f - score->score());
score->set_score(score->score() + boost);
proto_->set_total_score(proto_->total_score() + boost);
MaybeCleanup();
proto_.QueueWrite();
}
float MrfuCache::Get(const std::string& item) {
if (!proto_.initialized())
return 0.0f;
auto* items = proto_->mutable_items();
const auto& it = items->find(item);
if (it == items->end())
return 0.0f;
// |score| may not be current, so |Decay| it if needed.
Score* score = &it->second;
Decay(score);
return score->score();
}
float MrfuCache::GetNormalized(const std::string& item) {
if (!proto_.initialized() || proto_->total_score() == 0.0f)
return 0.0f;
return Get(item) / proto_->total_score();
}
MrfuCache::Items MrfuCache::GetAll() {
if (!proto_.initialized())
return {};
MrfuCache::Items results;
for (auto& item_score : *proto_->mutable_items()) {
Score& score = item_score.second;
Decay(&score);
results.emplace_back(item_score.first, score.score());
}
return results;
}
MrfuCache::Items MrfuCache::GetAllNormalized() {
if (!proto_.initialized() || proto_->total_score() == 0.0f)
return {};
auto results = GetAll();
const float total = proto_->total_score();
for (auto& pair : results)
pair.second /= total;
return results;
}
void MrfuCache::Delete(const std::string& item) {
if (!proto_.initialized())
return;
proto_->set_total_score(proto_->total_score() - Get(item));
proto_->mutable_items()->erase(item);
proto_.QueueWrite();
}
void MrfuCache::ResetWithItems(const Items& items) {
DCHECK(proto_.initialized());
proto_->Clear();
proto_->set_update_count(0);
float total_score = 0.0f;
auto* proto_items = proto_->mutable_items();
for (const auto& item_score : items) {
Score score;
score.set_score(item_score.second);
score.set_last_update_count(0);
proto_items->insert({item_score.first, score});
total_score += item_score.second;
}
proto_->set_total_score(total_score);
proto_.QueueWrite();
}
void MrfuCache::Decay(Score* score) {
int64_t update_count = proto_->update_count();
int64_t count_delta = update_count - score->last_update_count();
if (count_delta > 0) {
float decay = std::pow(decay_coeff_, count_delta);
proto_->set_total_score(proto_->total_score() +
(decay - 1.0f) * score->score());
score->set_score(score->score() * decay);
score->set_last_update_count(update_count);
proto_.QueueWrite();
}
}
void MrfuCache::MaybeCleanup() {
if (base::checked_cast<size_t>(proto_->items_size()) < 2u * max_items_)
return;
// Ensure all scores are up to date, and then keep all those over the
// |min_score_| threshold.
std::vector<std::pair<std::string, Score>> kept_items;
for (auto& item_score : *proto_->mutable_items()) {
Score& score = item_score.second;
Decay(&score);
if (score.score() > min_score_)
kept_items.emplace_back(item_score.first, item_score.second);
}
// Sort them high-to-low by score.
std::sort(kept_items.begin(), kept_items.end(),
[](auto const& a, auto const& b) {
return a.second.score() > b.second.score();
});
// Clear the proto and reinsert at most |max_items_| items.
float new_total = 0.0f;
proto_->clear_items();
for (size_t i = 0; i < std::min(max_items_, kept_items.size()); ++i) {
proto_->mutable_items()->insert(
{kept_items[i].first, kept_items[i].second});
new_total += kept_items[i].second.score();
}
proto_->set_total_score(new_total);
proto_.QueueWrite();
}
void MrfuCache::OnProtoInit() {
if (!proto_->has_version() || proto_->version() != kVersion) {
proto_.Purge();
}
proto_->set_version(kVersion);
}
} // namespace app_list