chromium/services/webnn/webnn_graph_mojolpm_fuzzer.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "base/command_line.h"
#include "base/files/scoped_temp_dir.h"
#include "base/memory/raw_ref.h"
#include "base/notreached.h"
#include "base/task/single_thread_task_runner.h"
#include "base/test/allow_check_is_test_for_testing.h"
#include "base/test/task_environment.h"
#include "base/test/test_timeouts.h"
#include "content/test/fuzzer/mojolpm_fuzzer_support.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "mojo/public/cpp/bindings/self_owned_receiver.h"
#include "services/webnn/public/mojom/webnn_graph.mojom-mojolpm.h"
#include "services/webnn/public/mojom/webnn_graph.mojom.h"
#include "services/webnn/tflite/context_impl_tflite.h"
#include "services/webnn/tflite/graph_builder_tflite.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_graph_builder_impl.h"
#include "services/webnn/webnn_graph_impl.h"
#include "services/webnn/webnn_graph_mojolpm_fuzzer.pb.h"
#include "third_party/libprotobuf-mutator/src/src/libfuzzer/libfuzzer_macro.h"

#if BUILDFLAG(IS_WIN)
#include "services/webnn/dml/adapter.h"
#include "services/webnn/dml/context_impl_dml.h"
#include "services/webnn/dml/graph_builder_dml.h"
#include "services/webnn/dml/graph_impl_dml.h"
#endif

#if BUILDFLAG(IS_POSIX)
#include "services/webnn/coreml/graph_builder_coreml.h"
#endif

namespace {
struct InitGlobals {
  InitGlobals() {
    mojo::core::Init();
    bool success = base::CommandLine::Init(0, nullptr);
    CHECK(success);

    TestTimeouts::Initialize();

    base::test::AllowCheckIsTestForTesting();

    task_environment = std::make_unique<base::test::TaskEnvironment>(
        base::test::TaskEnvironment::MainThreadType::DEFAULT,
        base::test::TaskEnvironment::TimeSource::MOCK_TIME);

#if BUILDFLAG(IS_WIN)
    auto adapter_creation_result =
        webnn::dml::Adapter::GetGpuInstanceForTesting();
    if (adapter_creation_result.has_value()) {
      adapter = adapter_creation_result.value();
    }
#endif
  }

  std::unique_ptr<base::test::TaskEnvironment> task_environment;
#if BUILDFLAG(IS_WIN)
  scoped_refptr<webnn::dml::Adapter> adapter;
#endif
};

InitGlobals* init_globals = new InitGlobals();

base::test::TaskEnvironment& GetEnvironment() {
  return *init_globals->task_environment;
}

#if BUILDFLAG(IS_WIN)
scoped_refptr<webnn::dml::Adapter> GetAdapter() {
  return init_globals->adapter;
}
#endif

scoped_refptr<base::SingleThreadTaskRunner> GetFuzzerTaskRunner() {
  return GetEnvironment().GetMainThreadTaskRunner();
}

class WebnnGraphLPMFuzzer {
 public:
  explicit WebnnGraphLPMFuzzer(
      const services::fuzzing::webnn_graph::proto::Testcase& testcase)
      : testcase_(testcase) {}

  void NextAction() {
    const auto& action = testcase_->actions(action_index_);
    const auto& create_graph = action.create_graph();

    auto graph_info_ptr = webnn::mojom::GraphInfo::New();
    mojolpm::FromProto(create_graph.graph_info(), graph_info_ptr);

#if BUILDFLAG(IS_POSIX)
    auto coreml_properties =
        webnn::WebNNContextImpl::IntersectWithBaseProperties(
            webnn::coreml::GraphBuilderCoreml::GetContextProperties());
    if (webnn::WebNNGraphBuilderImpl::ValidateGraph(coreml_properties,
                                                    *graph_info_ptr)
            .has_value()) {
      // Test the Core ML graph builder.
      base::ScopedTempDir temp_dir;
      CHECK(temp_dir.CreateUniqueTempDir());
      auto coreml_graph_builder =
          webnn::coreml::GraphBuilderCoreml::CreateAndBuild(
              *graph_info_ptr, std::move(coreml_properties),
              temp_dir.GetPath());
    }
#endif

    auto tflite_properties =
        webnn::WebNNContextImpl::IntersectWithBaseProperties(
            webnn::tflite::GraphBuilderTflite::GetContextProperties());
    if (webnn::WebNNGraphBuilderImpl::ValidateGraph(tflite_properties,
                                                    *graph_info_ptr)
            .has_value()) {
      // Test the TFLite graph builder.
      auto flatbuffer = webnn::tflite::GraphBuilderTflite::CreateAndBuild(
          std::move(tflite_properties), *graph_info_ptr);
    }

#if BUILDFLAG(IS_WIN)
    CHECK(GetAdapter());
    auto dml_properties = webnn::WebNNContextImpl::IntersectWithBaseProperties(
        webnn::dml::ContextImplDml::GetProperties(
            GetAdapter()->max_supported_feature_level()));
    if (webnn::WebNNGraphBuilderImpl::ValidateGraph(dml_properties,
                                                    *graph_info_ptr)
            .has_value()) {
      // Graph compilation relies on IDMLDevice1::CompileGraph introduced in
      // DirectML version 1.2 (DML_FEATURE_LEVEL_2_1).
      CHECK(GetAdapter()->IsDMLDeviceCompileGraphSupportedForTesting());

      webnn::dml::GraphBuilderDml graph_builder(GetAdapter()->dml_device());
      std::unordered_map<uint64_t, uint32_t> constant_id_to_input_index_map;
      webnn::dml::GraphImplDml::GraphBufferBindingInfo
          graph_buffer_binding_info;
      auto create_operator_result =
          webnn::dml::GraphImplDml::CreateAndBuildInternal(
              dml_properties, GetAdapter(), graph_info_ptr, graph_builder,
              constant_id_to_input_index_map, graph_buffer_binding_info);
      if (create_operator_result.has_value()) {
        auto dml_graph_builder = graph_builder.Compile(DML_EXECUTION_FLAG_NONE);
      }
    }
#endif
    ++action_index_;
  }

  bool IsFinished() { return action_index_ >= testcase_->actions_size(); }

 private:
  const raw_ref<const services::fuzzing::webnn_graph::proto::Testcase>
      testcase_;
  int action_index_ = 0;
};

void NextAction(WebnnGraphLPMFuzzer* testcase,
                base::OnceClosure fuzzer_run_loop) {
  if (!testcase->IsFinished()) {
    testcase->NextAction();
    GetFuzzerTaskRunner()->PostTask(
        FROM_HERE, base::BindOnce(NextAction, base::Unretained(testcase),
                                  std::move(fuzzer_run_loop)));
  } else {
    std::move(fuzzer_run_loop).Run();
  }
}

void RunTestcase(WebnnGraphLPMFuzzer* testcase) {
  base::RunLoop fuzzer_run_loop;
  GetFuzzerTaskRunner()->PostTask(
      FROM_HERE, base::BindOnce(NextAction, base::Unretained(testcase),
                                fuzzer_run_loop.QuitClosure()));
  // Make sure that all callbacks have completed.
  constexpr base::TimeDelta kTimeout = base::Seconds(5);
  GetEnvironment().FastForwardBy(kTimeout);
  fuzzer_run_loop.Run();
}

DEFINE_BINARY_PROTO_FUZZER(
    const services::fuzzing::webnn_graph::proto::Testcase& testcase) {
  if (!testcase.actions_size()) {
    return;
  }

  WebnnGraphLPMFuzzer webnn_graph_fuzzer_instance(testcase);
  base::RunLoop main_run_loop;

  GetFuzzerTaskRunner()->PostTaskAndReply(
      FROM_HERE,
      base::BindOnce(RunTestcase,
                     base::Unretained(&webnn_graph_fuzzer_instance)),
      main_run_loop.QuitClosure());
  main_run_loop.Run();
}

}  // namespace