#include "services/webnn/webnn_graph_impl.h"
#include <cmath>
#include <limits>
#include "base/containers/contains.h"
#include "base/memory/weak_ptr.h"
#include "base/notreached.h"
#include "base/strings/strcat.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/stringprintf.h"
#include "base/test/bind.h"
#include "base/test/run_until.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "mojo/public/cpp/bindings/associated_remote.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "mojo/public/cpp/bindings/self_owned_associated_receiver.h"
#include "mojo/public/cpp/system/functions.h"
#include "services/webnn/error.h"
#include "services/webnn/public/cpp/ml_buffer_usage.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/cpp/supported_data_types.h"
#include "services/webnn/public/cpp/webnn_errors.h"
#include "services/webnn/public/mojom/features.mojom-features.h"
#include "services/webnn/public/mojom/webnn_buffer.mojom.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_graph.mojom.h"
#include "services/webnn/public/mojom/webnn_graph_builder.mojom.h"
#include "services/webnn/webnn_buffer_impl.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_context_provider_impl.h"
#include "services/webnn/webnn_graph_builder_impl.h"
#include "services/webnn/webnn_test_utils.h"
#include "services/webnn/webnn_utils.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace webnn {
namespace {
class FakeWebNNGraphImpl final : public WebNNGraphImpl { … };
class FakeWebNNBufferImpl final : public WebNNBufferImpl { … };
class FakeWebNNContextImpl final : public WebNNContextImpl { … };
class FakeWebNNBackend : public WebNNContextProviderImpl::BackendForTesting { … };
bool ValidateInputsForComputing(
mojom::GraphInfoPtr graph_info,
base::flat_map<std::string, mojo_base::BigBuffer> inputs) { … }
struct CreateBufferSuccess { … };
CreateBufferSuccess CreateWebNNBuffer(
mojo::Remote<mojom::WebNNContext>& webnn_context,
OperandDataType data_type,
std::vector<uint32_t> shape) { … }
mojo::Remote<mojom::WebNNContext> CreateWebNNContext(
mojo::Remote<mojom::WebNNContextProvider>& webnn_context_provider) { … }
bool ValidateDispatch(
mojo::Remote<mojom::WebNNContext>& webnn_context,
mojom::GraphInfoPtr graph_info,
base::flat_map<std::string, CreateBufferSuccess> inputs,
base::flat_map<std::string, CreateBufferSuccess> outputs) { … }
OperandDataType kAllOperandDataTypes[] = …;
}
class WebNNGraphImplTest : public testing::Test { … };
struct OperandInfo { … };
struct ArgMinMaxTester { … };
TEST_F(WebNNGraphImplTest, ArgMinMaxTest) { … }
struct ClampTester { … };
TEST_F(WebNNGraphImplTest, ClampTest) { … }
struct HardSigmoidTester { … };
TEST_F(WebNNGraphImplTest, HardSigmoidTest) { … }
struct BatchNormalizationTester { … };
TEST_F(WebNNGraphImplTest, BatchNormalizationTest) { … }
struct ConcatTester { … };
TEST_F(WebNNGraphImplTest, ConcatTest) { … }
struct Conv2dTester { … };
TEST_F(WebNNGraphImplTest, Conv2dTest) { … }
TEST_F(WebNNGraphImplTest, ConvTranspose2dTest) { … }
struct ElementWiseBinaryTester { … };
TEST_F(WebNNGraphImplTest, ElementWiseBinaryTest) { … }
TEST_F(WebNNGraphImplTest, ElementWiseBinaryLogicalTest) { … }
struct ElementWiseUnaryTester { … };
class ElementWiseUnaryDataTypeFixture
: public testing::TestWithParam<
std::tuple<std::pair<mojom::ElementWiseUnary::Kind,
std::vector<OperandDataType>>,
OperandDataType,
OperandDataType>> { … };
TEST_P(ElementWiseUnaryDataTypeFixture, TestUnaryOperandDataTypeSupport) { … }
TEST_P(ElementWiseUnaryDataTypeFixture, TestUnaryOperandScalarDataTypeSupport) { … }
INSTANTIATE_TEST_SUITE_P(…);
TEST_F(WebNNGraphImplTest, ElementWiseUnaryTest) { … }
struct EluTester { … };
TEST_F(WebNNGraphImplTest, EluTest) { … }
struct ExpandTester { … };
TEST_F(WebNNGraphImplTest, ExpandTest) { … }
struct GatherAttributes { … };
struct GatherTester { … };
TEST_F(WebNNGraphImplTest, GatherTest) { … }
struct GatherElementsTester { … };
TEST_F(WebNNGraphImplTest, GatherElementsTest) { … }
struct GeluTester { … };
TEST_F(WebNNGraphImplTest, GeluTest) { … }
struct GemmTester { … };
TEST_F(WebNNGraphImplTest, GemmTest) { … }
struct GruTester { … };
TEST_F(WebNNGraphImplTest, GruTest) { … }
struct GruCellTester { … };
TEST_F(WebNNGraphImplTest, GruCellTest) { … }
struct InstanceNormalizationTester { … };
TEST_F(WebNNGraphImplTest, InstanceNormalizationTest) { … }
struct LayerNormalizationTester { … };
TEST_F(WebNNGraphImplTest, LayerNormalizationTest) { … }
struct LstmTester { … };
TEST_F(WebNNGraphImplTest, LstmTest) { … }
struct LstmCellTester { … };
TEST_F(WebNNGraphImplTest, LstmCellTest) { … }
struct MatmulTester { … };
TEST_F(WebNNGraphImplTest, MatmulTest) { … }
struct PadTester { … };
TEST_F(WebNNGraphImplTest, PadTest) { … }
struct Pool2dTester { … };
TEST_F(WebNNGraphImplTest, Pool2dTest) { … }
struct PreluTester { … };
TEST_F(WebNNGraphImplTest, PreluTest) { … }
struct ReduceTester { … };
TEST_F(WebNNGraphImplTest, ReduceTest) { … }
struct ReluTester { … };
TEST_F(WebNNGraphImplTest, ReluTest) { … }
struct Resample2dTester { … };
TEST_F(WebNNGraphImplTest, Resample2dTest) { … }
struct ReshapeTester { … };
TEST_F(WebNNGraphImplTest, ReshapeTest) { … }
struct SliceTester { … };
TEST_F(WebNNGraphImplTest, SliceTest) { … }
enum class FloatingPointUnaryKind { … };
struct FloatingPointUnaryTester { … };
TEST_F(WebNNGraphImplTest, FloatingPointUnaryTest) { … }
struct SoftmaxTester { … };
TEST_F(WebNNGraphImplTest, SoftmaxTest) { … }
struct SoftplusTester { … };
TEST_F(WebNNGraphImplTest, SoftplusTest) { … }
struct SoftsignTester { … };
TEST_F(WebNNGraphImplTest, SoftsignTest) { … }
struct SplitTester { … };
TEST_F(WebNNGraphImplTest, ValidateSplitTest) { … }
struct TransposeTester { … };
TEST_F(WebNNGraphImplTest, TransposeTest) { … }
struct TriangularTester { … };
TEST_F(WebNNGraphImplTest, TriangularTest) { … }
struct WhereTester { … };
TEST_F(WebNNGraphImplTest, WhereTest) { … }
TEST_F(WebNNGraphImplTest, ValidateInputsTest) { … }
TEST_F(WebNNGraphImplTest, ValidateDispatchTest) { … }
struct ConstantOperandTester { … };
TEST_F(WebNNGraphImplTest, ValidateConstantOperandTest) { … }
TEST_F(WebNNGraphImplTest, BuildMultipleInputsAppendingConstants) { … }
TEST_F(WebNNGraphImplTest, BuildMultipleConstantsAppendingInputs) { … }
TEST_F(WebNNGraphImplTest, BuildOperationWithNonexistentInputs) { … }
}