# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under both the MIT license found in the
# LICENSE-MIT file in the root directory of this source tree and the Apache
# License, Version 2.0 found in the LICENSE-APACHE file in the root directory
# of this source tree.
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
load("@fbsource//xplat/executorch/build:selects.bzl", "selects")
load(
"@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl",
"get_vec_android_preprocessor_flags",
)
def op_target(name, deps = []):
"""Registers an optimized implementation for an operator overload group.
An operator overload group is a set of operator overloads with a common
operator name. That common operator name should be the base name of this
target.
E.g., the "add" operator overload group, named "op_add" in this target,
might implement:
- add.Tensor
- add_.Tensor
- add.out
- add.Scalar
If an op target would like to share a header/sources with a different op
target (e.g., helpers/utilities), it should declare a separate cxx_library
and add it as a dep.
Args:
name: The name of the operator overload group; e.g.,
"op_add". This directory must contain a source file named
"<name>.cpp"; e.g., "op_add.cpp".
deps: Optional extra deps to add to the cxx_library(). Note:
- op targets may not depend on other op targets, to keep the
dependencies manageable. If two op targets would like to share
code, define a separate runtime.cxx_library that they both depend
on.
"""
# Note that this doesn't actually define the target, but helps register
# it in a table that's used to define the target.
return {
"deps": deps,
"name": name,
}
def _enforce_deps(deps, name):
"""Fails if any of the deps are not allowed.
Args:
deps: A list of build target strings.
name: The name of the target; e.g., "op_add"
"""
for dep in deps:
if dep.startswith(":op_"):
# op targets may not depend on other op targets, to keep the
# dependencies manageable. If two op targets would like to share
# code, define a separate runtime.cxx_library that they both depend
# on.
fail("op_target {} may not depend on other op_target {}".format(
name,
dep,
))
def define_op_library(name, deps):
"""Defines a cxx_library target for the named operator overload group.
Args:
name: The name of the target; e.g., "op_add"
deps: List of deps for the target.
"""
selects.apply(obj = deps, function = native.partial(_enforce_deps, name = name))
augmented_deps = deps + [
"//executorch/kernels/optimized:libvec",
"//executorch/kernels/optimized:libutils",
]
runtime.cxx_library(
name = "{}".format(name),
srcs = [
"{}.cpp".format(name),
],
visibility = [
"//executorch/kernels/portable/test/...",
"//executorch/kernels/quantized/test/...",
"//executorch/kernels/optimized/test/...",
"//executorch/kernels/test/...",
"@EXECUTORCH_CLIENTS",
],
# kernels often have helpers with no prototypes just disabling the warning here as the headers
# are codegend and linked in later
compiler_flags = ["-Wno-missing-prototypes"],
deps = [
"//executorch/runtime/kernel:kernel_includes",
] + augmented_deps,
fbandroid_platform_preprocessor_flags = get_vec_android_preprocessor_flags(),
# sleef needs to be added as a direct dependency of the operator target when building for Android,
# or a linker error may occur. Not sure why this happens; it seems that fbandroid_platform_deps of
# dependencies are not transitive
fbandroid_platform_deps = [
(
"^android-arm64.*$",
[
"fbsource//third-party/sleef:sleef_arm",
],
),
],
# link_whole is necessary because the operators register themselves
# via static initializers that run at program startup.
# @lint-ignore BUCKLINT link_whole
link_whole = True,
)
def define_op_target(name, deps):
"""Possibly defines cxx_library targets for the named operator group.
Args:
name: The base name of the target; e.g., "op_add"
deps: List of deps for the targets.
"""
# When building in ATen mode, ATen-compatible (non-custom) operators will
# use the implementations provided by ATen, so we should not build the
# versions defined here.
define_op_library(
name = name,
deps = deps,
)
def is_op_disabled(name):
# TODO (gjcomer) Enable ops with sleef dependency in OSS
disabled_ops = ["op_gelu", "op_log_softmax"]
return name in disabled_ops