#include "third_party/blink/renderer/modules/ml/ml_context.h"
#include "base/feature_list.h"
#include "base/numerics/checked_math.h"
#include "base/types/expected_macros.h"
#include "base/types/pass_key.h"
#include "services/webnn/public/cpp/context_properties.h"
#include "services/webnn/public/cpp/graph_validation_utils.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/cpp/supported_data_types.h"
#include "services/webnn/public/cpp/webnn_errors.h"
#include "services/webnn/public/mojom/features.mojom-blink.h"
#include "services/webnn/public/mojom/webnn_buffer.mojom-blink.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom-blink.h"
#include "services/webnn/public/mojom/webnn_graph_builder.mojom-blink.h"
#include "third_party/blink/public/platform/task_type.h"
#include "third_party/blink/renderer/bindings/core/v8/script_promise.h"
#include "third_party/blink/renderer/bindings/core/v8/script_promise_resolver.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_binary_support_limits.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_buffer_descriptor.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_concat_support_limits.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_context_lost_info.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_context_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_device_type.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gather_support_limits.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gemm_support_limits.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_logical_not_support_limits.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_op_support_limits.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operand_data_type.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_power_preference.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_prelu_support_limits.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_single_input_support_limits.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_support_limits.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_where_support_limits.h"
#include "third_party/blink/renderer/core/execution_context/execution_context.h"
#include "third_party/blink/renderer/core/typed_arrays/array_buffer_view_helpers.h"
#include "third_party/blink/renderer/modules/ml/ml_trace.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_buffer.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_error.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_utils.h"
#include "third_party/blink/renderer/platform/bindings/exception_code.h"
#include "third_party/blink/renderer/platform/bindings/exception_state.h"
namespace blink {
namespace {
MLSupportLimits* SupportedDataTypesToSupportLimits(
const webnn::SupportedDataTypes& supported_data_types) { … }
blink::V8MLInputOperandLayout::Enum InputOperandLayoutToBlink(
webnn::InputOperandLayout layout) { … }
}
MLContext::MLContext(
ExecutionContext* execution_context,
const V8MLDeviceType device_type,
const V8MLPowerPreference power_preference,
const unsigned int num_threads,
webnn::mojom::blink::CreateContextSuccessPtr create_context_success)
: … { … }
MLContext::~MLContext() = default;
V8MLDeviceType MLContext::GetDeviceType() const { … }
V8MLPowerPreference MLContext::GetPowerPreference() const { … }
unsigned int MLContext::GetNumThreads() const { … }
void MLContext::Trace(Visitor* visitor) const { … }
ScriptPromise<MLContextLostInfo> MLContext::lost(ScriptState* script_state) { … }
void MLContext::destroy(ScriptState* script_state,
ExceptionState& exception_state) { … }
ScriptPromise<MLComputeResult> MLContext::compute(
ScriptState* script_state,
MLGraph* graph,
const MLNamedArrayBufferViews& inputs,
const MLNamedArrayBufferViews& outputs,
ExceptionState& exception_state) { … }
MLGraphBuilder* MLContext::CreateWebNNGraphBuilder(
ScriptState* script_state,
ExceptionState& exception_state) { … }
void MLContext::OnLost(uint32_t custom_reason, const std::string& description) { … }
const MLOpSupportLimits* MLContext::opSupportLimits(ScriptState* script_state) { … }
void MLContext::OnGraphCreated(MLGraph* graph) { … }
ScriptPromise<MLBuffer> MLContext::createBuffer(
ScriptState* script_state,
const MLBufferDescriptor* descriptor,
ExceptionState& exception_state) { … }
void MLContext::writeBuffer(
ScriptState* script_state,
MLBuffer* dst_buffer,
const MaybeShared<DOMArrayBufferView>& src_data_view,
uint64_t src_element_offset,
ExceptionState& exception_state) { … }
void MLContext::writeBuffer(
ScriptState* script_state,
MLBuffer* dst_buffer,
const MaybeShared<DOMArrayBufferView>& src_data_view,
uint64_t src_element_offset,
uint64_t src_element_count,
ExceptionState& exception_state) { … }
void MLContext::writeBuffer(ScriptState* script_state,
MLBuffer* dst_buffer,
const DOMArrayBufferBase* src_data_base,
uint64_t src_byte_offset,
ExceptionState& exception_state) { … }
void MLContext::writeBuffer(ScriptState* script_state,
MLBuffer* dst_buffer,
const DOMArrayBufferBase* src_data_base,
uint64_t src_byte_offset,
uint64_t src_byte_size,
ExceptionState& exception_state) { … }
ScriptPromise<DOMArrayBuffer> MLContext::readBuffer(
ScriptState* script_state,
MLBuffer* src_buffer,
ExceptionState& exception_state) { … }
ScriptPromise<IDLUndefined> MLContext::readBuffer(
ScriptState* script_state,
MLBuffer* src_buffer,
DOMArrayBufferBase* dst_data,
ExceptionState& exception_state) { … }
ScriptPromise<IDLUndefined> MLContext::readBuffer(
ScriptState* script_state,
MLBuffer* src_buffer,
MaybeShared<DOMArrayBufferView> dst_data,
ExceptionState& exception_state) { … }
void MLContext::WriteWebNNBuffer(ScriptState* script_state,
MLBuffer* dst_buffer,
base::span<const uint8_t> src_data,
uint64_t src_element_offset,
unsigned src_data_type_size_bytes,
std::optional<uint64_t> src_element_count,
ExceptionState& exception_state) { … }
void MLContext::dispatch(ScriptState* script_state,
MLGraph* graph,
const MLNamedBuffers& inputs,
const MLNamedBuffers& outputs,
ExceptionState& exception_state) { … }
void MLContext::DidCreateWebNNBuffer(
ScopedMLTrace scoped_trace,
ScriptPromiseResolver<blink::MLBuffer>* resolver,
webnn::OperandDescriptor validated_descriptor,
webnn::MLBufferUsage usage,
webnn::mojom::blink::CreateBufferResultPtr result) { … }
}