// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chromeos/ash/components/heatmap/heatmap_palm_detector_impl.h"
#include "base/strings/strcat.h"
#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
namespace ash {
using ::chromeos::machine_learning::mojom::HeatmapPalmRejectionConfig;
using ::chromeos::machine_learning::mojom::HeatmapProcessedEventPtr;
using ::chromeos::machine_learning::mojom::LoadHeatmapPalmRejectionResult;
namespace {
constexpr char kSystemModelDir[] = "/opt/google/chrome/ml_models/";
constexpr base::TimeDelta kTimestampDiffThreshold = base::Milliseconds(20);
constexpr base::TimeDelta kReconnectInitialDelay = base::Seconds(1);
constexpr base::TimeDelta kReconnectMaxDelay = base::Minutes(10);
struct HeatmapModelMetadata {
std::string model_file;
int input_node;
int output_node;
double palm_threshold;
};
using MetadataMap =
std::map<HeatmapPalmDetectorImpl::ModelId, HeatmapModelMetadata>;
// Returns a map from device ID to model metadata for each supported device.
MetadataMap GetHeatmapModelMetadata() {
return {{HeatmapPalmDetectorImpl::ModelId::kRex,
{
.model_file =
"mlservice-model-poncho_palm_rejection-20230907-v0.tflite",
.input_node = 0,
.output_node = 23,
.palm_threshold = 0.6,
}},
{HeatmapPalmDetectorImpl::ModelId::kGeralt,
{
.model_file =
"mlservice-model-poncho-palm_rejection_g-20240313-v0.tflite",
.input_node = 0,
.output_node = 21,
.palm_threshold = 0.6,
}}};
}
bool CanBeMatched(base::Time t1, base::Time t2) {
return (t1 - t2).magnitude() < kTimestampDiffThreshold;
}
} // namespace
HeatmapPalmDetectorImpl::HeatmapPalmDetectorImpl()
: reconnect_delay_(kReconnectInitialDelay), client_(this) {}
HeatmapPalmDetectorImpl::~HeatmapPalmDetectorImpl() = default;
void HeatmapPalmDetectorImpl::Start(ModelId model_id,
std::string_view hidraw_path,
std::optional<CropHeatmap> crop_heatmap) {
crop_heatmap_ = crop_heatmap;
model_id_ = model_id;
hidraw_path_ = hidraw_path;
const MetadataMap model_metadata = GetHeatmapModelMetadata();
const auto metadata_lookup = model_metadata.find(model_id);
if (metadata_lookup == model_metadata.end()) {
LOG(ERROR) << "Invalid model ID: " << static_cast<int>(model_id);
return;
}
auto config = HeatmapPalmRejectionConfig::New();
config->tf_model_path =
base::StrCat({kSystemModelDir, metadata_lookup->second.model_file});
config->input_node = metadata_lookup->second.input_node;
config->output_node = metadata_lookup->second.output_node;
config->palm_threshold = metadata_lookup->second.palm_threshold;
config->heatmap_hidraw_device = hidraw_path;
if (crop_heatmap) {
if (!config->crop_heatmap) {
config->crop_heatmap =
::chromeos::machine_learning::mojom::CropHeatmap::New();
}
config->crop_heatmap->bottom_crop = crop_heatmap->bottom_crop;
config->crop_heatmap->left_crop = crop_heatmap->left_crop;
config->crop_heatmap->right_crop = crop_heatmap->right_crop;
config->crop_heatmap->top_crop = crop_heatmap->top_crop;
}
if (!ml_service_) {
chromeos::machine_learning::ServiceConnection::GetInstance()
->BindMachineLearningService(ml_service_.BindNewPipeAndPassReceiver());
}
ml_service_.set_disconnect_handler(base::BindOnce(
&HeatmapPalmDetectorImpl::OnConnectionError, weak_factory_.GetWeakPtr()));
ml_service_->LoadHeatmapPalmRejection(
std::move(config), client_.BindNewPipeAndPassRemote(),
base::BindOnce(&HeatmapPalmDetectorImpl::OnLoadHeatmapPalmRejection,
weak_factory_.GetWeakPtr()));
}
void HeatmapPalmDetectorImpl::OnConnectionError() {
ml_service_.reset();
client_.reset();
is_ready_ = false;
std::queue<TouchRecord>().swap(touch_records_);
palm_tracking_ids_.clear();
delay_timer_.Start(
FROM_HERE, reconnect_delay_,
base::BindOnce(&HeatmapPalmDetectorImpl::Start,
weak_factory_.GetWeakPtr(), model_id_,
hidraw_path_, crop_heatmap_));
if (reconnect_delay_ <
kReconnectMaxDelay) { // exponential backoff with max limit
reconnect_delay_ *= 2;
}
}
void HeatmapPalmDetectorImpl::OnLoadHeatmapPalmRejection(
LoadHeatmapPalmRejectionResult result) {
reconnect_delay_ = kReconnectInitialDelay;
if (result == LoadHeatmapPalmRejectionResult::OK) {
std::queue<TouchRecord>().swap(touch_records_);
palm_tracking_ids_.clear();
is_ready_ = true;
}
}
void HeatmapPalmDetectorImpl::OnHeatmapProcessedEvent(
HeatmapProcessedEventPtr event) {
if (touch_records_.empty()) {
return;
}
TouchRecord best_match = touch_records_.front();
if (best_match.timestamp > event->timestamp) {
if (CanBeMatched(best_match.timestamp, event->timestamp)) {
touch_records_.pop();
} else {
// Cannot find a matching record.
return;
}
} else {
// Find the last record which is before the heatmap data.
while (!touch_records_.empty() &&
touch_records_.front().timestamp < event->timestamp) {
best_match = touch_records_.front();
touch_records_.pop();
}
// Check if the next record is a better match.
if (!touch_records_.empty() &&
touch_records_.front().timestamp - event->timestamp <
event->timestamp - best_match.timestamp &&
CanBeMatched(touch_records_.front().timestamp, event->timestamp)) {
best_match = touch_records_.front();
touch_records_.pop();
}
if (!CanBeMatched(best_match.timestamp, event->timestamp)) {
// Cannot find a matching record.
return;
}
}
if (event->is_palm) {
for (int id : best_match.tracking_ids) {
palm_tracking_ids_.insert(id);
}
}
}
bool HeatmapPalmDetectorImpl::IsPalm(int tracking_id) const {
return palm_tracking_ids_.find(tracking_id) != palm_tracking_ids_.end();
}
bool HeatmapPalmDetectorImpl::IsReady() const {
return is_ready_;
}
void HeatmapPalmDetectorImpl::AddTouchRecord(
base::Time timestamp,
const std::vector<int>& tracking_ids) {
touch_records_.push(TouchRecord(timestamp, tracking_ids));
}
void HeatmapPalmDetectorImpl::RemoveTouch(int tracking_id) {
palm_tracking_ids_.erase(tracking_id);
}
} // namespace ash