chromium/services/webnn/coreml/utils_coreml.mm

// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "services/webnn/coreml/utils_coreml.h"

#include <CoreML/CoreML.h>

#include <vector>

#include "base/compiler_specific.h"
#include "base/containers/span.h"
#include "base/containers/span_reader.h"
#include "base/containers/span_writer.h"
#include "base/functional/bind.h"
#include "base/functional/callback_helpers.h"
#include "base/memory/ref_counted.h"
#include "base/task/sequenced_task_runner.h"
#include "mojo/public/cpp/base/big_buffer.h"

namespace webnn::coreml {

namespace {

uint32_t GetDataTypeByteSize(MLMultiArrayDataType data_type) {
  switch (data_type) {
    case MLMultiArrayDataTypeDouble:
      return 8;
    case MLMultiArrayDataTypeFloat32:
    case MLMultiArrayDataTypeInt32:
      return 4;
    case MLMultiArrayDataTypeFloat16:
      return 2;
  }
}

std::vector<uint32_t> ToStdVector(NSArray<NSNumber*>* ns_array) {
  std::vector<uint32_t> std_vector;
  std_vector.reserve(ns_array.count);
  for (NSNumber* number in ns_array) {
    std_vector.push_back(number.unsignedIntegerValue);
  }
  return std_vector;
}

// Extract data from an `MLMultiArray` - which may not be contiguous - using its
// `shape` and `strides` as appropriate.
void RecursivelyReadFromMLMultiArray(
    base::span<const uint8_t> multi_array_backed_input_buffer,
    uint32_t item_byte_size,
    base::span<const uint32_t> shape,
    base::span<const uint32_t> strides,
    base::span<uint8_t> output_buffer) {
  // Data is packed, copy the whole thing.
  //
  // On the last dimension, the bytes left to read could be more than the bytes
  // left to write because of strides from the previous dimension, but as long
  // as the current stride is 1, we can copy continously.
  if (multi_array_backed_input_buffer.size() == output_buffer.size() ||
      (shape.size() == 1 && strides[0] == 1)) {
    output_buffer.copy_from(
        multi_array_backed_input_buffer.first(output_buffer.size()));
    return;
  }

  CHECK_EQ(output_buffer.size() % shape[0], 0u);
  size_t subspan_size = output_buffer.size() / shape[0];

  base::SpanReader<const uint8_t> reader(multi_array_backed_input_buffer);
  base::SpanWriter<uint8_t> writer(output_buffer);
  for (uint32_t i = 0; i < shape[0]; i++) {
    auto output_subspan = writer.Skip(subspan_size);
    CHECK(output_subspan);
    auto input_subspan = reader.Read(strides[0] * item_byte_size);
    CHECK(input_subspan);
    if (shape.size() == 1) {
      output_subspan->copy_from(input_subspan->first(item_byte_size));
    } else {
      RecursivelyReadFromMLMultiArray(*input_subspan, item_byte_size,
                                      shape.subspan(1u), strides.subspan(1u),
                                      *output_subspan);
    }
  }
}

// Copy data to an `MLMultiArray` - which may not be contiguous - using its
// `shape` and `strides` as appropriate.
void RecursivelyWriteToMLMultiArray(
    base::span<const uint8_t> input_buffer,
    uint32_t item_byte_size,
    base::span<const uint32_t> shape,
    base::span<const uint32_t> strides,
    base::span<uint8_t> multi_array_backed_output_buffer) {
  // Data is packed, copy the whole thing.
  //
  // On the last dimension, the bytes left to write could be more than the bytes
  // left to read because of strides from the previous dimension, but as long as
  // the current stride is 1, we can copy continously.
  if (input_buffer.size() == multi_array_backed_output_buffer.size() ||
      (shape.size() == 1 && strides[0] == 1)) {
    multi_array_backed_output_buffer.copy_prefix_from(input_buffer);
    return;
  }

  CHECK_EQ(input_buffer.size() % shape[0], 0u);
  size_t subspan_size = input_buffer.size() / shape[0];

  base::SpanReader<const uint8_t> reader(input_buffer);
  base::SpanWriter<uint8_t> writer(multi_array_backed_output_buffer);
  for (uint32_t i = 0; i < shape[0]; i++) {
    auto output_subspan = writer.Skip(strides[0] * item_byte_size);
    CHECK(output_subspan);
    auto input_subspan = reader.Read(subspan_size);
    CHECK(input_subspan);
    if (shape.size() == 1) {
      output_subspan->copy_from(input_subspan->first(item_byte_size));
    } else {
      RecursivelyWriteToMLMultiArray(*input_subspan, item_byte_size,
                                     shape.subspan(1u), strides.subspan(1u),
                                     *output_subspan);
    }
  }
}

}  // namespace

void ReadFromMLMultiArray(MLMultiArray* multi_array,
                          base::span<uint8_t> buffer) {
  __block bool block_executing_synchronously = true;
  [multi_array getBytesWithHandler:^(const void* bytes, NSInteger size) {
    // TODO(crbug.com/333392274): Refactor this method to assume the handler may
    // be invoked on some other thread. We should not assume that the block
    // will always run synchronously.
    CHECK(block_executing_synchronously);

    std::vector<uint32_t> shape = ToStdVector(multi_array.shape);
    std::vector<uint32_t> strides = ToStdVector(multi_array.strides);
    CHECK_EQ(shape.size(), strides.size());

    // SAFETY: -[MLMultiArray getBytesWithHandler:] guarantees that `bytes`
    // points to at least `size` valid bytes.
    auto multi_array_data = UNSAFE_BUFFERS(base::span(
        static_cast<const uint8_t*>(bytes), base::checked_cast<size_t>(size)));
    RecursivelyReadFromMLMultiArray(multi_array_data,
                                    GetDataTypeByteSize(multi_array.dataType),
                                    shape, strides, buffer);
  }];

  block_executing_synchronously = false;
}

void WriteToMLMultiArray(MLMultiArray* multi_array,
                         base::span<const uint8_t> bytes_to_write) {
  __block bool block_executing_synchronously = true;
  [multi_array getMutableBytesWithHandler:^(void* mutable_bytes, NSInteger size,
                                            NSArray<NSNumber*>* strides) {
    // TODO(crbug.com/333392274): Refactor this method to assume the handler may
    // be invoked on some other thread. We should not assume that the block
    // will always run synchronously.
    CHECK(block_executing_synchronously);

    std::vector<uint32_t> shape = ToStdVector(multi_array.shape);
    std::vector<uint32_t> std_strides = ToStdVector(strides);
    CHECK_EQ(shape.size(), std_strides.size());

    // SAFETY: -[MLMultiArray getMutableBytesWithHandler:] guarantees that
    // `mutable_bytes` points to at least `size` valid bytes.
    auto mutable_multi_array_data =
        UNSAFE_BUFFERS(base::span(static_cast<uint8_t*>(mutable_bytes),
                                  base::checked_cast<size_t>(size)));
    RecursivelyWriteToMLMultiArray(
        bytes_to_write, GetDataTypeByteSize(multi_array.dataType), shape,
        std_strides, mutable_multi_array_data);
  }];
  block_executing_synchronously = false;
}

}  // namespace webnn::coreml