chromium/chrome/browser/ash/power/ml/smart_dim/download_worker.cc

// Copyright 2020 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/ash/power/ml/smart_dim/download_worker.h"

#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/task/task_traits.h"
#include "chrome/browser/ash/power/ml/smart_dim/metrics.h"
#include "chrome/browser/ash/power/ml/smart_dim/ml_agent_util.h"
#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
#include "components/assist_ranker/proto/example_preprocessor.pb.h"
#include "content/public/browser/browser_task_traits.h"
#include "content/public/browser/browser_thread.h"
#include "ui/base/resource/resource_bundle.h"

namespace ash {
namespace power {
namespace ml {

namespace {
using chromeos::machine_learning::mojom::FlatBufferModelSpec;
}  // namespace

DownloadWorker::DownloadWorker() : SmartDimWorker(), metrics_model_name_("") {}

DownloadWorker::~DownloadWorker() = default;

const assist_ranker::ExamplePreprocessorConfig*
DownloadWorker::GetPreprocessorConfig() {
  return preprocessor_config_.get();
}

const mojo::Remote<chromeos::machine_learning::mojom::GraphExecutor>&
DownloadWorker::GetExecutor() {
  return executor_;
}

void DownloadWorker::LoadModelCallback(
    chromeos::machine_learning::mojom::LoadModelResult result) {
  if (result != chromeos::machine_learning::mojom::LoadModelResult::OK) {
    LogLoadComponentEvent(LoadComponentEvent::kLoadModelError);
    DVLOG(1) << "Failed to load Smart Dim flatbuffer model.";
  }
}

void DownloadWorker::CreateGraphExecutorCallback(
    chromeos::machine_learning::mojom::CreateGraphExecutorResult result) {
  if (result !=
      chromeos::machine_learning::mojom::CreateGraphExecutorResult::OK) {
    LogLoadComponentEvent(LoadComponentEvent::kCreateGraphExecutorError);
    DVLOG(1) << "Failed to create a Smart Dim graph executor.";
  } else {
    LogLoadComponentEvent(LoadComponentEvent::kSuccess);
  }
}

bool DownloadWorker::IsReady() {
  return preprocessor_config_ && model_ && executor_ &&
         expected_feature_size_ > 0 && metrics_model_name_ != "";
}

void DownloadWorker::InitializeFromComponent(
    const ComponentFileContents& contents) {
  DCHECK_CURRENTLY_ON(content::BrowserThread::UI);

  auto [metadata_json, preprocessor_proto, model_flatbuffer] = contents;

  preprocessor_config_ =
      std::make_unique<assist_ranker::ExamplePreprocessorConfig>();
  if (!preprocessor_config_->ParseFromString(preprocessor_proto)) {
    LogLoadComponentEvent(LoadComponentEvent::kLoadPreprocessorError);
    DVLOG(1) << "Failed to load preprocessor_config.";
    preprocessor_config_.reset();
    return;
  }

  // Meta data contains necessary info to construct FlatBufferModelSpec, and
  // other optional info.
  data_decoder::DataDecoder::ParseJsonIsolated(
      std::move(metadata_json),
      base::BindOnce(&DownloadWorker::OnJsonParsed, base::Unretained(this),
                     std::move(model_flatbuffer)));
}

void DownloadWorker::SetOnReadyForTest(base::OnceClosure on_ready) {
  on_ready_for_test_ = std::move(on_ready);
}

void DownloadWorker::OnJsonParsed(
    const std::string& model_flatbuffer,
    const data_decoder::DataDecoder::ValueOrError result) {
  DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
  if (!result.has_value() || !result->is_dict() ||
      !ParseMetaInfoFromJsonObject(*result, &metrics_model_name_,
                                   &dim_threshold_, &expected_feature_size_,
                                   &inputs_, &outputs_)) {
    LogLoadComponentEvent(LoadComponentEvent::kLoadMetadataError);
    DVLOG(1) << "Failed to parse meta info from metadata_json.";
    return;
  }
  content::GetUIThreadTaskRunner({base::TaskPriority::BEST_EFFORT})
      ->PostTask(
          FROM_HERE,
          base::BindOnce(&DownloadWorker::LoadModelAndCreateGraphExecutor,
                         base::Unretained(this), std::move(model_flatbuffer)));
}

void DownloadWorker::LoadModelAndCreateGraphExecutor(
    const std::string& model_flatbuffer) {
  DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
  DCHECK(!model_.is_bound() && !executor_.is_bound());

  chromeos::machine_learning::ServiceConnection::GetInstance()
      ->GetMachineLearningService()
      .LoadFlatBufferModel(
          FlatBufferModelSpec::New(std::move(model_flatbuffer), inputs_,
                                   outputs_, metrics_model_name_),
          model_.BindNewPipeAndPassReceiver(),
          base::BindOnce(&DownloadWorker::LoadModelCallback,
                         base::Unretained(this)));
  model_->CreateGraphExecutor(
      chromeos::machine_learning::mojom::GraphExecutorOptions::New(),
      executor_.BindNewPipeAndPassReceiver(),
      base::BindOnce(&DownloadWorker::CreateGraphExecutorCallback,
                     base::Unretained(this)));
  executor_.set_disconnect_handler(base::BindOnce(
      &DownloadWorker::OnConnectionError, base::Unretained(this)));
  if (on_ready_for_test_) {
    std::move(on_ready_for_test_).Run();
  }
}

}  // namespace ml
}  // namespace power
}  // namespace ash