#include <assert.h>
#include <inttypes.h>
#include <math.h>
#include <stdint.h>
#include <string.h>
#include "xnnpack.h"
#include "xnnpack/common.h"
#include "xnnpack/log.h"
#include "xnnpack/node-type.h"
#include "xnnpack/operator-type.h"
#include "xnnpack/operator.h"
#include "xnnpack/subgraph-validation.h"
#include "xnnpack/subgraph.h"
#include "pthreadpool.h"
static enum xnn_status create_scaled_dot_product_attention_operator(
const struct xnn_node* node,
const struct xnn_value* values,
size_t num_values,
struct xnn_operator_data* opdata,
struct xnn_code_cache* code_cache,
xnn_weights_cache_t weights_cache)
{ … }
static enum xnn_status resize_scaled_dot_product_attention_output_tensor(
const struct xnn_operator_data* opdata, struct xnn_value* values, size_t num_values, size_t old_workspace_size)
{ … }
static enum xnn_status reshape_scaled_dot_product_attention_operator(
struct xnn_operator_data* opdata,
struct xnn_value* values,
size_t num_values,
pthreadpool_t threadpool)
{ … }
static enum xnn_status setup_scaled_dot_product_attention_operator(
const struct xnn_operator_data* opdata,
const struct xnn_value* values,
size_t num_values,
pthreadpool_t threadpool)
{ … }
static enum xnn_status check_inputs(
xnn_subgraph_t subgraph,
uint32_t input_id)
{ … }
enum xnn_status xnn_define_scaled_dot_product_attention(
xnn_subgraph_t subgraph,
enum xnn_attention_logits_cap_type cap_type,
const void* cap_params,
uint32_t query_id,
uint32_t key_id,
uint32_t value_id,
uint32_t scale_id,
uint32_t mask_id,
uint32_t output_id,
uint32_t flags)
{ … }