chromium/services/webnn/BUILD.gn

# Copyright 2023 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import("//build/buildflag_header.gni")
import("//mojo/public/tools/fuzzers/mojolpm.gni")
import("//services/webnn/features.gni")
import("//third_party/protobuf/proto_library.gni")
import("//third_party/tflite/features.gni")

buildflag_header("buildflags") {
  header = "buildflags.h"
  flags = [
    "WEBNN_USE_TFLITE=$webnn_use_tflite",
    "WEBNN_ENABLE_TFLITE_PROFILER=$webnn_enable_tflite_profiler",
  ]
}

if (is_posix) {
  source_set("coreml_graph_builder") {
    sources = [
      "coreml/graph_builder_coreml.cc",
      "coreml/graph_builder_coreml.h",
    ]
    deps = [
      ":webnn_utils",
      "//base",
      "//services/webnn/public/mojom",
      "//third_party/coremltools:modelformat_proto",
      "//third_party/fp16",
    ]
  }
}

source_set("tflite_graph_builder") {
  sources = [
    "tflite/graph_builder_tflite.cc",
    "tflite/graph_builder_tflite.h",
  ]
  public_deps = [
    ":webnn_utils",
    "//base",
    "//services/webnn/public/mojom",
    "//third_party/flatbuffers",
    "//third_party/tflite:tflite_public_headers",
  ]
}

source_set("webnn_switches") {
  sources = [ "webnn_switches.h" ]
}

component("webnn_service") {
  sources = [
    "error.h",
    "webnn_buffer_impl.cc",
    "webnn_buffer_impl.h",
    "webnn_context_impl.cc",
    "webnn_context_impl.h",
    "webnn_context_provider_impl.cc",
    "webnn_context_provider_impl.h",
    "webnn_graph_builder_impl.cc",
    "webnn_graph_builder_impl.h",
    "webnn_graph_impl.cc",
    "webnn_graph_impl.h",
    "webnn_object_impl.h",
  ]

  deps = [
    ":buildflags",
    ":webnn_switches",
    ":webnn_utils",
    "//base",
    "//gpu/command_buffer/service:gles2",
    "//gpu/config",
    "//mojo/public/cpp/bindings",
    "//services/webnn/public/mojom",
  ]

  if (is_win) {
    sources += [
      "dml/adapter.cc",
      "dml/adapter.h",
      "dml/buffer_impl_dml.cc",
      "dml/buffer_impl_dml.h",
      "dml/command_queue.cc",
      "dml/command_queue.h",
      "dml/command_recorder.cc",
      "dml/command_recorder.h",
      "dml/context_impl_dml.cc",
      "dml/context_impl_dml.h",
      "dml/error.h",
      "dml/graph_builder_dml.cc",
      "dml/graph_builder_dml.h",
      "dml/graph_impl_dml.cc",
      "dml/graph_impl_dml.h",
      "dml/platform_functions.cc",
      "dml/platform_functions.h",
      "dml/tensor_desc.cc",
      "dml/tensor_desc.h",
      "dml/utils.cc",
      "dml/utils.h",
    ]
    deps += [
      "//third_party/fp16",
      "//third_party/microsoft_dxheaders:dxguids",
      "//ui/gl",
      "//ui/gl/init",
    ]

    libs = [ "dxgi.lib" ]
  }

  if (is_mac) {
    sources += [
      "coreml/buffer_impl_coreml.h",
      "coreml/buffer_impl_coreml.mm",
      "coreml/context_impl_coreml.h",
      "coreml/context_impl_coreml.mm",
      "coreml/graph_impl_coreml.h",
      "coreml/graph_impl_coreml.mm",
      "coreml/utils_coreml.h",
      "coreml/utils_coreml.mm",
    ]
    frameworks = [
      "CoreFoundation.framework",
      "CoreML.framework",
      "Foundation.framework",
    ]
    deps += [ ":coreml_graph_builder" ]
  }

  if (webnn_use_tflite) {
    sources += [
      # TODO(crbug.com/333392274): Use these headers on Mac, too.
      "queueable_resource_state.h",
      "queueable_resource_state_base.cc",
      "queueable_resource_state_base.h",
      "resource_task.cc",
      "resource_task.h",
      "tflite/buffer_content.cc",
      "tflite/buffer_content.h",
      "tflite/buffer_impl_tflite.cc",
      "tflite/buffer_impl_tflite.h",
    ]

    if (is_chromeos) {
      sources += [
        "tflite/context_impl_cros.cc",
        "tflite/context_impl_cros.h",
        "tflite/graph_impl_cros.cc",
        "tflite/graph_impl_cros.h",
      ]

      deps += [ "//chromeos/services/machine_learning/public/cpp" ]
    } else {
      sources += [
        "tflite/context_impl_tflite.cc",
        "tflite/context_impl_tflite.h",
        "tflite/graph_impl_tflite.cc",
        "tflite/graph_impl_tflite.h",
        "tflite/op_resolver.cc",
        "tflite/op_resolver.h",
      ]

      deps += [ "//third_party/tflite:tflite_builtin_op_resolver" ]
    }

    deps += [
      ":tflite_graph_builder",
      "//components/ml/mojom",
      "//third_party/tflite",
      "//third_party/tflite:buildflags",
    ]

    if (build_tflite_with_xnnpack) {
      deps += [ "//third_party/xnnpack" ]
    }
  }

  defines = [ "IS_WEBNN_SERVICE_IMPL" ]
}

component("webnn_utils") {
  sources = [
    "webnn_utils.cc",
    "webnn_utils.h",
  ]

  deps = [
    "//base",
    "//services/webnn/public/mojom",
  ]

  defines = [ "IS_WEBNN_UTILS_IMPL" ]
}

source_set("tests") {
  testonly = true

  sources = [
    "webnn_context_provider_impl_unittest.cc",
    "webnn_graph_builder_impl_unittest.cc",
    "webnn_graph_impl_unittest.cc",
    "webnn_test_utils.cc",
    "webnn_test_utils.h",
  ]

  if (is_win) {
    sources += [
      "dml/adapter_test.cc",
      "dml/command_queue_test.cc",
      "dml/command_recorder_test.cc",
      "dml/context_impl_dml_test.cc",
      "dml/graph_builder_dml_test.cc",
      "dml/platform_functions_test.cc",
      "dml/tensor_desc_test.cc",
      "dml/test_base.cc",
      "dml/test_base.h",
    ]
  }

  if (webnn_use_tflite || is_mac || is_win) {
    sources += [
      "webnn_buffer_impl_backend_test.cc",
      "webnn_graph_impl_backend_test.cc",
    ]
  }

  deps = [
    ":buildflags",
    ":webnn_service",
    ":webnn_utils",
    "//base",
    "//base/test:test_support",
    "//mojo/public/cpp/bindings",
    "//mojo/public/cpp/test_support:test_utils",
    "//services/webnn/public/mojom",
    "//testing/gtest",
    "//third_party/fp16",
  ]

  if (is_win) {
    deps += [
      "//ui/gl",
      "//ui/gl/init",
    ]
  }

  if (is_chromeos) {
    deps += [
      "//chromeos/services/machine_learning/public/cpp",
      "//chromeos/services/machine_learning/public/cpp:stub",
    ]
  }
}

mojolpm_fuzzer_test("webnn_graph_mojolpm_fuzzer") {
  sources = [ "webnn_graph_mojolpm_fuzzer.cc" ]

  proto_source = "webnn_graph_mojolpm_fuzzer.proto"

  proto_deps = [ "//services/webnn/public/mojom:mojom_mojolpm" ]

  deps = [
    ":tflite_graph_builder",
    ":webnn_service",
    "//base",
    "//base/test:test_support",
    "//content/test/fuzzer:mojolpm_fuzzer_support",
    "//services/webnn/public/mojom",
    "//third_party/libprotobuf-mutator",
  ]

  if (is_posix) {
    deps += [ ":coreml_graph_builder" ]
  }
}