chromium/services/webnn/webnn_graph_impl_unittest.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "services/webnn/webnn_graph_impl.h"

#include <cmath>
#include <limits>

#include "base/containers/contains.h"
#include "base/memory/weak_ptr.h"
#include "base/notreached.h"
#include "base/strings/strcat.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/stringprintf.h"
#include "base/test/bind.h"
#include "base/test/run_until.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "mojo/public/cpp/bindings/associated_remote.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "mojo/public/cpp/bindings/self_owned_associated_receiver.h"
#include "mojo/public/cpp/system/functions.h"
#include "services/webnn/error.h"
#include "services/webnn/public/cpp/ml_buffer_usage.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/cpp/supported_data_types.h"
#include "services/webnn/public/cpp/webnn_errors.h"
#include "services/webnn/public/mojom/features.mojom-features.h"
#include "services/webnn/public/mojom/webnn_buffer.mojom.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_graph.mojom.h"
#include "services/webnn/public/mojom/webnn_graph_builder.mojom.h"
#include "services/webnn/webnn_buffer_impl.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_context_provider_impl.h"
#include "services/webnn/webnn_graph_builder_impl.h"
#include "services/webnn/webnn_test_utils.h"
#include "services/webnn/webnn_utils.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace webnn {

namespace {

// A fake WebNNGraph Mojo interface implementation that binds a pipe for
// computing graph message.
class FakeWebNNGraphImpl final : public WebNNGraphImpl {};

// A fake WebNNBuffer Mojo interface implementation that binds a pipe for
// buffer creation message.
class FakeWebNNBufferImpl final : public WebNNBufferImpl {};

// A fake WebNNContext Mojo interface implementation that binds a pipe for
// creating graph message.
class FakeWebNNContextImpl final : public WebNNContextImpl {};

// Helper class to create the FakeWebNNContext that is intended to test
// the graph validation steps and computation resources.
class FakeWebNNBackend : public WebNNContextProviderImpl::BackendForTesting {};

bool ValidateInputsForComputing(
    mojom::GraphInfoPtr graph_info,
    base::flat_map<std::string, mojo_base::BigBuffer> inputs) {}

struct CreateBufferSuccess {};

CreateBufferSuccess CreateWebNNBuffer(
    mojo::Remote<mojom::WebNNContext>& webnn_context,
    OperandDataType data_type,
    std::vector<uint32_t> shape) {}

mojo::Remote<mojom::WebNNContext> CreateWebNNContext(
    mojo::Remote<mojom::WebNNContextProvider>& webnn_context_provider) {}

// Converts inputs and outputs to MLBuffer then dispatches them.
bool ValidateDispatch(
    mojo::Remote<mojom::WebNNContext>& webnn_context,
    mojom::GraphInfoPtr graph_info,
    base::flat_map<std::string, CreateBufferSuccess> inputs,
    base::flat_map<std::string, CreateBufferSuccess> outputs) {}

OperandDataType kAllOperandDataTypes[] =;

}  // namespace

class WebNNGraphImplTest : public testing::Test {};

struct OperandInfo {};

struct ArgMinMaxTester {};

TEST_F(WebNNGraphImplTest, ArgMinMaxTest) {}

struct ClampTester {};

TEST_F(WebNNGraphImplTest, ClampTest) {}

struct HardSigmoidTester {};

TEST_F(WebNNGraphImplTest, HardSigmoidTest) {}

struct BatchNormalizationTester {};

TEST_F(WebNNGraphImplTest, BatchNormalizationTest) {}

struct ConcatTester {};

TEST_F(WebNNGraphImplTest, ConcatTest) {}

struct Conv2dTester {};

TEST_F(WebNNGraphImplTest, Conv2dTest) {}

TEST_F(WebNNGraphImplTest, ConvTranspose2dTest) {}

struct ElementWiseBinaryTester {};

TEST_F(WebNNGraphImplTest, ElementWiseBinaryTest) {}

TEST_F(WebNNGraphImplTest, ElementWiseBinaryLogicalTest) {}

struct ElementWiseUnaryTester {};

// Test the data type support for element-wise unary operators.
// The data type support is defined in the first parameter of the tuple
// as a std::pair of mojom::ElementWiseUnary::Kind and array of
// datatypes supported by the operator.
class ElementWiseUnaryDataTypeFixture
    : public testing::TestWithParam<
          std::tuple<std::pair<mojom::ElementWiseUnary::Kind,
                               std::vector<OperandDataType>>,
                     OperandDataType,
                     OperandDataType>> {};

TEST_P(ElementWiseUnaryDataTypeFixture, TestUnaryOperandDataTypeSupport) {}

TEST_P(ElementWiseUnaryDataTypeFixture, TestUnaryOperandScalarDataTypeSupport) {}

INSTANTIATE_TEST_SUITE_P();

TEST_F(WebNNGraphImplTest, ElementWiseUnaryTest) {}

struct EluTester {};

TEST_F(WebNNGraphImplTest, EluTest) {}

struct ExpandTester {};

TEST_F(WebNNGraphImplTest, ExpandTest) {}

struct GatherAttributes {};

struct GatherTester {};

TEST_F(WebNNGraphImplTest, GatherTest) {}

struct GatherElementsTester {};

TEST_F(WebNNGraphImplTest, GatherElementsTest) {}

struct GeluTester {};

TEST_F(WebNNGraphImplTest, GeluTest) {}

struct GemmTester {};

TEST_F(WebNNGraphImplTest, GemmTest) {}

struct GruTester {};

TEST_F(WebNNGraphImplTest, GruTest) {}

struct GruCellTester {};

TEST_F(WebNNGraphImplTest, GruCellTest) {}

struct InstanceNormalizationTester {};

TEST_F(WebNNGraphImplTest, InstanceNormalizationTest) {}

struct LayerNormalizationTester {};

TEST_F(WebNNGraphImplTest, LayerNormalizationTest) {}

struct LstmTester {};

TEST_F(WebNNGraphImplTest, LstmTest) {}

struct LstmCellTester {};

TEST_F(WebNNGraphImplTest, LstmCellTest) {}

struct MatmulTester {};

TEST_F(WebNNGraphImplTest, MatmulTest) {}

struct PadTester {};

TEST_F(WebNNGraphImplTest, PadTest) {}

struct Pool2dTester {};

TEST_F(WebNNGraphImplTest, Pool2dTest) {}

struct PreluTester {};

TEST_F(WebNNGraphImplTest, PreluTest) {}

struct ReduceTester {};

TEST_F(WebNNGraphImplTest, ReduceTest) {}

struct ReluTester {};

TEST_F(WebNNGraphImplTest, ReluTest) {}

struct Resample2dTester {};

TEST_F(WebNNGraphImplTest, Resample2dTest) {}

struct ReshapeTester {};

TEST_F(WebNNGraphImplTest, ReshapeTest) {}
struct SliceTester {};

TEST_F(WebNNGraphImplTest, SliceTest) {}

enum class FloatingPointUnaryKind {};

struct FloatingPointUnaryTester {};

TEST_F(WebNNGraphImplTest, FloatingPointUnaryTest) {}

struct SoftmaxTester {};

TEST_F(WebNNGraphImplTest, SoftmaxTest) {}

struct SoftplusTester {};

TEST_F(WebNNGraphImplTest, SoftplusTest) {}

struct SoftsignTester {};

TEST_F(WebNNGraphImplTest, SoftsignTest) {}

struct SplitTester {};

TEST_F(WebNNGraphImplTest, ValidateSplitTest) {}

struct TransposeTester {};

TEST_F(WebNNGraphImplTest, TransposeTest) {}

struct TriangularTester {};

TEST_F(WebNNGraphImplTest, TriangularTest) {}

struct WhereTester {};

TEST_F(WebNNGraphImplTest, WhereTest) {}

TEST_F(WebNNGraphImplTest, ValidateInputsTest) {}

TEST_F(WebNNGraphImplTest, ValidateDispatchTest) {}

struct ConstantOperandTester {};

TEST_F(WebNNGraphImplTest, ValidateConstantOperandTest) {}

// Test building a graph with two inputs and two constant in the following
// topology.
//    [input_a] [constant_a] [input_b] [constant_b]
//           \    /                \    /
//            gemm                  gemm
//                \                /
//                       gemm
TEST_F(WebNNGraphImplTest, BuildMultipleInputsAppendingConstants) {}

// Test building a graph with two inputs and two constant in the following
// topology.
//    [constant_a] [input_a] [constant_b] [input_b]
//           \    /                \    /
//            gemm                  gemm
//                \                /
//                       gemm
TEST_F(WebNNGraphImplTest, BuildMultipleConstantsAppendingInputs) {}

TEST_F(WebNNGraphImplTest, BuildOperationWithNonexistentInputs) {}

}  // namespace webnn