#include "third_party/blink/renderer/modules/ml/webnn/ml_graph.h"
#include <cinttypes>
#include "base/containers/span.h"
#include "base/functional/callback.h"
#include "base/numerics/checked_math.h"
#include "base/types/expected_macros.h"
#include "mojo/public/cpp/base/big_buffer.h"
#include "services/webnn/public/cpp/graph_validation_utils.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom-blink.h"
#include "services/webnn/public/mojom/webnn_graph.mojom-blink.h"
#include "third_party/blink/renderer/bindings/core/v8/script_promise_resolver.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_compute_result.h"
#include "third_party/blink/renderer/core/dom/dom_exception.h"
#include "third_party/blink/renderer/core/execution_context/execution_context.h"
#include "third_party/blink/renderer/core/typed_arrays/dom_array_buffer_view.h"
#include "third_party/blink/renderer/modules/ml/ml_context.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_buffer.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_error.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_utils.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_operand.h"
#include "third_party/blink/renderer/platform/bindings/exception_state.h"
#include "third_party/blink/renderer/platform/heap/collection_support/heap_hash_set.h"
#include "third_party/blink/renderer/platform/heap/persistent.h"
namespace blink {
namespace {
#define THROW_AND_RETURN_IF_ERROR(func, msg) …
#define THROW_AND_RETURN_EMPTY_PROMISE_IF_ERROR(func, msg) …
base::expected<void, String> ValidateNamedArrayBufferViews(
const MLNamedArrayBufferViews& named_array_buffer_views,
const MLGraph::NamedOperandDescriptors& expected_named_descriptors) { … }
base::expected<void, String> ValidateNamedMLBuffers(
const MLContext* context,
const MLNamedBuffers& named_buffers,
const MLGraph::NamedOperandDescriptors& expected_named_descriptors) { … }
base::expected<void, String> ValidateMLBufferUsage(
const MLNamedBuffers& named_inputs,
const MLNamedBuffers& named_outputs) { … }
}
MLGraph::MLGraph(ExecutionContext* execution_context,
MLContext* context,
mojo::PendingAssociatedRemote<webnn::mojom::blink::WebNNGraph>
pending_graph_remote,
NamedOperandDescriptors input_constraints,
NamedOperandDescriptors output_constraints,
base::PassKey<MLGraphBuilder> )
: … { … }
MLGraph::~MLGraph() = default;
void MLGraph::Trace(Visitor* visitor) const { … }
void MLGraph::destroy() { … }
const MLGraph::NamedOperandDescriptors& MLGraph::GetInputConstraints() const { … }
const MLGraph::NamedOperandDescriptors& MLGraph::GetOutputConstraints() const { … }
ScriptPromise<MLComputeResult> MLGraph::Compute(
ScopedMLTrace scoped_trace,
const MLNamedArrayBufferViews& inputs,
const MLNamedArrayBufferViews& outputs,
ScriptState* script_state,
ExceptionState& exception_state) { … }
void MLGraph::Dispatch(ScopedMLTrace scoped_trace,
const MLNamedBuffers& inputs,
const MLNamedBuffers& outputs,
ExceptionState& exception_state) { … }
const MLContext* MLGraph::Context() const { … }
void MLGraph::DidCompute(
ScopedMLTrace scoped_trace,
ScriptPromiseResolver<MLComputeResult>* resolver,
std::unique_ptr<Vector<std::pair<String, ArrayBufferViewInfo>>> inputs_info,
std::unique_ptr<Vector<std::pair<String, ArrayBufferViewInfo>>>
outputs_info,
webnn::mojom::blink::ComputeResultPtr mojo_result) { … }
void MLGraph::OnConnectionError() { … }
}