#include <array>
#include <string>
#include "base/functional/bind.h"
#include "base/functional/callback_helpers.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "build/build_config.h"
#include "build/buildflag.h"
#include "mojo/public/cpp/bindings/associated_remote.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "mojo/public/cpp/system/functions.h"
#include "services/webnn/buildflags.h"
#include "services/webnn/error.h"
#include "services/webnn/public/cpp/ml_buffer_usage.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/mojom/features.mojom-features.h"
#include "services/webnn/public/mojom/webnn_buffer.mojom.h"
#include "services/webnn/public/mojom/webnn_context.mojom.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/webnn_context_provider_impl.h"
#include "testing/gtest/include/gtest/gtest.h"
#if BUILDFLAG(IS_WIN)
#include "services/webnn/dml/adapter.h"
#include "services/webnn/dml/test_base.h"
#endif
#if BUILDFLAG(IS_MAC)
#include "base/mac/mac_util.h"
#endif
#if BUILDFLAG(IS_CHROMEOS)
#include "chromeos/services/machine_learning/public/cpp/fake_service_connection.h"
#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
#endif
namespace webnn::test {
namespace {
class BadMessageTestHelper { … };
struct CreateContextSuccess { … };
struct CreateBufferSuccess { … };
#if BUILDFLAG(IS_WIN)
class WebNNBufferImplBackendTest : public dml::TestBase {
public:
WebNNBufferImplBackendTest()
: scoped_feature_list_(
webnn::mojom::features::kWebMachineLearningNeuralNetwork) {}
void SetUp() override;
void TearDown() override;
protected:
base::expected<CreateContextSuccess, webnn::mojom::Error::Code>
CreateWebNNContext();
base::test::ScopedFeatureList scoped_feature_list_;
scoped_refptr<dml::Adapter> adapter_;
mojo::Remote<mojom::WebNNContextProvider> webnn_provider_remote_;
};
void WebNNBufferImplBackendTest::SetUp() {
SKIP_TEST_IF(!dml::UseGPUInTests());
dml::Adapter::EnableDebugLayerForTesting();
auto adapter_creation_result = dml::Adapter::GetGpuInstanceForTesting();
SKIP_TEST_IF(!adapter_creation_result.has_value());
adapter_ = adapter_creation_result.value();
SKIP_TEST_IF(!adapter_->IsDMLDeviceCompileGraphSupportedForTesting());
WebNNContextProviderImpl::CreateForTesting(
webnn_provider_remote_.BindNewPipeAndPassReceiver());
}
#elif BUILDFLAG(IS_MAC)
class WebNNBufferImplBackendTest : public testing::Test {
public:
WebNNBufferImplBackendTest()
: scoped_feature_list_(
webnn::mojom::features::kWebMachineLearningNeuralNetwork) {}
void SetUp() override;
void TearDown() override;
protected:
base::expected<CreateContextSuccess, webnn::mojom::Error::Code>
CreateWebNNContext();
base::test::ScopedFeatureList scoped_feature_list_;
base::test::TaskEnvironment task_environment_;
mojo::Remote<mojom::WebNNContextProvider> webnn_provider_remote_;
};
void WebNNBufferImplBackendTest::SetUp() {
if (base::mac::MacOSVersion() < 14'00'00) {
GTEST_SKIP() << "Skipping test because WebNN is not supported on Mac OS "
<< base::mac::MacOSVersion();
}
WebNNContextProviderImpl::CreateForTesting(
webnn_provider_remote_.BindNewPipeAndPassReceiver());
GTEST_SKIP() << "WebNNBuffer not implemented on macOS";
}
#elif BUILDFLAG(WEBNN_USE_TFLITE)
class WebNNBufferImplBackendTest : public testing::Test { … };
#endif
void WebNNBufferImplBackendTest::TearDown() { … }
base::expected<CreateContextSuccess, webnn::mojom::Error::Code>
WebNNBufferImplBackendTest::CreateWebNNContext() { … }
base::expected<CreateBufferSuccess, webnn::mojom::Error::Code>
CreateWebNNBuffer(mojo::Remote<mojom::WebNNContext>& webnn_context_remote,
mojom::BufferInfoPtr buffer_info) { … }
bool IsBufferDataEqual(const mojo_base::BigBuffer& a,
const mojo_base::BigBuffer& b) { … }
TEST_F(WebNNBufferImplBackendTest, CreateBufferImplTest) { … }
TEST_F(WebNNBufferImplBackendTest, CreateBufferImplManyTest) { … }
TEST_F(WebNNBufferImplBackendTest, WriteBufferImplTest) { … }
TEST_F(WebNNBufferImplBackendTest, WriteBufferImplTooLargeTest) { … }
TEST_F(WebNNBufferImplBackendTest, CreateContextImplManyTest) { … }
}
}