chromium/services/webnn/tflite/context_impl_cros.cc

// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "services/webnn/tflite/context_impl_cros.h"

#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
#include "services/webnn/tflite/buffer_impl_tflite.h"
#include "services/webnn/tflite/context_impl_tflite.h"
#include "services/webnn/tflite/graph_builder_tflite.h"
#include "services/webnn/tflite/graph_impl_cros.h"
#include "services/webnn/webnn_context_impl.h"

namespace webnn::tflite {

ContextImplCrOS::ContextImplCrOS(
    mojo::PendingReceiver<mojom::WebNNContext> receiver,
    WebNNContextProviderImpl* context_provider,
    mojom::CreateContextOptionsPtr options)
    : WebNNContextImpl(std::move(receiver),
                       context_provider,
                       GraphBuilderTflite::GetContextProperties(),
                       std::move(options)) {}

ContextImplCrOS::~ContextImplCrOS() = default;

base::WeakPtr<WebNNContextImpl> ContextImplCrOS::AsWeakPtr() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  return weak_factory_.GetWeakPtr();
}

void ContextImplCrOS::LoadModel(
    flatbuffers::DetachedBuffer model_content,
    ml::model_loader::mojom::ModelLoader::LoadCallback callback) {
  if (!model_loader_remote_.is_bound()) {
    // Bootstrap the mojo connection for ml-service first.
    //
    // The remote sent to WebNN Service that is hosted in browser process.
    mojo::PendingRemote<ml::model_loader::mojom::ModelLoader>
        webnn_service_remote;
    // The receiver sent to ML Service.
    auto ml_service_receiver =
        webnn_service_remote.InitWithNewPipeAndPassReceiver();
    // "0" means the backend can determine number of threads automatically.
    // TODO(crbug.com/330380801): Support other device types.
    auto options = ml::model_loader::mojom::CreateModelLoaderOptions::New(
        /*num_threads=*/0, ml::model_loader::mojom::ModelFormat::kTfLite,
        ml::model_loader::mojom::DevicePreference::kCpu);
    chromeos::machine_learning::ServiceConnection::GetInstance()
        ->GetMachineLearningService()
        .CreateWebPlatformModelLoader(
            std::move(ml_service_receiver), std::move(options),
            base::BindOnce(&ContextImplCrOS::OnModelLoaderCreated,
                           weak_factory_.GetWeakPtr(),
                           std::move(webnn_service_remote),
                           std::move(model_content), std::move(callback)));
  } else {
    model_loader_remote_->Load(
        mojo_base::BigBuffer(base::make_span(model_content)),
        std::move(callback));
  }
}

void ContextImplCrOS::OnModelLoaderCreated(
    mojo::PendingRemote<ml::model_loader::mojom::ModelLoader>
        webnn_service_remote,
    flatbuffers::DetachedBuffer model_content,
    ml::model_loader::mojom::ModelLoader::LoadCallback callback,
    ml::model_loader::mojom::CreateModelLoaderResult result) {
  switch (result) {
    case ml::model_loader::mojom::CreateModelLoaderResult::kUnknownError: {
      std::move(callback).Run(
          ml::model_loader::mojom::LoadModelResult::kUnknownError,
          mojo::NullRemote(), nullptr);
      return;
    }
    case ml::model_loader::mojom::CreateModelLoaderResult::kNotSupported: {
      std::move(callback).Run(
          ml::model_loader::mojom::LoadModelResult::kNotSupported,
          mojo::NullRemote(), nullptr);
      return;
    }
    case ml::model_loader::mojom::CreateModelLoaderResult::kOk: {
      model_loader_remote_.Bind(std::move(webnn_service_remote));

      model_loader_remote_->Load(
          mojo_base::BigBuffer(base::make_span(model_content)),
          std::move(callback));
      return;
    }
  }
}

void ContextImplCrOS::CreateGraphImpl(
    mojom::GraphInfoPtr graph_info,
    WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
    CreateGraphImplCallback callback) {
  GraphImplCrOS::CreateAndBuild(this, std::move(graph_info),
                                std::move(compute_resource_info),
                                std::move(callback));
}

void ContextImplCrOS::CreateBufferImpl(
    mojo::PendingAssociatedReceiver<mojom::WebNNBuffer> receiver,
    mojom::BufferInfoPtr buffer_info,
    CreateBufferImplCallback callback) {
  std::move(callback).Run(BufferImplTflite::Create(std::move(receiver), this,
                                                   std::move(buffer_info)));
}

}  // namespace webnn::tflite