chromium/services/webnn/coreml/buffer_impl_coreml.mm

// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "services/webnn/coreml/buffer_impl_coreml.h"

#import <CoreML/CoreML.h>

#include <optional>

#include "mojo/public/cpp/base/big_buffer.h"
#include "services/webnn/coreml/context_impl_coreml.h"
#include "services/webnn/coreml/utils_coreml.h"
#include "services/webnn/public/cpp/operand_descriptor.h"

namespace webnn::coreml {

namespace {

MLMultiArrayDataType ToMLMultiArrayDataType(OperandDataType data_type) {
  switch (data_type) {
    case OperandDataType::kFloat32:
      return MLMultiArrayDataTypeFloat32;
    case OperandDataType::kFloat16:
      if (__builtin_available(macOS 12, *)) {
        return MLMultiArrayDataTypeFloat16;
      }
      NOTREACHED();
    case OperandDataType::kInt32:
      return MLMultiArrayDataTypeInt32;
    case OperandDataType::kUint32:
    case OperandDataType::kInt64:
    case OperandDataType::kUint64:
    case OperandDataType::kInt8:
    case OperandDataType::kUint8:
      // Unsupported data types for MLMultiArrays in CoreML.
      NOTREACHED();
  }
}

}  // namespace

// static
base::expected<std::unique_ptr<WebNNBufferImpl>, mojom::ErrorPtr>
BufferImplCoreml::Create(
    mojo::PendingAssociatedReceiver<mojom::WebNNBuffer> receiver,
    WebNNContextImpl* context,
    mojom::BufferInfoPtr buffer_info) {
  // TODO(crbug.com/343638938): Check `MLBufferUsageFlags` and use an
  // IOSurface to facilitate zero-copy buffer sharing with WebGPU when possible.

  // TODO(crbug.com/329482489): Move this check to the renderer and throw a
  // TypeError.
  if (buffer_info->descriptor.Rank() > 5) {
    LOG(ERROR) << "[WebNN] Buffer rank is too large.";
    return base::unexpected(mojom::Error::New(
        mojom::Error::Code::kNotSupportedError, "Buffer rank is too large."));
  }

  // Limit to INT_MAX for security reasons (similar to PartitionAlloc).
  //
  // TODO(crbug.com/356670455): Consider relaxing this restriction, especially
  // if partial reads and writes of an MLBuffer are supported.
  //
  // TODO(crbug.com/356670455): Consider moving this check to the renderer and
  // throwing a TypeError.
  if (!base::IsValueInRangeForNumericType<int>(
          buffer_info->descriptor.PackedByteLength())) {
    LOG(ERROR) << "[WebNN] Buffer is too large to create.";
    return base::unexpected(mojom::Error::New(
        mojom::Error::Code::kUnknownError, "Buffer is too large to create."));
  }

  NSMutableArray<NSNumber*>* ns_shape = [[NSMutableArray alloc] init];
  for (uint32_t dimension : buffer_info->descriptor.shape()) {
    [ns_shape addObject:[[NSNumber alloc] initWithUnsignedLong:dimension]];
  }

  NSError* error = nil;
  MLMultiArray* multi_array = [[MLMultiArray alloc]
      initWithShape:ns_shape
           dataType:ToMLMultiArrayDataType(buffer_info->descriptor.data_type())
              error:&error];
  if (error) {
    LOG(ERROR) << "[WebNN] Failed to allocate buffer: " << error;
    return base::unexpected(mojom::Error::New(mojom::Error::Code::kUnknownError,
                                              "Failed to allocate buffer."));
  }

  // `MLMultiArray` doesn't initialize its contents.
  __block bool block_executing_synchronously = true;
  [multi_array getMutableBytesWithHandler:^(void* mutable_bytes, NSInteger size,
                                            NSArray<NSNumber*>* strides) {
    // TODO(crbug.com/333392274): Refactor this method to assume the handler may
    // be invoked on some other thread. We should not assume that the block
    // will always run synchronously.
    CHECK(block_executing_synchronously);
    memset(mutable_bytes, 0, size);
  }];
  block_executing_synchronously = false;

  return base::WrapUnique(
      new BufferImplCoreml(std::move(receiver), context, std::move(buffer_info),
                           multi_array, base::PassKey<BufferImplCoreml>()));
}

BufferImplCoreml::BufferImplCoreml(
    mojo::PendingAssociatedReceiver<mojom::WebNNBuffer> receiver,
    WebNNContextImpl* context,
    mojom::BufferInfoPtr buffer_info,
    MLMultiArray* multi_array,
    base::PassKey<BufferImplCoreml> /*pass_key*/)
    : WebNNBufferImpl(std::move(receiver), context, std::move(buffer_info)),
      multi_array_(multi_array) {}

BufferImplCoreml::~BufferImplCoreml() = default;

void BufferImplCoreml::ReadBufferImpl(
    mojom::WebNNBuffer::ReadBufferCallback callback) {
  mojo_base::BigBuffer output_buffer(PackedByteLength());
  ReadFromMLMultiArray(multi_array_, output_buffer);

  std::move(callback).Run(
      mojom::ReadBufferResult::NewBuffer(std::move(output_buffer)));
}

void BufferImplCoreml::WriteBufferImpl(mojo_base::BigBuffer src_buffer) {
  WriteToMLMultiArray(multi_array_, src_buffer);
}

MLFeatureValue* BufferImplCoreml::AsFeatureValue() {
  return [MLFeatureValue featureValueWithMultiArray:multi_array_];
}

}  // namespace webnn::coreml