// Copyright 2020 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ash/power/ml/smart_dim/ml_agent.h"
#include <cstddef>
#include <memory>
#include "ash/constants/ash_features.h"
#include "base/containers/flat_map.h"
#include "base/metrics/field_trial_params.h"
#include "base/no_destructor.h"
#include "chrome/browser/ash/power/ml/smart_dim/metrics.h"
#include "chrome/browser/ash/power/ml/smart_dim/ml_agent_util.h"
#include "chrome/browser/ash/power/ml/user_activity_ukm_logger_helpers.h"
#include "chromeos/services/machine_learning/public/mojom/graph_executor.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/model.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/tensor.mojom.h"
#include "components/assist_ranker/example_preprocessing.h"
#include "components/assist_ranker/proto/example_preprocessor.pb.h"
#include "components/assist_ranker/proto/ranker_example.pb.h"
namespace ash {
namespace power {
namespace ml {
namespace {
using chromeos::machine_learning::mojom::ExecuteResult;
using chromeos::machine_learning::mojom::FloatList;
using chromeos::machine_learning::mojom::Int64List;
using chromeos::machine_learning::mojom::Tensor;
using chromeos::machine_learning::mojom::TensorPtr;
using chromeos::machine_learning::mojom::ValueList;
int ScoreToProbability(float score) {
const float sigmoid = 1.0f / (1.0f + exp(-score));
const int prob = floor(sigmoid * 100);
return prob;
}
// Callback for completed ML Service calls to Execute() on a model's
// GraphExecutor.
void ExecuteCallback(const double threshold,
DimDecisionCallback decision_callback,
ExecuteResult result,
std::optional<std::vector<TensorPtr>> outputs) {
UserActivityEvent::ModelPrediction prediction;
if (result != ExecuteResult::OK) {
DVLOG(1) << "Smart Dim inference execution failed.";
prediction.set_response(UserActivityEvent::ModelPrediction::MODEL_ERROR);
LogPowerMLSmartDimModelResult(SmartDimModelResult::kOtherError);
} else {
float inactivity_score =
(outputs.value())[0]->data->get_float_list()->value[0];
prediction.set_decision_threshold(ScoreToProbability(threshold));
prediction.set_inactivity_score(ScoreToProbability(inactivity_score));
prediction.set_response(inactivity_score >= threshold
? UserActivityEvent::ModelPrediction::DIM
: UserActivityEvent::ModelPrediction::NO_DIM);
LogPowerMLSmartDimModelResult(SmartDimModelResult::kSuccess);
}
std::move(decision_callback).Run(prediction);
}
// Populates |example| using |features|. Returns true if no error occurred.
bool PopulateRankerExample(const UserActivityEvent::Features& features,
assist_ranker::RankerExample* example) {
CHECK(example);
// Some features are bucketized before being logged to UKM. Hence training
// examples use bucketized values. We need to bucketize them here to ensure
// consistency.
// It's ok if a feature is missing from |features|, and we will not return
// false. But if a feature exists in |features|, then we expect it to have
// a bucketized version in |buckets|. If its bucketized version is missing
// from |buckets| then we return false.
const std::map<std::string, int> buckets =
UserActivityUkmLoggerBucketizer::BucketizeUserActivityEventFeatures(
features);
auto& ranker_example_features = *example->mutable_features();
if (features.has_battery_percent()) {
const auto it = buckets.find(kBatteryPercent);
if (it == buckets.end())
return false;
ranker_example_features[kBatteryPercent].set_int32_value(it->second);
}
if (features.has_device_management()) {
ranker_example_features["DeviceManagement"].set_int32_value(
features.device_management());
}
if (features.has_device_mode()) {
ranker_example_features["DeviceMode"].set_int32_value(
features.device_mode());
}
if (features.has_device_type()) {
ranker_example_features["DeviceType"].set_int32_value(
features.device_type());
}
if (features.has_key_events_in_last_hour()) {
const auto it = buckets.find(kKeyEventsInLastHour);
if (it == buckets.end())
return false;
ranker_example_features[kKeyEventsInLastHour].set_int32_value(it->second);
}
if (features.has_last_activity_day()) {
ranker_example_features["LastActivityDay"].set_int32_value(
features.last_activity_day());
}
if (features.has_last_activity_time_sec()) {
const auto it = buckets.find(kLastActivityTime);
if (it == buckets.end())
return false;
ranker_example_features[kLastActivityTime].set_int32_value(it->second);
}
if (features.has_last_user_activity_time_sec()) {
const auto it = buckets.find(kLastUserActivityTime);
if (it == buckets.end())
return false;
ranker_example_features[kLastUserActivityTime].set_int32_value(it->second);
}
if (features.has_mouse_events_in_last_hour()) {
const auto it = buckets.find(kMouseEventsInLastHour);
if (it == buckets.end())
return false;
ranker_example_features[kMouseEventsInLastHour].set_int32_value(it->second);
}
if (features.has_on_battery()) {
// This is an int value in the model.
ranker_example_features["OnBattery"].set_int32_value(features.on_battery());
}
ranker_example_features["PreviousNegativeActionsCount"].set_int32_value(
features.previous_negative_actions_count());
ranker_example_features["PreviousPositiveActionsCount"].set_int32_value(
features.previous_positive_actions_count());
ranker_example_features["RecentTimeActive"].set_int32_value(
features.recent_time_active_sec());
if (features.has_video_playing_time_sec()) {
const auto it = buckets.find(kRecentVideoPlayingTime);
if (it == buckets.end())
return false;
ranker_example_features[kRecentVideoPlayingTime].set_int32_value(
it->second);
}
if (features.has_on_to_dim_sec()) {
ranker_example_features["ScreenDimDelay"].set_int32_value(
features.on_to_dim_sec());
}
if (features.has_dim_to_screen_off_sec()) {
ranker_example_features["ScreenDimToOffDelay"].set_int32_value(
features.dim_to_screen_off_sec());
}
if (features.has_time_since_last_key_sec()) {
ranker_example_features["TimeSinceLastKey"].set_int32_value(
features.time_since_last_key_sec());
}
if (features.has_time_since_last_mouse_sec()) {
ranker_example_features["TimeSinceLastMouse"].set_int32_value(
features.time_since_last_mouse_sec());
}
if (features.has_time_since_video_ended_sec()) {
const auto it = buckets.find(kTimeSinceLastVideoEnded);
if (it == buckets.end())
return false;
ranker_example_features[kTimeSinceLastVideoEnded].set_int32_value(
it->second);
}
if (features.has_engagement_score()) {
ranker_example_features["SiteEngagementScore"].set_int32_value(
features.engagement_score());
}
if (features.has_has_form_entry()) {
ranker_example_features["HasFormEntry"].set_bool_value(
features.has_form_entry());
}
if (features.has_tab_domain()) {
ranker_example_features["TabDomain"].set_string_value(
features.tab_domain());
ranker_example_features["HasTabs"].set_bool_value(true);
} else {
ranker_example_features["HasTabs"].set_bool_value(false);
}
return true;
}
// Vectorize the features proto to feature vector with preprocessor.
SmartDimModelResult PreprocessInput(
const assist_ranker::ExamplePreprocessorConfig& preprocessor_config,
const UserActivityEvent::Features& features,
std::vector<float>* vectorized_features) {
DCHECK(vectorized_features);
assist_ranker::RankerExample ranker_example;
if (!PopulateRankerExample(features, &ranker_example)) {
return SmartDimModelResult::kOtherError;
}
int preprocessor_result = assist_ranker::ExamplePreprocessor::Process(
preprocessor_config, &ranker_example, true);
// kNoFeatureIndexFound can occur normally (e.g., when the domain name
// isn't known to the model or a rarely seen enum value is used).
if (preprocessor_result != assist_ranker::ExamplePreprocessor::kSuccess &&
preprocessor_result !=
assist_ranker::ExamplePreprocessor::kNoFeatureIndexFound) {
return SmartDimModelResult::kPreprocessorOtherError;
}
const auto& extracted_features =
ranker_example.features()
.at(assist_ranker::ExamplePreprocessor::kVectorizedFeatureDefaultName)
.float_list()
.float_value();
vectorized_features->assign(extracted_features.begin(),
extracted_features.end());
return SmartDimModelResult::kSuccess;
}
} // namespace
SmartDimMlAgent::SmartDimMlAgent() = default;
SmartDimMlAgent::~SmartDimMlAgent() = default;
SmartDimMlAgent* SmartDimMlAgent::GetInstance() {
static base::NoDestructor<SmartDimMlAgent> smart_dim_ml_agent;
return smart_dim_ml_agent.get();
}
bool SmartDimMlAgent::IsDownloadWorkerReady() {
return download_worker_.IsReady();
}
void SmartDimMlAgent::OnComponentReady(const ComponentFileContents& contents) {
download_worker_.InitializeFromComponent(std::move(contents));
}
void SmartDimMlAgent::RequestDimDecision(
const UserActivityEvent::Features& features,
DimDecisionCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
dim_decision_callback_.Reset(std::move(callback));
auto* worker = GetWorker();
UserActivityEvent::ModelPrediction prediction;
prediction.set_response(UserActivityEvent::ModelPrediction::MODEL_ERROR);
DCHECK(worker->GetPreprocessorConfig());
std::vector<float> vectorized_features;
auto preprocess_result = PreprocessInput(*(worker->GetPreprocessorConfig()),
features, &vectorized_features);
if (preprocess_result != SmartDimModelResult::kSuccess) {
LogPowerMLSmartDimModelResult(preprocess_result);
dim_decision_callback_.callback().Run(prediction);
return;
}
if (vectorized_features.size() != worker->expected_feature_size()) {
DVLOG(1) << "Smart Dim vectorized features not of correct size.";
LogPowerMLSmartDimModelResult(
SmartDimModelResult::kMismatchedFeatureSizeError);
dim_decision_callback_.callback().Run(prediction);
return;
}
DCHECK(worker->GetExecutor());
// Prepare the input tensor.
base::flat_map<std::string, TensorPtr> inputs;
auto tensor = Tensor::New();
tensor->shape = Int64List::New();
tensor->shape->value = std::vector<int64_t>(
{1, static_cast<int64_t>(vectorized_features.size())});
tensor->data = ValueList::NewFloatList(FloatList::New(std::vector<double>(
std::begin(vectorized_features), std::end(vectorized_features))));
inputs.emplace(std::string(kSmartDimInputNodeName), std::move(tensor));
std::vector<std::string> outputs({std::string(kSmartDimOutputNodeName)});
// Gets dim_threshold from finch experiment parameter, also logs status to
// UMA.
const double dim_threshold = base::GetFieldTrialParamByFeatureAsDouble(
features::kUserActivityPrediction, "dim_threshold",
worker->dim_threshold());
if (std::abs(dim_threshold - worker->dim_threshold()) < 1e-10)
LogPowerMLSmartDimParameterResult(
SmartDimParameterResult::kUseDefaultValue);
else
LogPowerMLSmartDimParameterResult(SmartDimParameterResult::kSuccess);
worker->GetExecutor()->Execute(
std::move(inputs), std::move(outputs),
base::BindOnce(&ExecuteCallback, dim_threshold,
dim_decision_callback_.callback()));
}
void SmartDimMlAgent::CancelPreviousRequest() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
dim_decision_callback_.Cancel();
}
void SmartDimMlAgent::ResetForTesting() {
builtin_worker_.Reset();
download_worker_.Reset();
}
SmartDimWorker* SmartDimMlAgent::GetWorker() {
if (download_worker_.IsReady()) {
// When download_worker_ is ready, builtin_worker_ is not useful any more,
// we can release it to save memory.
builtin_worker_.Reset();
LogWorkerType(WorkerType::kDownloadWorker);
return &download_worker_;
}
LogWorkerType(WorkerType::kBuiltinWorker);
return &builtin_worker_;
}
} // namespace ml
} // namespace power
} // namespace ash