chromium/third_party/xnnpack/src/src/xnnpack/compute.h

// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <stddef.h>
#include <stdint.h>

#include "xnnpack.h"
#include "xnnpack/common.h"
#include "xnnpack/math.h"
#include "xnnpack/microfnptr.h"
#include "xnnpack/microparams.h"

#include "pthreadpool.h"

enum xnn_parallelization_type {};

struct compute_parameters {};

struct transpose_context {};

XNN_PRIVATE void xnn_compute_transposec_2d(
    const struct transpose_context* context,
    size_t i,
    size_t j,
    size_t tile_i,
    size_t tile_j);

XNN_PRIVATE void xnn_compute_transposec_3d(
    const struct transpose_context* context,
    size_t i,
    size_t j,
    size_t k,
    size_t tile_j,
    size_t tile_k);

XNN_PRIVATE void xnn_compute_transposec_4d(
    const struct transpose_context* context,
    size_t i,
    size_t j,
    size_t k,
    size_t l,
    size_t tile_k,
    size_t tile_l);

XNN_PRIVATE void xnn_compute_transposec_5d(
    const struct transpose_context* context,
    size_t i,
    size_t j,
    size_t k,
    size_t l,
    size_t m,
    size_t tile_l,
    size_t tile_m);

XNN_PRIVATE void xnn_compute_transposec_6d(
    const struct transpose_context* context,
    size_t i,
    size_t j,
    size_t k,
    size_t l,
    size_t m,
    size_t n,
    size_t tile_m,
    size_t tile_n);

XNN_PRIVATE void xnn_compute_transposev_2d(
    const struct transpose_context* context,
    size_t i,
    size_t j,
    size_t tile_i,
    size_t tile_j);

XNN_PRIVATE void xnn_compute_transposev_3d(
    const struct transpose_context* context,
    size_t i,
    size_t j,
    size_t k,
    size_t tile_j,
    size_t tile_k);

XNN_PRIVATE void xnn_compute_transposev_4d(
    const struct transpose_context* context,
    size_t i,
    size_t j,
    size_t k,
    size_t l,
    size_t tile_k,
    size_t tile_l);

XNN_PRIVATE void xnn_compute_transposev_5d(
    const struct transpose_context* context,
    size_t i,
    size_t j,
    size_t k,
    size_t l,
    size_t m,
    size_t tile_l,
    size_t tile_m);

XNN_PRIVATE void xnn_compute_transposev_6d(
    const struct transpose_context* context,
    size_t i,
    size_t j,
    size_t k,
    size_t l,
    size_t m,
    size_t n,
    size_t tile_m,
    size_t tile_n);

// Context for Packing Weights (packw) for GEMM microkernels in Group-OutputChannels-InputChannels layout.
// Kernel has shape GxNxK, bias has shape GxN.
struct packw_gemm_goi_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_packw_gemm_goi(
      const struct packw_gemm_goi_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t n_block_start,
      size_t n_block_size);
  XNN_PRIVATE void xnn_compute_batched_packw_gemm_goi(
      const struct packw_gemm_goi_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t n_block_start,
      size_t n_block_size);
#endif

// Context for Packing Weights (packw) for GEMM microkernels in Groups-InputChannels-OutputChannels layout.
// Kernel has shape GxKxN, bias has shape GxN.
struct packw_gemm_gio_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_packw_gemm_gio(
      const struct packw_gemm_gio_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t n_block_start,
      size_t n_block_size);
  XNN_PRIVATE void xnn_compute_batched_packw_gemm_gio(
      const struct packw_gemm_gio_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t n_block_start,
      size_t n_block_size);
#endif

// Context for Dense Matrix Multiplication.
// C [GxMxN] := A [GxMxK] * B[GxKxN] + bias [GxN]
// Where B and bias have been packed into packed_w.
struct gemm_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_grouped_gemm(
      const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t group_index,
      size_t mr_block_start,
      size_t nr_block_start,
      size_t mr_block_size,
      size_t nr_block_size);

  XNN_PRIVATE void xnn_compute_dqgemm(
      const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t mr_block_start,
      size_t nr_block_start,
      size_t mr_block_size,
      size_t nr_block_size);

  XNN_PRIVATE void xnn_compute_gemm(
      const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t mr_block_start,
      size_t nr_block_start,
      size_t mr_block_size,
      size_t nr_block_size);

  XNN_PRIVATE void xnn_compute_dqgemm_bl(
      const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t mr_block_start,
      size_t nr_block_start,
      size_t mr_block_size,
      size_t nr_block_size);

  XNN_PRIVATE void xnn_compute_qp8gemm(
      const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t mr_block_start, size_t nr_block_start, size_t mr_block_size,
      size_t nr_block_size);

#if XNN_MAX_UARCH_TYPES > 1
    XNN_PRIVATE void xnn_compute_hmp_grouped_gemm(
        const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
        uint32_t uarch_index,
        size_t group_index,
        size_t mr_block_start,
        size_t nr_block_start,
        size_t mr_block_size,
        size_t nr_block_size);

    XNN_PRIVATE void xnn_compute_hmp_gemm(
        const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
        uint32_t uarch_index,
        size_t mr_block_start,
        size_t nr_block_start,
        size_t mr_block_size,
        size_t nr_block_size);

    XNN_PRIVATE void xnn_compute_hmp_dqgemm(
        const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
        uint32_t uarch_index,
        size_t mr_block_start,
        size_t nr_block_start,
        size_t mr_block_size,
        size_t nr_block_size);

    XNN_PRIVATE void xnn_compute_hmp_qp8gemm(
        const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
        uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start,
        size_t mr_block_size, size_t nr_block_size);

    XNN_PRIVATE void xnn_compute_hmp_dqgemm_bl(
        const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
        uint32_t uarch_index,
        size_t mr_block_start,
        size_t nr_block_start,
        size_t mr_block_size,
        size_t nr_block_size);
  #endif  // XNN_MAX_UARCH_TYPES > 1
#endif

    // Context for Sparse Matrix-Dense Matrix Multiplication.
    // C [MxN] := A [MxK] * B [KxN] + bias [N]
    // A and C are dense matrices with row-major storage, B is a sparse matrix.
    struct spmm_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_spmm(
    const struct spmm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t mr_block_start,
    size_t mr_block_size);
#endif

// Context for initializing the indirection buffer for conv2d igemm.
struct conv2d_igemm_indirection_init_context {};

// Context for Indirect Dense Matrix Multiplication.
// C [BxGxMxN] := A [BxGxMxK] * B[BxGxKxN] + bias [BxGxN]
// Where B and bias have been packed into packed_w.
struct igemm_context {};

#ifndef __cplusplus
XNN_PRIVATE void xnn_compute_grouped_dqigemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t group_index, size_t mr_block_start, size_t nr_block_start,
    size_t mr_block_size, size_t nr_block_size);

XNN_PRIVATE void xnn_compute_grouped_igemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t group_index, size_t mr_block_start, size_t nr_block_start,
    size_t mr_block_size, size_t nr_block_size);

XNN_PRIVATE void xnn_compute_grouped_batch_dqigemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index, size_t group_index, size_t mr_block_start,
    size_t nr_block_start, size_t mr_block_size, size_t nr_block_size);

XNN_PRIVATE void xnn_compute_grouped_batch_igemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index, size_t group_index, size_t mr_block_start,
    size_t nr_block_start, size_t mr_block_size, size_t nr_block_size);

XNN_PRIVATE void xnn_compute_dq_zero_buffer_igemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t size);

XNN_PRIVATE void xnn_compute_dqigemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t mr_block_start, size_t nr_block_start, size_t mr_block_size,
    size_t nr_block_size);

XNN_PRIVATE void xnn_compute_igemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t mr_block_start, size_t nr_block_start, size_t mr_block_size,
    size_t nr_block_size);

XNN_PRIVATE void xnn_compute_conv2d_igemm_indirection(
    const struct conv2d_igemm_indirection_init_context
        context[restrict XNN_MIN_ELEMENTS(1)],
    size_t output_tile_start, size_t output_tile_size);

XNN_PRIVATE void xnn_compute_batch_dqigemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index, size_t mr_block_start, size_t nr_block_start,
    size_t mr_block_size, size_t nr_block_size);

XNN_PRIVATE void xnn_compute_batch_igemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index, size_t mr_block_start, size_t nr_block_start,
    size_t mr_block_size, size_t nr_block_size);

#if XNN_MAX_UARCH_TYPES > 1
XNN_PRIVATE void xnn_compute_hmp_grouped_igemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    uint32_t uarch_index, size_t group_index, size_t mr_block_start,
    size_t nr_block_start, size_t mr_block_size, size_t nr_block_size);

XNN_PRIVATE void xnn_compute_hmp_grouped_dqigemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    uint32_t uarch_index, size_t group_index, size_t mr_block_start,
    size_t nr_block_start, size_t mr_block_size, size_t nr_block_size);

XNN_PRIVATE void xnn_compute_hmp_grouped_batch_dqigemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    uint32_t uarch_index, size_t batch_index, size_t group_index,
    size_t mr_block_start, size_t nr_block_start, size_t mr_block_size,
    size_t nr_block_size);

XNN_PRIVATE void xnn_compute_hmp_grouped_batch_igemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    uint32_t uarch_index, size_t batch_index, size_t group_index,
    size_t mr_block_start, size_t nr_block_start, size_t mr_block_size,
    size_t nr_block_size);

XNN_PRIVATE void xnn_compute_hmp_dqigemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start,
    size_t mr_block_size, size_t nr_block_size);

XNN_PRIVATE void xnn_compute_hmp_igemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start,
    size_t mr_block_size, size_t nr_block_size);

XNN_PRIVATE void xnn_compute_batch_hmp_dqigemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    uint32_t uarch_index, size_t batch_index, size_t mr_block_start,
    size_t nr_block_start, size_t mr_block_size, size_t nr_block_size);

XNN_PRIVATE void xnn_compute_batch_hmp_igemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    uint32_t uarch_index, size_t batch_index, size_t mr_block_start,
    size_t nr_block_start, size_t mr_block_size, size_t nr_block_size);
#endif  // XNN_MAX_UARCH_TYPES > 1
#endif

struct subgemm_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_grouped_subgemm2d(
      const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t group_index,
      size_t subkernel_index,
      size_t slice_y,
      size_t slice_x_start,
      size_t nc_block_start,
      size_t slice_x_max,
      size_t nc_block_size);

  XNN_PRIVATE void xnn_compute_subgemm2d(
      const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t subkernel_index,
      size_t slice_y,
      size_t slice_x_start,
      size_t nc_block_start,
      size_t slice_x_max,
      size_t nc_block_size);
#endif

struct subconv_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_dq_zero_buffer_subconv(
    const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t size);

  XNN_PRIVATE void xnn_compute_grouped_subconv2d(
      const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t group_index,
      size_t subkernel_index,
      size_t slice_y,
      size_t slice_x_start,
      size_t nr_block_start,
      size_t slice_x_max,
      size_t nr_block_size);

  XNN_PRIVATE void xnn_compute_grouped_dqsubconv2d(
    const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t group_index,
    size_t subkernel_index,
    size_t slice_y,
    size_t slice_x_start,
    size_t nr_block_start,
    size_t slice_x_max,
    size_t nr_block_size);

  XNN_PRIVATE void xnn_compute_subconv2d(
      const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t subkernel_index,
      size_t slice_y,
      size_t slice_x_start,
      size_t nr_block_start,
      size_t slice_x_max,
      size_t nr_block_size);

  XNN_PRIVATE void xnn_compute_dqsubconv2d(
      const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t subkernel_index,
      size_t slice_y,
      size_t slice_x_start,
      size_t nr_block_start,
      size_t slice_x_max,
      size_t nr_block_size);
#endif

struct conv2d_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_conv2d_hwc2chw(
      const struct conv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t output_y_start,
      size_t output_y_slice);
#endif

// Context for initializing the indirection buffer for dwconv.
struct dwconv_indirection_init_context {};

struct dwconv_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_dwconv_indirection(
    const struct dwconv_indirection_init_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t output_y_start,
    size_t output_y_tile);
  XNN_PRIVATE void xnn_compute_dwconv_unipass(
      const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t output_y);
  XNN_PRIVATE void xnn_compute_dwconv_multipass(
      const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t output_y);
  XNN_PRIVATE void xnn_compute_dwconv_multipass_with_thread(
      const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t thread_index,
      size_t batch_index,
      size_t output_y);
#endif

struct dwconv2d_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_dwconv2d_chw(
      const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t channel);
#endif

struct max_pooling_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_max_pooling(
      const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t output_y);
#endif

struct unpooling_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_unpooling(
      const struct unpooling_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t input_y,
      size_t input_x);
#endif

struct argmax_pooling_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_argmax_pooling_unipass(
      const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t output_y);

  // Workspace sized based on batch size * output height.
  XNN_PRIVATE void xnn_compute_argmax_pooling_multipass(
      const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t output_y);

  // Workspace sized based on number of threads.
  XNN_PRIVATE void xnn_compute_argmax_pooling_multipass_with_thread(
      const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t thread_index,
      size_t batch_index,
      size_t output_y);
#endif

struct average_pooling_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_average_pooling_unipass(
      const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t output_y);

  XNN_PRIVATE void xnn_compute_average_pooling_multipass(
      const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t output_y);

  XNN_PRIVATE void xnn_compute_average_pooling_multipass_with_thread(
      const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t thread_index,
      size_t batch_index,
      size_t output_y);
#endif

struct pixelwise_average_pooling_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_unipass(
      const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t output_y);

  XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_multipass(
      const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t output_y);

  XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_multipass_with_thread(
      const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t thread_index,
      size_t batch_index,
      size_t output_y);
#endif

struct global_average_pooling_nwc_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_global_average_pooling_nwc_unipass(
      const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index);

  XNN_PRIVATE void xnn_compute_global_average_pooling_nwc_multipass(
      const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index);

  XNN_PRIVATE void xnn_compute_global_average_pooling_nwc_multipass_with_thread(
      const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t thread_index,
      size_t batch_index);
#endif

struct global_average_pooling_ncw_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_global_average_pooling_ncw(
      const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t channels_start,
      size_t channels_slice);
#endif

struct resize_bilinear_nhwc_indirection_init_context {};

struct resize_bilinear_context {};

struct resize_bilinear_chw_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_resize_bilinear_indirection(
      const struct resize_bilinear_nhwc_indirection_init_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t output_y_start,
      size_t output_y_tile);
  XNN_PRIVATE void xnn_compute_resize_bilinear(
      const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t pixel_start,
      size_t pixel_range);
  XNN_PRIVATE void xnn_compute_resize_bilinear_chw(
      const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t pixel_start,
      size_t pixel_range);
#endif

struct elementwise_binary_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_elementwise_binary_1d_tile(
      const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t offset, size_t tile);
  XNN_PRIVATE void xnn_compute_elementwise_binary_1d(
      const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t i);
  XNN_PRIVATE void xnn_compute_elementwise_binary_2d(
      const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t i, size_t j);
  XNN_PRIVATE void xnn_compute_elementwise_binary_3d(
      const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t i, size_t j, size_t k);
  XNN_PRIVATE void xnn_compute_elementwise_binary_4d(
      const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t i, size_t j, size_t k, size_t l);
  XNN_PRIVATE void xnn_compute_elementwise_binary_5d(
      const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t i, size_t j, size_t k, size_t l, size_t m);
#endif

struct channel_shuffle_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_channel_shuffle_fixed(
      const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t index);

  XNN_PRIVATE void xnn_compute_channel_shuffle_variable(
      const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t index);
#endif

struct lut_strided_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_lut_strided(
      const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index);
#endif

struct lut_contiguous_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_lut_contiguous(
      const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t offset,
      size_t size);
#endif

struct univector_strided_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_univector_strided(
      const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t batch_range);
#endif

struct univector_contiguous_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_univector_contiguous(
      const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t offset,
      size_t size);
#endif

struct reduce_context {};

#ifndef __cplusplus
// Compute contiguous reduction over the 1st, 3rd and 5th dimensions of the input
// tensor.
  XNN_PRIVATE void xnn_compute_contiguous_reduce(
      const struct reduce_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t output_idx0,
      size_t output_idx1,
      size_t output_idx2,
      size_t output1_block_size,
      size_t output2_block_size);
#endif

#ifndef __cplusplus
// Compute discontiguous reduction over the 0st, 2rd and 4th dimensions of the input
// tensor.
  XNN_PRIVATE void xnn_compute_discontiguous_reduce(
      const struct reduce_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t output_idx0,
      size_t output_idx1,
      size_t output_idx2,
      size_t output1_block_size,
      size_t output2_block_size);
#endif

struct prelu_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_prelu(
      const struct prelu_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_start,
      size_t batch_range);
#endif

struct vmulcaddc_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_vmulcaddc(
      const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_start,
      size_t batch_size);
#endif

struct pad_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_pad_5d(
      const struct pad_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t i, size_t j, size_t k, size_t l, size_t m);
#endif

struct slice_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_slice_1d(
      const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t i);
  XNN_PRIVATE void xnn_compute_slice_2d(
      const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t i, size_t j);
  XNN_PRIVATE void xnn_compute_slice_3d(
      const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t i, size_t j, size_t k);
  XNN_PRIVATE void xnn_compute_slice_4d(
      const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t i, size_t j, size_t k, size_t l);
  XNN_PRIVATE void xnn_compute_slice_5d(
      const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t i, size_t j, size_t k, size_t l, size_t m);
#endif

struct f16_qd8_convert_context {};

struct f32_qd8_convert_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_f16_qd8_convert(
      const struct f16_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index);

  XNN_PRIVATE void xnn_compute_f32_qd8_convert(
      const struct f32_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index);

  XNN_PRIVATE void xnn_compute_pad_qd8_params(
      const struct f32_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index);
#endif

struct f32_qp8_convert_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_f32_qp8_convert(
      const struct f32_qp8_convert_context
          context[restrict XNN_MIN_ELEMENTS(1)],
      size_t m_idx_start);
#endif

  struct u8_softmax_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_u8_softmax(
      const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index);
#endif

xnn_compute_reciprocal_fn;

struct floating_point_softmax_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_floating_point_softmax(
      const struct floating_point_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index);
#endif

struct rope_context {};

#ifndef __cplusplus
  XNN_PRIVATE void xnn_compute_rope(
      const struct rope_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t head_index,
      size_t sequence_index);
#endif

struct attention_logits_cap {};

struct scaled_dot_product_attention_context {};

#ifndef __cplusplus
  // We have 4 variations of compute scaled dot product attention:
  // 1. micro-architecture aware and not micro-architecture aware
  // 2. whether the workspace size is based on batch_size or number of heads.
  // The workspace size is chosen based on which one requires a smaller memory allocation for workspace.
  // Batch size (times query heads and query tokens) is compared to number of threads (times MR).
  XNN_PRIVATE void xnn_compute_scaled_dot_product_attention(
      const struct scaled_dot_product_attention_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t head_index,
      size_t tokens_start,
      size_t tokens_block_size);
  XNN_PRIVATE void xnn_compute_scaled_dot_product_attention_with_thread(
      const struct scaled_dot_product_attention_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t thread_index,
      size_t batch_index,
      size_t head_index,
      size_t tokens_start,
      size_t tokens_block_size);
  XNN_PRIVATE void xnn_compute_hmp_scaled_dot_product_attention(
      const struct scaled_dot_product_attention_context context[restrict XNN_MIN_ELEMENTS(1)],
      uint32_t uarch_index,
      size_t batch_index,
      size_t head_index,
      size_t tokens_start,
      size_t tokens_block_size);
  XNN_PRIVATE void xnn_compute_hmp_scaled_dot_product_attention_with_thread(
      const struct scaled_dot_product_attention_context context[restrict XNN_MIN_ELEMENTS(1)],
      uint32_t uarch_index,
      size_t thread_index,
      size_t batch_index,
      size_t head_index,
      size_t tokens_start,
      size_t tokens_block_size);
#endif