#include <assert.h>
#include <math.h>
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include <fp16/fp16.h>
#include "xnnpack.h"
#include "xnnpack/allocator.h"
#include "xnnpack/common.h"
#include "xnnpack/compute.h"
#include "xnnpack/config-types.h"
#include "xnnpack/config.h"
#include "xnnpack/log.h"
#include "xnnpack/math.h"
#include "xnnpack/microkernel-type.h"
#include "xnnpack/microparams.h"
#include "xnnpack/operator-type.h"
#include "xnnpack/operator.h"
#include "xnnpack/params.h"
#include "pthreadpool.h"
static enum xnn_status create_scaled_dot_product_attention_nhtc(
enum xnn_attention_logits_cap_type cap_type,
const void* cap_params,
enum xnn_operator_type operator_type,
const struct xnn_gemm_config* gemm_config,
const struct xnn_raddstoreexpminusmax_config* raddstoreexpminusmax_config,
const struct xnn_rmax_config* rmax_config,
const struct xnn_binary_elementwise_config* vadd_config,
const struct xnn_binary_elementwise_config* vmul_config,
const struct xnn_unary_elementwise_config* vtanh_config,
const void* minmax_params,
size_t minmax_params_size,
const void* expminus_params,
size_t expminus_params_size,
const void* rmax_params,
size_t rmax_params_size,
const void* tanh_params,
size_t tanh_params_size,
uint32_t flags,
xnn_operator_t* attention_op_out)
{ … }
enum xnn_status xnn_create_scaled_dot_product_attention_nhtc_f16(
enum xnn_attention_logits_cap_type cap_type,
const void* cap_params,
uint32_t flags,
xnn_operator_t* attention_op_out)
{ … }
enum xnn_status xnn_create_scaled_dot_product_attention_nhtc_f32(
enum xnn_attention_logits_cap_type cap_type,
const void* cap_params,
uint32_t flags,
xnn_operator_t* attention_op_out)
{ … }
static void compute_reciprocal_f16(
const uint16_t input[XNN_MIN_ELEMENTS(1)],
uint16_t output[XNN_MIN_ELEMENTS(1)])
{ … }
static void compute_reciprocal_f32(
const float input[XNN_MIN_ELEMENTS(1)],
float output[XNN_MIN_ELEMENTS(1)])
{ … }
static enum xnn_status reshape_scaled_dot_product_attention_nhtc(
xnn_operator_t attention_op,
enum xnn_operator_type expected_operator_type,
size_t batch_size,
size_t query_heads,
size_t query_tokens,
size_t key_value_heads,
size_t key_value_tokens,
size_t query_key_channels,
size_t value_channels,
size_t* workspace_size,
size_t* workspace_alignment,
size_t log2_element_size,
size_t element_size,
xnn_compute_reciprocal_fn compute_reciprocal,
void* cap,
void* cap_reciprocal,
size_t cap_size,
const void* minmax_params,
size_t minmax_params_size,
const void* expminus_params,
size_t expminus_params_size,
const void* rmax_params,
size_t rmax_params_size,
const void* tanh_params,
size_t tanh_params_size,
pthreadpool_t threadpool)
{ … }
enum xnn_status xnn_reshape_scaled_dot_product_attention_nhtc_f16(
xnn_operator_t attention_op,
size_t batch_size,
size_t heads,
size_t query_tokens,
size_t key_value_heads,
size_t key_value_tokens,
size_t query_key_channels,
size_t value_channels,
size_t* workspace_size,
size_t* workspace_alignment,
pthreadpool_t threadpool)
{ … }
enum xnn_status xnn_reshape_scaled_dot_product_attention_nhtc_f32(
xnn_operator_t attention_op,
size_t batch_size,
size_t query_heads,
size_t query_tokens,
size_t key_value_heads,
size_t key_value_tokens,
size_t query_key_channels,
size_t value_channels,
size_t* workspace_size,
size_t* workspace_alignment,
pthreadpool_t threadpool)
{ … }
static enum xnn_status setup_scaled_dot_product_attention_nhtc(
xnn_operator_t attention_op,
enum xnn_operator_type expected_operator_type,
void* workspace,
const float* query,
const float* key,
const float* value,
const float* scale,
const float* mask,
float* output)
{ … }
enum xnn_status xnn_setup_scaled_dot_product_attention_nhtc_f16(
xnn_operator_t attention_op,
void* workspace,
const void* query,
const void* key,
const void* value,
const void* scale,
const void* mask,
void* output)
{ … }
enum xnn_status xnn_setup_scaled_dot_product_attention_nhtc_f32(
xnn_operator_t attention_op,
void* workspace,
const float* query,
const float* key,
const float* value,
const float* scale,
const float* mask,
float* output)
{ … }