chromium/services/webnn/webnn_graph_impl_backend_test.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <stdint.h>

#include <cmath>
#include <type_traits>

#include "base/containers/fixed_flat_set.h"
#include "base/notreached.h"
#include "base/run_loop.h"
#include "base/strings/string_number_conversions.h"
#include "base/test/bind.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "build/build_config.h"
#include "mojo/public/cpp/bindings/associated_remote.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "services/webnn/buildflags.h"
#include "services/webnn/public/mojom/features.mojom-features.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_graph.mojom-shared.h"
#include "services/webnn/public/mojom/webnn_graph.mojom.h"
#include "services/webnn/public/mojom/webnn_graph_builder.mojom.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_context_provider_impl.h"
#include "services/webnn/webnn_test_utils.h"
#include "services/webnn/webnn_utils.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/fp16/src/include/fp16.h"

#if BUILDFLAG(IS_WIN)
#include "base/containers/fixed_flat_map.h"
#include "services/webnn/dml/adapter.h"
#include "services/webnn/dml/command_queue.h"
#include "services/webnn/dml/command_recorder.h"
#include "services/webnn/dml/context_impl_dml.h"
#include "services/webnn/dml/graph_impl_dml.h"
#include "services/webnn/dml/test_base.h"
#include "services/webnn/dml/utils.h"
#include "third_party/microsoft_dxheaders/include/directml.h"

// Windows SDK headers should be included after DirectX headers.
#include <wrl.h>

#endif  // BUILDFLAG(IS_WIN)

#if BUILDFLAG(IS_MAC)
#include "base/mac/mac_util.h"
#endif  // BUILDFLAG(IS_MAC)

#if BUILDFLAG(IS_CHROMEOS)
#include "chromeos/services/machine_learning/public/cpp/fake_service_connection.h"
#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
#endif

namespace webnn::test {

namespace {

// Since there is no float16 data type in C++, use uint16_t to represent the
// binary data.
float16;

enum class BuildAndComputeExpectation {};
void BuildAndCompute(
    mojom::GraphInfoPtr graph_info,
    base::flat_map<std::string, mojo_base::BigBuffer> named_inputs,
    base::flat_map<std::string, mojo_base::BigBuffer>& named_outputs,
    BuildAndComputeExpectation expectation =
        BuildAndComputeExpectation::kSuccess,
    mojom::CreateContextOptions::Device device =
        mojom::CreateContextOptions::Device::kGpu) {}

template <typename T>
mojo_base::BigBuffer VectorToBigBuffer(const std::vector<T>& data) {}

template <typename T>
std::vector<T> BigBufferToVector(mojo_base::BigBuffer big_buffer) {}

void VerifyFloatDataIsEqual(base::span<const float> data,
                            base::span<const float> expected_data) {}

// Convert a vector of 32-bit floating-point data to a vector of 16-bit
// floating-point data, both in IEEE precision format.
std::vector<float16> Float16FromFloat32(const std::vector<float>& fp32_data) {}

// Convert a vector of 16-bit floating-point data to a vector of 32-bit
// floating-point data, both in IEEE precision format.
std::vector<float> Float16ToFloat32(const std::vector<float16>& fp16_data) {}

// Get the output data from a `mojo_base::BigBuffer` as 32-bit floating-point
// number.
std::vector<float> GetFloatOutputData(mojo_base::BigBuffer big_buffer,
                                      OperandDataType type) {}

template <typename T>
struct OperandInfo {};

void VerifyIsEqual(mojo_base::BigBuffer actual,
                   const OperandInfo<float>& expected) {}
template <typename T>
void VerifyIsEqual(mojo_base::BigBuffer actual,
                   const OperandInfo<T>& expected) {}
}  // namespace

#if BUILDFLAG(IS_WIN)
class WebNNGraphImplBackendTest : public dml::TestBase {
 public:
  WebNNGraphImplBackendTest()
      : scoped_feature_list_(
            webnn::mojom::features::kWebMachineLearningNeuralNetwork) {}

  void SetUp() override;

 protected:
  base::test::ScopedFeatureList scoped_feature_list_;
  scoped_refptr<dml::Adapter> adapter_;
};

void WebNNGraphImplBackendTest::SetUp() {
  SKIP_TEST_IF(!dml::UseGPUInTests());

  dml::Adapter::EnableDebugLayerForTesting();
  auto adapter_creation_result = dml::Adapter::GetGpuInstanceForTesting();
  // If the adapter creation result has no value, it's most likely because
  // platform functions were not properly loaded.
  SKIP_TEST_IF(!adapter_creation_result.has_value());
  adapter_ = adapter_creation_result.value();
  // Graph compilation relies on IDMLDevice1::CompileGraph introduced in
  // DirectML version 1.2 or DML_FEATURE_LEVEL_2_1, so skip the tests if the
  // DirectML version doesn't support this feature.
  SKIP_TEST_IF(!adapter_->IsDMLDeviceCompileGraphSupportedForTesting());

  // Skip a test if the required feature level is not supported for the
  // operator being tested.
  auto kRequiredFeatureLevels = base::MakeFixedFlatMap<std::string_view,
                                                       DML_FEATURE_LEVEL>(
      {// DML_BATCHNORMALIZATION_OPERATOR_DESC support for 1~8 dimension counts
       // was introduced in DML_FEATURE_LEVEL_3_1.
       {"FuseStandaloneActivationIntoBatchNormalization",
        DML_FEATURE_LEVEL_3_1},
       // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in
       // DML_FEATURE_LEVEL_4_0.
       {"FuseStandaloneActivationIntoGemm", DML_FEATURE_LEVEL_4_0},
       // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in
       // DML_FEATURE_LEVEL_4_0.
       {"BuildAndComputeMultipleOperatorGemm", DML_FEATURE_LEVEL_4_0},
       // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in
       // DML_FEATURE_LEVEL_4_0.
       {"BuildOneInputAndOneConstantOperand", DML_FEATURE_LEVEL_4_0},
       // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in
       // DML_FEATURE_LEVEL_4_0.
       {"BuildOneGraphToComputeMultipleTimes", DML_FEATURE_LEVEL_4_0},
       // DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC support for 1~8
       // dimension
       // counts was introduced in DML_FEATURE_LEVEL_3_1.
       {"BuildSingleOperatorLayerNormalization", DML_FEATURE_LEVEL_3_1},
       // DML_GEMM_OPERATOR_DESC support for 2~4 dimensions was introduced in
       // DML_FEATURE_LEVEL_4_0.
       {"BuildAndComputeSingleOperatorMatmul", DML_FEATURE_LEVEL_4_0},
       {"FuseStandaloneOperationsIntoMatmul", DML_FEATURE_LEVEL_4_0},
       // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in
       // DML_FEATURE_LEVEL_4_0.
       {"BuildMultipleInputsAppendingConstants", DML_FEATURE_LEVEL_4_0},
       // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in
       // DML_FEATURE_LEVEL_4_0.
       {"BuildMultipleConstantsAppendingInputs", DML_FEATURE_LEVEL_4_0},
       // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in
       // DML_FEATURE_LEVEL_4_0.
       {"BuildGemmWithReshapedConstantOperand", DML_FEATURE_LEVEL_4_0},
       // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in
       // DML_FEATURE_LEVEL_4_0.
       {"BuildMaxPooingAsThirdOperator", DML_FEATURE_LEVEL_4_0},
       // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in
       // DML_FEATURE_LEVEL_4_0.
       {"BuildMaxPooingAsSecondOperator", DML_FEATURE_LEVEL_4_0},
       // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in
       // DML_FEATURE_LEVEL_4_0.
       {"BuildMaxPooingAsFirstOperator", DML_FEATURE_LEVEL_4_0}});
  auto it = kRequiredFeatureLevels.find(
      ::testing::UnitTest::GetInstance()->current_test_info()->name());
  if (it != kRequiredFeatureLevels.end()) {
    const auto& required_feature_level = it->second;
    SKIP_TEST_IF(!adapter_->IsDMLFeatureLevelSupported(required_feature_level));
  }
}
#endif  // #if BUILDFLAG(IS_WIN)

#if BUILDFLAG(IS_MAC)
class WebNNGraphImplBackendTest : public testing::Test {
 public:
  WebNNGraphImplBackendTest()
      : scoped_feature_list_(
            webnn::mojom::features::kWebMachineLearningNeuralNetwork) {}

  void SetUp() override;

 protected:
  base::test::ScopedFeatureList scoped_feature_list_;
  base::test::TaskEnvironment task_environment_;
};

void WebNNGraphImplBackendTest::SetUp() {
  if (base::mac::MacOSVersion() < 14'00'00) {
    GTEST_SKIP() << "Skipping test because WebNN is not supported on Mac OS "
                 << base::mac::MacOSVersion();
  }
  const std::string_view current_test_name =
      ::testing::UnitTest::GetInstance()->current_test_info()->name();
  // Keep this list sorted by the operator being tested.
  static auto kSupportedTests = base::MakeFixedFlatSet<std::string_view>({
      "BuildAndComputeSingleOperatorClamp",
      "BuildAndComputeConcatWithConstants",
      "BuildAndComputeSingleOperatorRelu",
      "BuildAndComputeSingleOperatorTanh",
      "BuildAndComputeGraphWithTwoTranspose",
  });
  if (!kSupportedTests.contains(current_test_name)) {
    GTEST_SKIP() << "Skipping test because the operator is not yet supported.";
  }
}
#endif  // BUILDFLAG(IS_MAC)

// TODO(crbug.com/325612086): Parameterize these tests for different backends.
#if BUILDFLAG(WEBNN_USE_TFLITE) && !BUILDFLAG(IS_WIN)
class WebNNGraphImplBackendTest : public testing::Test {};

void WebNNGraphImplBackendTest::SetUp() {}
#endif  // BUILDFLAG(WEBNN_USE_TFLITE) && !BUILDFLAG(IS_WIN)

struct FusibleOperationDescriptor {};

void BuildFusibleOperation(GraphInfoBuilder& builder,
                           const FusibleOperationDescriptor& operation,
                           uint64_t input_operand_id,
                           uint64_t output_operand_id) {}

template <typename T>
struct BatchNormalizationTester {};

// Test building and computing a graph of fusing a standalone activation into
// batchNormalization automatically.
TEST_F(WebNNGraphImplBackendTest,
       FuseStandaloneActivationIntoBatchNormalization) {}

template <typename T>
struct Conv2dTester {};

// Test building and computing a graph of fusing a standalone activation
// into conv2d automatically.
TEST_F(WebNNGraphImplBackendTest, FuseStandaloneActivationIntoConv2d) {}

// I is the type of the inputs, both of which must be the same.
// O is the type of the output, which by default is the same as the input.
// Logical operators, however, have uint8_t (bool) as outputs.
template <typename I, typename O = I>
struct ElementWiseBinaryTester {};

// Test building and computing a graph of fusing a standalone activation
// into elementwise binary add automatically.
TEST_F(WebNNGraphImplBackendTest,
       FuseStandaloneActivationIntoElementWiseBinaryAdd) {}

// Test building and computing a graph in the following topology.
//         [input]
//            |
//          split
//        /       \
//   [output1]  reshape
//                 |
//             [output2]
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithSplitAndReshape) {}

template <typename T>
struct UnaryOperatorTester {};

// Test building and computing a graph with single operator clamp.
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorClamp) {}

// Test building and computing a graph with single operator hardSigmoid.
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorHardSigmoid) {}

// Test building and computing a graph with single operator hardSwish.
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorHardSwish) {}

// Test building and computing a graph with single operator tanh.
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorTanh) {}

// Test building and computing a graph with two relu operators.
//    [input]
//       |
//      relu1
//       |
//      relu2
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithTwoRelu) {}

// Test building and computing a graph with two operators (reshape as the
// last node).
//    [input]
//       |
//      relu
//       |
//     reshape
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithReshapeAsLastNode) {}

// Test building and computing a graph with two operators (reshape as an
// intermediate node).
//    [input]
//       |
//    reshape
//       |
//      relu
TEST_F(WebNNGraphImplBackendTest,
       BuildAndComputeGraphWithReshapeAsIntermediateNode) {}

// Test building and computing a graph with two reshape operators
//    [input]
//       |
//    reshape1
//       |
//    reshape2
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithTwoReshape) {}

// Test building and computing a graph with two operators and two outputs
//      [input]
//       /   \
//  reshape   relu
//     |        |
// [output1] [output2]
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithTwoOutputs) {}

struct GemmAttributes {};

template <typename T>
struct GemmTester {};

// Test building and computing a graph of fusing a standalone activation
// into gemm automatically.
TEST_F(WebNNGraphImplBackendTest, FuseStandaloneActivationIntoGemm) {}

template <typename T>
struct GruTester {};

// Test building and computing a graph with single operator gru.
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorGru) {}

// TODO(https://issues.chromium.org/issues/331250158): Delete the test cases
// after the WPT conformance tests are completed.
template <typename T>
struct GruCellTester {};

// Test building and computing a graph with single operator gruCell.
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorGruCell) {}

// Test building and computing a graph with three gemm operations.
//    [input_a] [input_b] [input_a] [input_b]
//           \    /                \    /
//            gemm                  gemm
//                \                /
//                       gemm
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeMultipleOperatorGemm) {}

// Test building and computing a graph with one input and one constant.
TEST_F(WebNNGraphImplBackendTest, BuildOneInputAndOneConstantOperand) {}

// Test building a graph with one input and one constant to compute for
// multiple times.
TEST_F(WebNNGraphImplBackendTest, BuildOneGraphToComputeMultipleTimes) {}

template <typename T>
struct InstanceNormalizationTester {};

// Test building and computing a graph of fusing a standalone activation into
// instanceNormalization automatically.
TEST_F(WebNNGraphImplBackendTest,
       FuseStandaloneActivationIntoInstanceNormalization) {}

template <typename T>
struct LayerNormalizationTester {};

// Test building and computing a graph of fusing a standalone activation into
// layerNormalization automatically.
TEST_F(WebNNGraphImplBackendTest,
       FuseStandaloneActivationIntoLayerNormalization) {}

// Test building and computing a graph with single operator
// layerNormalization.
TEST_F(WebNNGraphImplBackendTest, BuildSingleOperatorLayerNormalization) {}

template <typename T>
struct LstmTester {};

// Test building and computing a graph with single operator lstm.
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorLstm) {}

struct LstmCellAttributes {};

// TODO(crbug.com/331250158): Remove this test after the WPT conformance tests
// are completed.
// Test building and computing a graph with single operator lstmCell.
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorLstmCell) {}

template <typename T>
struct MatmulTester {};

// Test building and computing a graph of fusing standalone operations
// into matmul when possible.
TEST_F(WebNNGraphImplBackendTest, FuseStandaloneOperationsIntoMatmul) {}

// Test building and computing a graph with two inputs and two constant in
// the following topology.
//    [input_a] [constant_a] [input_b] [constant_b]
//           \    /                \    /
//            gemm                  gemm
//                \                /
//                       gemm
TEST_F(WebNNGraphImplBackendTest, BuildMultipleInputsAppendingConstants) {}

// Test building and computing a graph with two inputs and two constant in
// the following topology.
//    [constant_a] [input_a] [constant_b] [input_b]
//           \    /                \    /
//            gemm                  gemm
//                \                /
//                       gemm
TEST_F(WebNNGraphImplBackendTest, BuildMultipleConstantsAppendingInputs) {}

// Test building and computing a graph whose gemm operator takes a reshaped
// constant operand c in the following topology:
//                        [constant_c]
//                         |
//     [input_a] [input_b] reshape
//             \    |     /
//                 gemm
// This test case could reproduce the issue of ResNetV2 50 model of WebNN image
// classification sample:
// https://bugs.chromium.org/p/chromium/issues/detail?id=1509747
TEST_F(WebNNGraphImplBackendTest, BuildGemmWithReshapedConstantOperand) {}

// Test building a graph whose add operator takes a reshaped
// constant operand b in the following topology:
//              [constant_b]
//                 |
//    [input_a]  reshape
//           \    /
//            add
TEST_F(WebNNGraphImplBackendTest, BuildAddWithReshapedConstantOperand) {}

// Test building and computing a graph whose relu operator only has a
// constant operand input, as the following topology:
//    [constant]
//         |
//       relu
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeReluWithOnlyConstantInput) {}

// Test building and computing a graph whose add operator only has constant
// operand inputs, as the following topology:
//    [constant_a]  [constant_b]
//               \  /
//               add
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeAddWithOnlyConstantInputs) {}

// Test building and computing a graph whose add and mul operators only have
// constant and intermediate operand inputs, as the following topology:
//    [constant_a]  [constant_b]
//               \  /
//               add    [constant_c]
//                  \  /
//                   mul
TEST_F(WebNNGraphImplBackendTest,
       BuildAndComputeAddAndMulWithOnlyConstantInputs) {}

struct Pool2dAttributes {};

// Test building a graph in the following topology.
//    [input_a] [input_b]
//           \    /
//            add
//             |
//            relu
//             |
//          max pooling
TEST_F(WebNNGraphImplBackendTest, BuildMaxPooingAsThirdOperator) {}

// Test building a graph in the following topology.
//    [input_a] [input_b]
//           \    /
//            add
//             |
//          max pooling
//             |
//            relu
TEST_F(WebNNGraphImplBackendTest, BuildMaxPooingAsSecondOperator) {}

// Test building a graph in the following topology.
//      [input_a]
//          |
//      max pooling
//                  [input_b]
//           \        /
//               add
//                |
//               relu
TEST_F(WebNNGraphImplBackendTest, BuildMaxPooingAsFirstOperator) {}

// Test building and computing a graph with float 16 data type in the
// following topology.
//     [input_a]
//         |
//      reshape    [input_b]
//          \         /
//             concat
//               |
//             clamp
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeReshapeConcatAndClamp) {}

// Test building and computing a graph in the following topology.
//      [input]   [constant_a]
//          \          /
//             concat   [constant_b]
//               \           /
//                   concat
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeConcatWithConstants) {}

template <typename T>
struct Resample2dTester {};

// Test building and computing a graph with single operator resample2d.
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorResample2d) {}

// Test building and computing a graph in the following topology.
//      [input]
//         |
//     transpose
//         |
//     transpose
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithTwoTranspose) {}

// Test building and computing a graph in the following topology.
//      [input]
//         |
//     transpose
//         |
//       relu
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithTransposeAndRelu) {}

// Test building and computing a graph in the following topology.
//      [input]
//         |
//     transpose
//         |
//      reshape
//         |
//      reshape
//         |
//     transpose
TEST_F(WebNNGraphImplBackendTest,
       BuildAndComputeGraphWithTransposeAndTwoReshape) {}

// Test building and computing a graph in the following topology.
//         [input]
//            |
//           relu
//          /    \
//     reshape    transpose
//        |           |
//    [output1]   [output2]
TEST_F(WebNNGraphImplBackendTest,
       BuildAndComputeGraphWithTransposeAndTwoOutputs) {}

// Test building and computing a graph which can't be automatically fused
// because the output of conv2d is used by two operations or as graph's output.
TEST_F(WebNNGraphImplBackendTest,
       MultipleOutputsCanNotFuseStandaloneActivation) {}

}  // namespace webnn::test