chromium/third_party/xnnpack/src/src/operators/scaled-dot-product-attention-nhtc.c

// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <math.h>
#include <stddef.h>
#include <stdint.h>
#include <string.h>

#include <fp16/fp16.h>
#include "xnnpack.h"
#include "xnnpack/allocator.h"
#include "xnnpack/common.h"
#include "xnnpack/compute.h"
#include "xnnpack/config-types.h"
#include "xnnpack/config.h"
#include "xnnpack/log.h"
#include "xnnpack/math.h"
#include "xnnpack/microkernel-type.h"
#include "xnnpack/microparams.h"
#include "xnnpack/operator-type.h"
#include "xnnpack/operator.h"
#include "xnnpack/params.h"
#include "pthreadpool.h"

static enum xnn_status create_scaled_dot_product_attention_nhtc(
  enum xnn_attention_logits_cap_type cap_type,
  const void* cap_params,
  enum xnn_operator_type operator_type,
  const struct xnn_gemm_config* gemm_config,
  const struct xnn_raddstoreexpminusmax_config* raddstoreexpminusmax_config,
  const struct xnn_rmax_config* rmax_config,
  const struct xnn_binary_elementwise_config* vadd_config,
  const struct xnn_binary_elementwise_config* vmul_config,
  const struct xnn_unary_elementwise_config* vtanh_config,
  const void* minmax_params,
  size_t minmax_params_size,
  const void* expminus_params,
  size_t expminus_params_size,
  const void* rmax_params,
  size_t rmax_params_size,
  const void* tanh_params,
  size_t tanh_params_size,
  uint32_t flags,
  xnn_operator_t* attention_op_out)
{}

enum xnn_status xnn_create_scaled_dot_product_attention_nhtc_f16(
  enum xnn_attention_logits_cap_type cap_type,
  const void* cap_params,
  uint32_t flags,
  xnn_operator_t* attention_op_out)
{}

enum xnn_status xnn_create_scaled_dot_product_attention_nhtc_f32(
  enum xnn_attention_logits_cap_type cap_type,
  const void* cap_params,
  uint32_t flags,
  xnn_operator_t* attention_op_out)
{}

static void compute_reciprocal_f16(
    const uint16_t input[XNN_MIN_ELEMENTS(1)],
    uint16_t output[XNN_MIN_ELEMENTS(1)])
{}

static void compute_reciprocal_f32(
  const float input[XNN_MIN_ELEMENTS(1)],
  float output[XNN_MIN_ELEMENTS(1)])
{}

static enum xnn_status reshape_scaled_dot_product_attention_nhtc(
  xnn_operator_t attention_op,
  enum xnn_operator_type expected_operator_type,
  size_t batch_size,
  size_t query_heads,
  size_t query_tokens,
  size_t key_value_heads,
  size_t key_value_tokens,
  size_t query_key_channels,
  size_t value_channels,
  size_t* workspace_size,
  size_t* workspace_alignment,
  size_t log2_element_size,
  size_t element_size,
  xnn_compute_reciprocal_fn compute_reciprocal,
  void* cap,
  void* cap_reciprocal,
  size_t cap_size,
  const void* minmax_params,
  size_t minmax_params_size,
  const void* expminus_params,
  size_t expminus_params_size,
  const void* rmax_params,
  size_t rmax_params_size,
  const void* tanh_params,
  size_t tanh_params_size,
  pthreadpool_t threadpool)
{}

enum xnn_status xnn_reshape_scaled_dot_product_attention_nhtc_f16(
  xnn_operator_t attention_op,
  size_t batch_size,
  size_t heads,
  size_t query_tokens,
  size_t key_value_heads,
  size_t key_value_tokens,
  size_t query_key_channels,
  size_t value_channels,
  size_t* workspace_size,
  size_t* workspace_alignment,
  pthreadpool_t threadpool)
{}

enum xnn_status xnn_reshape_scaled_dot_product_attention_nhtc_f32(
  xnn_operator_t attention_op,
  size_t batch_size,
  size_t query_heads,
  size_t query_tokens,
  size_t key_value_heads,
  size_t key_value_tokens,
  size_t query_key_channels,
  size_t value_channels,
  size_t* workspace_size,
  size_t* workspace_alignment,
  pthreadpool_t threadpool)
{}

static enum xnn_status setup_scaled_dot_product_attention_nhtc(
  xnn_operator_t attention_op,
  enum xnn_operator_type expected_operator_type,
  void* workspace,
  const float* query,
  const float* key,
  const float* value,
  const float* scale,
  const float* mask,
  float* output)
{}

enum xnn_status xnn_setup_scaled_dot_product_attention_nhtc_f16(
  xnn_operator_t attention_op,
  void* workspace,
  const void* query,
  const void* key,
  const void* value,
  const void* scale,
  const void* mask,
  void* output)
{}

enum xnn_status xnn_setup_scaled_dot_product_attention_nhtc_f32(
  xnn_operator_t attention_op,
  void* workspace,
  const float* query,
  const float* key,
  const float* value,
  const float* scale,
  const float* mask,
  float* output)
{}