chromium/services/webnn/tflite/graph_impl_cros.cc

// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "services/webnn/tflite/graph_impl_cros.h"

#include <vector>

#include "mojo/public/cpp/bindings/self_owned_associated_receiver.h"
#include "services/webnn/error.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_error.mojom.h"
#include "services/webnn/tflite/context_impl_cros.h"
#include "services/webnn/tflite/graph_builder_tflite.h"
#include "services/webnn/webnn_graph_impl.h"
#include "third_party/flatbuffers/src/include/flatbuffers/flatbuffers.h"

namespace webnn::tflite {

// static
void GraphImplCrOS::CreateAndBuild(
    ContextImplCrOS* context_impl,
    mojom::GraphInfoPtr graph_info,
    ComputeResourceInfo compute_resource_info,
    WebNNContextImpl::CreateGraphImplCallback callback) {
  base::expected<flatbuffers::DetachedBuffer, std::string> conversion_result =
      GraphBuilderTflite::CreateAndBuild(context_impl->properties(),
                                         *graph_info);
  if (!conversion_result.has_value()) {
    std::move(callback).Run(base::unexpected(mojom::Error::New(
        mojom::Error::Code::kUnknownError, conversion_result.error())));
    return;
  }

  context_impl->LoadModel(
      std::move(conversion_result.value()),
      base::BindOnce(
          [](base::WeakPtr<WebNNContextImpl> context,
             ComputeResourceInfo compute_resource_info,
             WebNNContextImpl::CreateGraphImplCallback callback,
             ml::model_loader::mojom::LoadModelResult result,
             mojo::PendingRemote<ml::model_loader::mojom::Model> pending_remote,
             ml::model_loader::mojom::ModelInfoPtr tensor_info) {
            if (!context ||
                result != ml::model_loader::mojom::LoadModelResult::kOk) {
              std::move(callback).Run(base::unexpected(
                  mojom::Error::New(mojom::Error::Code::kUnknownError,
                                    "Failed to load model with ml service.")));
              return;
            }

            // TODO(crbug.com/330806169): Pass `WebNNGraph` directly to ML
            // Service and not have to bounce through the browser process.
            std::move(callback).Run(base::WrapUnique(new GraphImplCrOS(
                static_cast<ContextImplCrOS*>(context.get()),
                std::move(compute_resource_info), std::move(pending_remote))));
          },
          context_impl->AsWeakPtr(), std::move(compute_resource_info),
          std::move(callback)));
}

GraphImplCrOS::~GraphImplCrOS() = default;

GraphImplCrOS::GraphImplCrOS(
    ContextImplCrOS* context_impl,
    ComputeResourceInfo compute_resource_info,
    mojo::PendingRemote<ml::model_loader::mojom::Model> pending_remote)
    : WebNNGraphImpl(context_impl, std::move(compute_resource_info)) {
  model_remote_.Bind(std::move(pending_remote));
}

void GraphImplCrOS::ComputeImpl(
    base::flat_map<std::string, mojo_base::BigBuffer> named_inputs,
    mojom::WebNNGraph::ComputeCallback callback) {
  // TODO(crbug.com/330856251): Optimize inference time with shared memory.
  std::vector<std::pair<std::string, std::vector<uint8_t>>> input_tensors;
  input_tensors.reserve(named_inputs.size());
  for (const auto& [name, buffer] : named_inputs) {
    input_tensors.emplace_back(
        name, std::vector<uint8_t>(buffer.begin(), buffer.end()));
  }

  model_remote_->Compute(
      std::move(input_tensors),
      base::BindOnce(
          [](mojom::WebNNGraph::ComputeCallback callback,
             ml::model_loader::mojom::ComputeResult compute_result,
             const std::optional<base::flat_map<
                 std::string, std::vector<uint8_t>>>& output_tensors) {
            if (compute_result != ml::model_loader::mojom::ComputeResult::kOk ||
                !output_tensors.has_value()) {
              std::move(callback).Run(ToError<mojom::ComputeResult>(
                  mojom::Error::Code::kUnknownError,
                  "Failed to obtain the computation result."));
              return;
            }
            std::vector<std::pair<std::string, mojo_base::BigBuffer>>
                named_outputs;
            named_outputs.reserve(output_tensors->size());
            for (const auto& [name, buffer] : *output_tensors) {
              named_outputs.emplace_back(name, mojo_base::BigBuffer(buffer));
            }

            std::move(callback).Run(mojom::ComputeResult::NewNamedOutputs(
                std::move(named_outputs)));
          },
          std::move(callback)));
}

void GraphImplCrOS::DispatchImpl(
    const base::flat_map<std::string_view, WebNNBufferImpl*>& named_inputs,
    const base::flat_map<std::string_view, WebNNBufferImpl*>& named_outputs) {
  // TODO(crbug.com/40278771): Implement MLBuffer for TFLite. Involve
  // an IPC security reviewer.
  NOTIMPLEMENTED();
}

}  // namespace webnn::tflite