#include "third_party/blink/renderer/modules/ml/webnn/ml_graph.h"
#include <array>
#include <numeric>
#include <optional>
#include <utility>
#include "base/containers/fixed_flat_set.h"
#include "base/containers/span.h"
#include "base/memory/raw_ref.h"
#include "base/notreached.h"
#include "base/test/scoped_feature_list.h"
#include "mojo/public/cpp/base/big_buffer.h"
#include "mojo/public/cpp/bindings/pending_associated_receiver.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "mojo/public/cpp/bindings/self_owned_associated_receiver.h"
#include "mojo/public/cpp/bindings/self_owned_receiver.h"
#include "mojo/public/cpp/bindings/unique_associated_receiver_set.h"
#include "mojo/public/cpp/system/message_pipe.h"
#include "services/webnn/public/cpp/context_properties.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/mojom/features.mojom-blink.h"
#include "services/webnn/public/mojom/webnn_buffer.mojom-blink.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom-blink.h"
#include "services/webnn/public/mojom/webnn_graph.mojom-blink.h"
#include "services/webnn/public/mojom/webnn_graph_builder.mojom-blink.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/blink/public/platform/browser_interface_broker_proxy.h"
#include "third_party/blink/renderer/bindings/core/v8/native_value_traits.h"
#include "third_party/blink/renderer/bindings/core/v8/native_value_traits_impl.h"
#include "third_party/blink/renderer/bindings/core/v8/script_promise.h"
#include "third_party/blink/renderer/bindings/core/v8/script_promise_tester.h"
#include "third_party/blink/renderer/bindings/core/v8/script_value.h"
#include "third_party/blink/renderer/bindings/core/v8/v8_binding_for_testing.h"
#include "third_party/blink/renderer/bindings/core/v8/v8_dom_exception.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_buffer_descriptor.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_buffer_usage.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_clamp_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_compute_result.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_context_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_elu_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gemm_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_hard_sigmoid_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_leaky_relu_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_linear_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operand_data_type.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operator_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_recurrent_network_activation.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_triangular_options.h"
#include "third_party/blink/renderer/core/dom/dom_exception.h"
#include "third_party/blink/renderer/core/typed_arrays/array_buffer_view_helpers.h"
#include "third_party/blink/renderer/core/typed_arrays/dom_array_buffer.h"
#include "third_party/blink/renderer/core/typed_arrays/dom_array_buffer_view.h"
#include "third_party/blink/renderer/core/typed_arrays/dom_typed_array.h"
#include "third_party/blink/renderer/modules/ml/ml.h"
#include "third_party/blink/renderer/modules/ml/ml_context.h"
#include "third_party/blink/renderer/modules/ml/ml_trace.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_buffer.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_builder_test_utils.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_type_converter.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_utils.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_operand.h"
#include "third_party/blink/renderer/platform/bindings/exception_code.h"
#include "third_party/blink/renderer/platform/bindings/v8_binding.h"
#include "third_party/blink/renderer/platform/heap/garbage_collected.h"
#include "third_party/blink/renderer/platform/heap/persistent.h"
#include "third_party/blink/renderer/platform/testing/task_environment.h"
#include "third_party/blink/renderer/platform/wtf/functional.h"
#include "third_party/blink/renderer/platform/wtf/hash_map.h"
#include "third_party/blink/renderer/platform/wtf/text/wtf_string.h"
#include "third_party/blink/renderer/platform/wtf/vector.h"
#include "third_party/blink/renderer/platform/wtf/wtf_size_t.h"
namespace blink {
blink_mojom;
class FakeWebNNBuffer;
namespace {
struct BuildResult { … };
struct ComputeResult { … };
template <typename T>
struct OperandInfo { … };
webnn::OperandDescriptor ToDescriptor(webnn::OperandDataType data_type,
base::span<const uint32_t> shape) { … }
template <typename T>
T* V8ToObject(V8TestingScope* scope, ScriptValue value) { … }
String ExceptionCodeToString(ExceptionCode exception_code) { … }
std::pair<String, String> GetErrorNameAndMessage(V8TestingScope* scope,
ScriptValue value) { … }
template <typename T>
void SetArrayBufferViewValues(NotShared<DOMArrayBufferView> array_buffer_view,
const Vector<T>& values) { … }
NotShared<DOMArrayBufferView> CreateArrayBufferViewForOperand(
const MLOperand* operand) { … }
template <typename T>
NotShared<DOMArrayBufferView> CreateArrayBufferViewForOperand(
const MLOperand* operand,
const Vector<T>& values) { … }
template <typename T>
Vector<T> GetArrayBufferViewValues(
NotShared<DOMArrayBufferView> array_buffer_view) { … }
MLContext* CreateContext(V8TestingScope& scope, MLContextOptions* options) { … }
std::pair<String, String> ComputeGraph(V8TestingScope& scope,
MLGraph* graph,
MLNamedArrayBufferViews& inputs,
MLNamedArrayBufferViews& outputs) { … }
template <typename T>
MLOperand* BuildConstant(MLGraphBuilder* builder,
const Vector<uint32_t>& dimensions,
V8MLOperandDataType::Enum data_type,
const Vector<T>& values,
ExceptionState& exception_state) { … }
MLOperand* BuildConv2d(
V8TestingScope& scope,
MLGraphBuilder* builder,
const MLOperand* input,
const MLOperand* filter,
const MLConv2dOptions* options = MLConv2dOptions::Create()) { … }
MLOperand* BuildGemm(V8TestingScope& scope,
MLGraphBuilder* builder,
const MLOperand* a,
const MLOperand* b,
const MLGemmOptions* options = MLGemmOptions::Create()) { … }
MLOperand* BuildElementWiseBinaryOperator(
MLGraphBuilder* builder,
V8TestingScope& scope,
const MLOperand* a,
const MLOperand* b,
webnn::mojom::blink::ElementWiseBinary::Kind kind,
const MLOperatorOptions* options) { … }
MLOperand* BuildElementWiseBinary(
V8TestingScope& scope,
MLGraphBuilder* builder,
webnn::mojom::blink::ElementWiseBinary::Kind kind,
const MLOperand* a,
const MLOperand* b,
const MLOperatorOptions* options = MLOperatorOptions::Create()) { … }
}
class MLGraphTest : public testing::Test { … };
class WebNNContextHelper { … };
class FakeWebNNGraph : public blink_mojom::WebNNGraph { … };
class FakeWebNNBuffer : public blink_mojom::WebNNBuffer { … };
class FakeWebNNGraphBuilder : public blink_mojom::WebNNGraphBuilder { … };
class FakeWebNNContext : public blink_mojom::WebNNContext { … };
class FakeWebNNContextProvider : public blink_mojom::WebNNContextProvider { … };
class ScopedWebNNServiceBinder { … };
ScriptPromise<MLGraph> BuildSimpleGraph(V8TestingScope& scope,
MLContextOptions* context_options) { … }
bool IsBufferDataEqual(DOMArrayBuffer* array_buffer,
base::span<const uint8_t> expected_data) { … }
MaybeShared<DOMArrayBufferView> CreateArrayBufferViewFromBytes(
DOMArrayBuffer* array_buffer,
base::span<const uint8_t> data) { … }
bool DownloadMLBufferAndCheck(V8TestingScope& scope,
MLContext* context,
MLBuffer* src_buffer,
base::span<const uint8_t> expected_data) { … }
MLBuffer* CreateMLBufferForOperand(V8TestingScope& scope,
MLContext* ml_context,
const MLOperand* operand) { … }
Vector<uint8_t> GetMLBufferValues(V8TestingScope& scope,
MLContext* ml_context,
MLBuffer* ml_buffer) { … }
TEST_F(MLGraphTest, BuildTest) { … }
struct ArrayBufferViewHelper { … };
TEST_F(MLGraphTest, CreateNamedArrayBufferViewsTest) { … }
TEST_F(MLGraphTest, ComputeTest) { … }
TEST_F(MLGraphTest, CreateWebNNBufferTest) { … }
TEST_F(MLGraphTest, WriteWebNNBufferTest) { … }
TEST_F(MLGraphTest, WriteWebNNBufferThenDestroyTest) { … }
TEST_F(MLGraphTest, ReadWebNNBufferThenDestroyTest) { … }
TEST_F(MLGraphTest, WebNNGraphDispatchTest) { … }
TEST_F(MLGraphTest, CreateWebNNGraphTest) { … }
struct ClampOptions { … };
struct SoftmaxTester { … };
TEST_F(MLGraphTest, SoftmaxTest) { … }
template <typename T>
struct ConstantTester { … };
TEST_F(MLGraphTest, ConstantTest) { … }
struct CastTester { … };
TEST_F(MLGraphTest, CastTester) { … }
TEST_F(MLGraphTest, WebNNGraphComputeTest) { … }
}