#include <stdint.h>
#include <cmath>
#include <type_traits>
#include "base/containers/fixed_flat_set.h"
#include "base/notreached.h"
#include "base/run_loop.h"
#include "base/strings/string_number_conversions.h"
#include "base/test/bind.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "build/build_config.h"
#include "mojo/public/cpp/bindings/associated_remote.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "services/webnn/buildflags.h"
#include "services/webnn/public/mojom/features.mojom-features.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_graph.mojom-shared.h"
#include "services/webnn/public/mojom/webnn_graph.mojom.h"
#include "services/webnn/public/mojom/webnn_graph_builder.mojom.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_context_provider_impl.h"
#include "services/webnn/webnn_test_utils.h"
#include "services/webnn/webnn_utils.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/fp16/src/include/fp16.h"
#if BUILDFLAG(IS_WIN)
#include "base/containers/fixed_flat_map.h"
#include "services/webnn/dml/adapter.h"
#include "services/webnn/dml/command_queue.h"
#include "services/webnn/dml/command_recorder.h"
#include "services/webnn/dml/context_impl_dml.h"
#include "services/webnn/dml/graph_impl_dml.h"
#include "services/webnn/dml/test_base.h"
#include "services/webnn/dml/utils.h"
#include "third_party/microsoft_dxheaders/include/directml.h"
#include <wrl.h>
#endif
#if BUILDFLAG(IS_MAC)
#include "base/mac/mac_util.h"
#endif
#if BUILDFLAG(IS_CHROMEOS)
#include "chromeos/services/machine_learning/public/cpp/fake_service_connection.h"
#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
#endif
namespace webnn::test {
namespace {
float16;
enum class BuildAndComputeExpectation { … };
void BuildAndCompute(
mojom::GraphInfoPtr graph_info,
base::flat_map<std::string, mojo_base::BigBuffer> named_inputs,
base::flat_map<std::string, mojo_base::BigBuffer>& named_outputs,
BuildAndComputeExpectation expectation =
BuildAndComputeExpectation::kSuccess,
mojom::CreateContextOptions::Device device =
mojom::CreateContextOptions::Device::kGpu) { … }
template <typename T>
mojo_base::BigBuffer VectorToBigBuffer(const std::vector<T>& data) { … }
template <typename T>
std::vector<T> BigBufferToVector(mojo_base::BigBuffer big_buffer) { … }
void VerifyFloatDataIsEqual(base::span<const float> data,
base::span<const float> expected_data) { … }
std::vector<float16> Float16FromFloat32(const std::vector<float>& fp32_data) { … }
std::vector<float> Float16ToFloat32(const std::vector<float16>& fp16_data) { … }
std::vector<float> GetFloatOutputData(mojo_base::BigBuffer big_buffer,
OperandDataType type) { … }
template <typename T>
struct OperandInfo { … };
void VerifyIsEqual(mojo_base::BigBuffer actual,
const OperandInfo<float>& expected) { … }
template <typename T>
void VerifyIsEqual(mojo_base::BigBuffer actual,
const OperandInfo<T>& expected) { … }
}
#if BUILDFLAG(IS_WIN)
class WebNNGraphImplBackendTest : public dml::TestBase {
public:
WebNNGraphImplBackendTest()
: scoped_feature_list_(
webnn::mojom::features::kWebMachineLearningNeuralNetwork) {}
void SetUp() override;
protected:
base::test::ScopedFeatureList scoped_feature_list_;
scoped_refptr<dml::Adapter> adapter_;
};
void WebNNGraphImplBackendTest::SetUp() {
SKIP_TEST_IF(!dml::UseGPUInTests());
dml::Adapter::EnableDebugLayerForTesting();
auto adapter_creation_result = dml::Adapter::GetGpuInstanceForTesting();
SKIP_TEST_IF(!adapter_creation_result.has_value());
adapter_ = adapter_creation_result.value();
SKIP_TEST_IF(!adapter_->IsDMLDeviceCompileGraphSupportedForTesting());
auto kRequiredFeatureLevels = base::MakeFixedFlatMap<std::string_view,
DML_FEATURE_LEVEL>(
{
{"FuseStandaloneActivationIntoBatchNormalization",
DML_FEATURE_LEVEL_3_1},
{"FuseStandaloneActivationIntoGemm", DML_FEATURE_LEVEL_4_0},
{"BuildAndComputeMultipleOperatorGemm", DML_FEATURE_LEVEL_4_0},
{"BuildOneInputAndOneConstantOperand", DML_FEATURE_LEVEL_4_0},
{"BuildOneGraphToComputeMultipleTimes", DML_FEATURE_LEVEL_4_0},
{"BuildSingleOperatorLayerNormalization", DML_FEATURE_LEVEL_3_1},
{"BuildAndComputeSingleOperatorMatmul", DML_FEATURE_LEVEL_4_0},
{"FuseStandaloneOperationsIntoMatmul", DML_FEATURE_LEVEL_4_0},
{"BuildMultipleInputsAppendingConstants", DML_FEATURE_LEVEL_4_0},
{"BuildMultipleConstantsAppendingInputs", DML_FEATURE_LEVEL_4_0},
{"BuildGemmWithReshapedConstantOperand", DML_FEATURE_LEVEL_4_0},
{"BuildMaxPooingAsThirdOperator", DML_FEATURE_LEVEL_4_0},
{"BuildMaxPooingAsSecondOperator", DML_FEATURE_LEVEL_4_0},
{"BuildMaxPooingAsFirstOperator", DML_FEATURE_LEVEL_4_0}});
auto it = kRequiredFeatureLevels.find(
::testing::UnitTest::GetInstance()->current_test_info()->name());
if (it != kRequiredFeatureLevels.end()) {
const auto& required_feature_level = it->second;
SKIP_TEST_IF(!adapter_->IsDMLFeatureLevelSupported(required_feature_level));
}
}
#endif
#if BUILDFLAG(IS_MAC)
class WebNNGraphImplBackendTest : public testing::Test {
public:
WebNNGraphImplBackendTest()
: scoped_feature_list_(
webnn::mojom::features::kWebMachineLearningNeuralNetwork) {}
void SetUp() override;
protected:
base::test::ScopedFeatureList scoped_feature_list_;
base::test::TaskEnvironment task_environment_;
};
void WebNNGraphImplBackendTest::SetUp() {
if (base::mac::MacOSVersion() < 14'00'00) {
GTEST_SKIP() << "Skipping test because WebNN is not supported on Mac OS "
<< base::mac::MacOSVersion();
}
const std::string_view current_test_name =
::testing::UnitTest::GetInstance()->current_test_info()->name();
static auto kSupportedTests = base::MakeFixedFlatSet<std::string_view>({
"BuildAndComputeSingleOperatorClamp",
"BuildAndComputeConcatWithConstants",
"BuildAndComputeSingleOperatorRelu",
"BuildAndComputeSingleOperatorTanh",
"BuildAndComputeGraphWithTwoTranspose",
});
if (!kSupportedTests.contains(current_test_name)) {
GTEST_SKIP() << "Skipping test because the operator is not yet supported.";
}
}
#endif
#if BUILDFLAG(WEBNN_USE_TFLITE) && !BUILDFLAG(IS_WIN)
class WebNNGraphImplBackendTest : public testing::Test { … };
void WebNNGraphImplBackendTest::SetUp() { … }
#endif
struct FusibleOperationDescriptor { … };
void BuildFusibleOperation(GraphInfoBuilder& builder,
const FusibleOperationDescriptor& operation,
uint64_t input_operand_id,
uint64_t output_operand_id) { … }
template <typename T>
struct BatchNormalizationTester { … };
TEST_F(WebNNGraphImplBackendTest,
FuseStandaloneActivationIntoBatchNormalization) { … }
template <typename T>
struct Conv2dTester { … };
TEST_F(WebNNGraphImplBackendTest, FuseStandaloneActivationIntoConv2d) { … }
template <typename I, typename O = I>
struct ElementWiseBinaryTester { … };
TEST_F(WebNNGraphImplBackendTest,
FuseStandaloneActivationIntoElementWiseBinaryAdd) { … }
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithSplitAndReshape) { … }
template <typename T>
struct UnaryOperatorTester { … };
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorClamp) { … }
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorHardSigmoid) { … }
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorHardSwish) { … }
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorTanh) { … }
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithTwoRelu) { … }
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithReshapeAsLastNode) { … }
TEST_F(WebNNGraphImplBackendTest,
BuildAndComputeGraphWithReshapeAsIntermediateNode) { … }
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithTwoReshape) { … }
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithTwoOutputs) { … }
struct GemmAttributes { … };
template <typename T>
struct GemmTester { … };
TEST_F(WebNNGraphImplBackendTest, FuseStandaloneActivationIntoGemm) { … }
template <typename T>
struct GruTester { … };
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorGru) { … }
template <typename T>
struct GruCellTester { … };
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorGruCell) { … }
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeMultipleOperatorGemm) { … }
TEST_F(WebNNGraphImplBackendTest, BuildOneInputAndOneConstantOperand) { … }
TEST_F(WebNNGraphImplBackendTest, BuildOneGraphToComputeMultipleTimes) { … }
template <typename T>
struct InstanceNormalizationTester { … };
TEST_F(WebNNGraphImplBackendTest,
FuseStandaloneActivationIntoInstanceNormalization) { … }
template <typename T>
struct LayerNormalizationTester { … };
TEST_F(WebNNGraphImplBackendTest,
FuseStandaloneActivationIntoLayerNormalization) { … }
TEST_F(WebNNGraphImplBackendTest, BuildSingleOperatorLayerNormalization) { … }
template <typename T>
struct LstmTester { … };
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorLstm) { … }
struct LstmCellAttributes { … };
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorLstmCell) { … }
template <typename T>
struct MatmulTester { … };
TEST_F(WebNNGraphImplBackendTest, FuseStandaloneOperationsIntoMatmul) { … }
TEST_F(WebNNGraphImplBackendTest, BuildMultipleInputsAppendingConstants) { … }
TEST_F(WebNNGraphImplBackendTest, BuildMultipleConstantsAppendingInputs) { … }
TEST_F(WebNNGraphImplBackendTest, BuildGemmWithReshapedConstantOperand) { … }
TEST_F(WebNNGraphImplBackendTest, BuildAddWithReshapedConstantOperand) { … }
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeReluWithOnlyConstantInput) { … }
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeAddWithOnlyConstantInputs) { … }
TEST_F(WebNNGraphImplBackendTest,
BuildAndComputeAddAndMulWithOnlyConstantInputs) { … }
struct Pool2dAttributes { … };
TEST_F(WebNNGraphImplBackendTest, BuildMaxPooingAsThirdOperator) { … }
TEST_F(WebNNGraphImplBackendTest, BuildMaxPooingAsSecondOperator) { … }
TEST_F(WebNNGraphImplBackendTest, BuildMaxPooingAsFirstOperator) { … }
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeReshapeConcatAndClamp) { … }
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeConcatWithConstants) { … }
template <typename T>
struct Resample2dTester { … };
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorResample2d) { … }
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithTwoTranspose) { … }
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithTransposeAndRelu) { … }
TEST_F(WebNNGraphImplBackendTest,
BuildAndComputeGraphWithTransposeAndTwoReshape) { … }
TEST_F(WebNNGraphImplBackendTest,
BuildAndComputeGraphWithTransposeAndTwoOutputs) { … }
TEST_F(WebNNGraphImplBackendTest,
MultipleOutputsCanNotFuseStandaloneActivation) { … }
}