// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ash/app_list/search/util/ftrl_optimizer.h"
#include <cmath>
#include "base/files/file_path.h"
#include "base/functional/bind.h"
#include "base/numerics/safe_conversions.h"
namespace app_list {
namespace {
// A version number to be incremented each time a code change invalidates the
// state stored on-disk in |proto_|.
constexpr int kVersion = 1;
double Total(const google::protobuf::RepeatedField<double>& values) {
double total = 0.0;
for (double v : values)
total += v;
return total;
}
// Normalize |values| to sum to 1 in-place.
void Normalize(google::protobuf::RepeatedField<double>& values) {
double total = Total(values);
if (total == 0.0)
return;
for (int i = 0; i < values.size(); ++i)
values[i] = values[i] / total;
}
} // namespace
FtrlOptimizer::FtrlOptimizer(FtrlOptimizer::Proto proto, const Params& params)
: params_(params), proto_(std::move(proto)) {
DCHECK_GT(params.alpha, 0.0);
DCHECK_GE(params.gamma, 0.0);
DCHECK_LE(params.gamma, 1.0);
DCHECK_GT(params.num_experts, 0u);
// `proto_` is a class member so it is safe to call `RegisterOnInitUnsafe()`.
proto_.RegisterOnInitUnsafe(
base::BindOnce(&FtrlOptimizer::OnProtoInit, base::Unretained(this)));
proto_.Init();
}
FtrlOptimizer::~FtrlOptimizer() {}
void FtrlOptimizer::Clear() {
last_expert_scores_.clear();
}
std::vector<double> FtrlOptimizer::Score(
std::vector<std::string>&& items,
std::vector<std::vector<double>>&& expert_scores) {
size_t num_items = items.size();
size_t num_experts = params_.num_experts;
std::vector<double> result(num_items, 0.0);
if (!proto_.initialized())
return result;
const auto& weights = proto_->weights();
DCHECK_EQ(expert_scores.size(), num_experts);
DCHECK_GE(weights.size(), 0);
DCHECK_EQ(static_cast<size_t>(weights.size()), num_experts);
for (size_t i = 0; i < num_items; ++i) {
last_expert_scores_[items[i]] = {};
for (size_t j = 0; j < num_experts; ++j) {
result[i] += weights[j] * expert_scores[j][i];
last_expert_scores_[items[i]].emplace_back(expert_scores[j][i]);
}
}
return result;
}
void FtrlOptimizer::Train(const std::string& item) {
// If |last_items_| is empty, experts had no chance at prediction and we
// should early exit. This could happen if |proto_| finishes initializing
// after Score but before Train.
if (!proto_.initialized() || last_expert_scores_.empty()) {
return;
}
// Compute the loss of each expert and update weights.
auto& weights = *proto_->mutable_weights();
for (int i = 0; i < weights.size(); ++i) {
double loss = Loss(i, item);
double fixed_share = params_.gamma / weights.size();
double weight_factor = (1.0 - params_.gamma) * exp(-params_.alpha * loss);
weights[i] = fixed_share + weight_factor * weights[i];
}
// Re-normalize the weights.
Normalize(weights);
DCHECK_LE(std::abs(Total(proto_->weights()) - 1.0), 1.0e-5);
proto_.StartWrite();
}
double FtrlOptimizer::Loss(size_t expert, const std::string& item) {
size_t num_experts = params_.num_experts;
size_t num_items = last_expert_scores_.size();
DCHECK_GT(num_items, 0u);
DCHECK_LT(expert, num_experts);
// Find the score of the launched item.
double score = {0.0};
if (last_expert_scores_.find(item) != last_expert_scores_.end()) {
DCHECK_EQ(last_expert_scores_[item].size(), num_experts);
score = last_expert_scores_[item][expert];
}
// Find the rank of the item, ie. the number of items with higher score.
size_t rank = 0;
for (const auto& scores : last_expert_scores_) {
if (scores.second[expert] > score) {
++rank;
}
}
// The loss is linear in the |rank|. A loss of 1.0 means |item| wasn't
// included at all.
DCHECK(!last_expert_scores_.empty());
return static_cast<double>(rank) / last_expert_scores_.size();
}
void FtrlOptimizer::OnProtoInit() {
if (!proto_->has_version() || proto_->version() != kVersion ||
params_.num_experts !=
base::checked_cast<size_t>(proto_->weights_size())) {
proto_.Purge();
proto_->set_version(kVersion);
for (size_t i = 0; i < params_.num_experts; ++i)
proto_->add_weights(1.0 / params_.num_experts);
}
DCHECK_LE(std::abs(Total(proto_->weights()) - 1.0), 1.0e-5);
}
} // namespace app_list