chromium/third_party/xnnpack/src/src/operator-run.c

// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include <string.h>

#include "xnnpack.h"
#include "xnnpack/common.h"
#include "xnnpack/compute.h"
#include "xnnpack/config-types.h"
#include "xnnpack/indirection.h"
#include "xnnpack/log.h"
#include "xnnpack/math.h"
#include "xnnpack/microfnptr.h"
#include "xnnpack/microkernel-type.h"
#include "xnnpack/microparams.h"
#include "xnnpack/operator-type.h"
#include "xnnpack/operator.h"
#include "xnnpack/packq.h"
#include "xnnpack/quantization.h"
#include "pthreadpool.h"

void xnn_compute_transposec_2d(
    const struct transpose_context* context,
    size_t i,
    size_t j,
    size_t tile_i,
    size_t tile_j)
{}

void xnn_compute_transposec_3d(
    const struct transpose_context* context,
    size_t i,
    size_t j,
    size_t k,
    size_t tile_j,
    size_t tile_k)
{}

void xnn_compute_transposec_4d(
    const struct transpose_context* context,
    size_t i,
    size_t j,
    size_t k,
    size_t l,
    size_t tile_k,
    size_t tile_l)
{}

void xnn_compute_transposec_5d(
    const struct transpose_context* context,
    size_t i,
    size_t j,
    size_t k,
    size_t l,
    size_t m,
    size_t tile_l,
    size_t tile_m)
{}

void xnn_compute_transposec_6d(
    const struct transpose_context* context,
    size_t i,
    size_t j,
    size_t k,
    size_t l,
    size_t m,
    size_t n,
    size_t tile_m,
    size_t tile_n)
{}

void xnn_compute_transposev_2d(
    const struct transpose_context* context,
    size_t i,
    size_t j,
    size_t tile_i,
    size_t tile_j)
{}

void xnn_compute_transposev_3d(
    const struct transpose_context* context,
    size_t i,
    size_t j,
    size_t k,
    size_t tile_j,
    size_t tile_k)
{}

void xnn_compute_transposev_4d(
    const struct transpose_context* context,
    size_t i,
    size_t j,
    size_t k,
    size_t l,
    size_t tile_k,
    size_t tile_l)
{}

void xnn_compute_transposev_5d(
    const struct transpose_context* context,
    size_t i,
    size_t j,
    size_t k,
    size_t l,
    size_t m,
    size_t tile_l,
    size_t tile_m)
{}

void xnn_compute_transposev_6d(
    const struct transpose_context* context,
    size_t i,
    size_t j,
    size_t k,
    size_t l,
    size_t m,
    size_t n,
    size_t tile_m,
    size_t tile_n)
{}

void xnn_compute_packw_gemm_gio(
    const struct packw_gemm_gio_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t n_block_start,
    size_t n_block_size)
{}

void xnn_compute_batched_packw_gemm_gio(
    const struct packw_gemm_gio_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t n_block_start,
    size_t n_block_size)
{}

void xnn_compute_packw_gemm_goi(
    const struct packw_gemm_goi_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t n_block_start,
    size_t n_block_size)
{}

void xnn_compute_batched_packw_gemm_goi(
    const struct packw_gemm_goi_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t n_block_start,
    size_t n_block_size)
{}

void xnn_compute_hmp_grouped_gemm(
    const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    uint32_t uarch_index, size_t group_index, size_t mr_block_start,
    size_t nr_block_start, size_t mr_block_size, size_t nr_block_size) {}

void xnn_compute_grouped_gemm(
    const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t group_index, size_t mr_block_start, size_t nr_block_start,
    size_t mr_block_size, size_t nr_block_size) {}

void xnn_compute_gemm(
    const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t mr_block_start,
    size_t nr_block_start,
    size_t mr_block_size,
    size_t nr_block_size)
{}

void xnn_compute_dqgemm(
    const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t mr_block_start,
    size_t nr_block_start,
    size_t mr_block_size,
    size_t nr_block_size)
{}

void xnn_compute_hmp_qp8gemm(
    const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start,
    size_t mr_block_size, size_t nr_block_size) {}

void xnn_compute_qp8gemm(
    const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t mr_block_start, size_t nr_block_start, size_t mr_block_size,
    size_t nr_block_size) {}

void xnn_compute_hmp_dqgemm_bl(
    const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    uint32_t uarch_index,
    size_t mr_block_start,
    size_t nr_block_start,
    size_t mr_block_size,
    size_t nr_block_size)
{}

void xnn_compute_dqgemm_bl(
    const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t mr_block_start,
    size_t nr_block_start,
    size_t mr_block_size,
    size_t nr_block_size)
{}

void xnn_compute_spmm(
    const struct spmm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t mr_block_start,
    size_t mr_block_size)
{}

void xnn_compute_grouped_batch_igemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t group_index,
    size_t mr_block_start,
    size_t nr_block_start,
    size_t mr_block_size,
    size_t nr_block_size)
{}

void xnn_compute_dq_zero_buffer_igemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index
    ) {}

void xnn_compute_dq_zero_buffer_subconv(
    const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index
    ) {}

void xnn_compute_grouped_batch_dqigemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t group_index,
    size_t mr_block_start,
    size_t nr_block_start,
    size_t mr_block_size,
    size_t nr_block_size)
{}

void xnn_compute_grouped_igemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t group_index,
    size_t mr_block_start,
    size_t nr_block_start,
    size_t mr_block_size,
    size_t nr_block_size)
{}

void xnn_compute_grouped_dqigemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t group_index,
    size_t mr_block_start,
    size_t nr_block_start,
    size_t mr_block_size,
    size_t nr_block_size)
{}

void xnn_compute_batch_igemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t mr_block_start,
    size_t nr_block_start,
    size_t mr_block_size,
    size_t nr_block_size)
{}

void xnn_compute_batch_dqigemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t mr_block_start,
    size_t nr_block_start,
    size_t mr_block_size,
    size_t nr_block_size)
{}

void xnn_compute_igemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t mr_block_start,
    size_t nr_block_start,
    size_t mr_block_size,
    size_t nr_block_size)
{}

void xnn_compute_dqigemm(
    const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t mr_block_start,
    size_t nr_block_start,
    size_t mr_block_size,
    size_t nr_block_size)
{}
// `output_tile_start` should be a multiple of igemm.mr (tile size).
void xnn_compute_conv2d_igemm_indirection(
    const struct conv2d_igemm_indirection_init_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t output_tile_start,
    size_t output_tile_size)
{}

void xnn_compute_grouped_subgemm2d(
      const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t group_index,
      size_t subkernel_index,
      size_t slice_y,
      size_t slice_x_start,
      size_t nc_block_start,
      size_t slice_x_max,
      size_t nc_block_size)
{}

void xnn_compute_subgemm2d(
      const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t subkernel_index,
      size_t slice_y,
      size_t slice_x_start,
      size_t nc_block_start,
      size_t slice_x_max,
      size_t nc_block_size)
{}

void xnn_compute_grouped_subconv2d(
      const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t group_index,
      size_t subkernel_index,
      size_t slice_y,
      size_t slice_x_start,
      size_t nc_block_start,
      size_t slice_x_max,
      size_t nc_block_size)
{}

void xnn_compute_grouped_dqsubconv2d(
      const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t group_index,
      size_t subkernel_index,
      size_t slice_y,
      size_t slice_x_start,
      size_t nc_block_start,
      size_t slice_x_max,
      size_t nc_block_size)
{}

void xnn_compute_subconv2d(
      const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t subkernel_index,
      size_t slice_y,
      size_t slice_x_start,
      size_t nc_block_start,
      size_t slice_x_max,
      size_t nc_block_size)
{}

void xnn_compute_dqsubconv2d(
      const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t subkernel_index,
      size_t slice_y,
      size_t slice_x_start,
      size_t nc_block_start,
      size_t slice_x_max,
      size_t nc_block_size)
{}

void xnn_compute_conv2d_hwc2chw(
      const struct conv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
      size_t batch_index,
      size_t output_y_start,
      size_t output_y_slice)
{}

void xnn_compute_dwconv_indirection(
  const struct dwconv_indirection_init_context context[restrict XNN_MIN_ELEMENTS(1)],
  size_t output_y_start,
  size_t output_y_tile)
{}

void xnn_compute_dwconv_unipass(
    const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t output_y)
{}

void xnn_compute_dwconv_multipass(
    const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t output_y)
{}

void xnn_compute_dwconv_multipass_with_thread(
    const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t thread_index,
    size_t batch_index,
    size_t output_y)
{}

void xnn_compute_dwconv2d_chw(
    const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t channel)
{}

void xnn_compute_argmax_pooling_unipass(
    const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t output_y)
{}

void xnn_compute_argmax_pooling_multipass(
    const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t output_y)
{}

void xnn_compute_argmax_pooling_multipass_with_thread(
    const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t thread_index,
    size_t batch_index,
    size_t output_y)
{}

void xnn_compute_max_pooling(
    const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t output_y)
{}

void xnn_compute_unpooling(
    const struct unpooling_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t input_y,
    size_t input_x)
{}

void xnn_compute_average_pooling_unipass(
    const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t output_y)
{}

void xnn_compute_average_pooling_multipass(
    const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t output_y)
{}

void xnn_compute_average_pooling_multipass_with_thread(
    const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t thread_index,
    size_t batch_index,
    size_t output_y)
{}

void xnn_compute_pixelwise_average_pooling_unipass(
    const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t output_y)
{}

void xnn_compute_pixelwise_average_pooling_multipass(
    const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t output_y)
{}

void xnn_compute_pixelwise_average_pooling_multipass_with_thread(
    const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t thread_index,
    size_t batch_index,
    size_t output_y)
{}

void xnn_compute_global_average_pooling_nwc_unipass(
    const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index)
{}

void xnn_compute_global_average_pooling_nwc_multipass(
    const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index)
{}

void xnn_compute_global_average_pooling_nwc_multipass_with_thread(
    const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t thread_index,
    size_t batch_index)
{}

void xnn_compute_global_average_pooling_ncw(
    const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t channels_start,
    size_t channels_slice)
{}

void xnn_compute_resize_bilinear_indirection(
    const struct resize_bilinear_nhwc_indirection_init_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t output_y_start,
    size_t output_y_tile)
{}

void xnn_compute_resize_bilinear(
    const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t pixel_start,
    size_t pixel_range)
{}

void xnn_compute_resize_bilinear_chw(
    const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t channel_start,
    size_t channel_range)
{}

void xnn_compute_prelu(
    const struct prelu_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_start,
    size_t batch_range)
{}

void xnn_compute_pad_5d(
    const struct pad_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t i, size_t j, size_t k, size_t l, size_t m)
{}

void xnn_compute_scaled_dot_product_attention(
  const struct scaled_dot_product_attention_context context[restrict XNN_MIN_ELEMENTS(1)],
  size_t batch_index,
  size_t head_index,
  size_t tokens_start,
  size_t tokens_block_size)
{}

void xnn_compute_scaled_dot_product_attention_with_thread(
  const struct scaled_dot_product_attention_context context[restrict XNN_MIN_ELEMENTS(1)],
  size_t thread_index,
  size_t batch_index,
  size_t head_index,
  size_t tokens_start,
  size_t tokens_block_size)
{}

void xnn_compute_slice_1d(
    const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t i)
{}

void xnn_compute_slice_2d(
    const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t i, size_t j)
{}

void xnn_compute_slice_3d(
    const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t i, size_t j, size_t k)
{}

void xnn_compute_slice_4d(
    const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t i, size_t j, size_t k, size_t l)
{}

void xnn_compute_slice_5d(
    const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t i, size_t j, size_t k, size_t l, size_t m)
{}

void xnn_compute_elementwise_binary_1d_tile(
    const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t offset,
    size_t size)
{}

void xnn_compute_elementwise_binary_1d(
    const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t i)
{}

void xnn_compute_elementwise_binary_2d(
    const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t i, size_t j)
{}

void xnn_compute_elementwise_binary_3d(
    const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t i, size_t j, size_t k)
{}

void xnn_compute_elementwise_binary_4d(
    const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t i, size_t j, size_t k, size_t l)
{}

void xnn_compute_elementwise_binary_5d(
    const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t i, size_t j, size_t k, size_t l, size_t m)
{}

void xnn_compute_channel_shuffle_fixed(
    const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t index)
{}

void xnn_compute_channel_shuffle_variable(
    const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t index)
{}

void xnn_compute_lut_strided(
    const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index)
{}

void xnn_compute_lut_contiguous(
    const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t offset,
    size_t size)
{}

void xnn_compute_univector_strided(
    const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t batch_range)
{}

void xnn_compute_univector_contiguous(
    const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t offset,
    size_t size)
{}

void xnn_compute_contiguous_reduce(
    const struct reduce_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t output_idx0,
    size_t output_idx1,
    size_t output_idx2,
    size_t output1_block_size,
    size_t output2_block_size)
{}

void xnn_compute_discontiguous_reduce(
    const struct reduce_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t output_idx0,
    size_t output_idx1,
    size_t output_idx2,
    size_t output1_block_size,
    size_t output2_block_size)
{}

void xnn_compute_pad_qd8_params(
    const struct f32_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index)
{}

void xnn_compute_f16_qd8_convert(
    const struct f16_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index)
{}

void xnn_compute_f32_qd8_convert(
    const struct f32_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index)
{}

void xnn_compute_f32_qp8_convert(
    const struct f32_qp8_convert_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t m_idx_start) {}

void xnn_compute_u8_softmax(
    const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index)
{}

void xnn_compute_floating_point_softmax(
    const struct floating_point_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index)
{}

void xnn_compute_vmulcaddc(
    const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_start,
    size_t batch_size)
{}

void xnn_compute_rope(
    const struct rope_context context[restrict XNN_MIN_ELEMENTS(1)],
    size_t batch_index,
    size_t head_index,
    size_t sequence_index)
{}

#if XNN_MAX_UARCH_TYPES > 1
void xnn_compute_hmp_gemm(
    const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
    uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start,
    size_t mr_block_size, size_t nr_block_size) {
  const size_t a_stride = context->a_stride;
  const size_t cm_stride = context->cm_stride;

  context->ukernel.function[uarch_index](
      mr_block_size, nr_block_size, context->k_scaled,
      (const void*)((uintptr_t)context->a + mr_block_start * a_stride),
      a_stride,
      (const void*)((uintptr_t)context->packed_w +
                    nr_block_start * context->w_stride),
      (void*)((uintptr_t)context->c + mr_block_start * cm_stride +
              (nr_block_start << context->log2_csize)),
      cm_stride, context->cn_stride, context->fused_params);
  }

  void xnn_compute_hmp_dqgemm(
      const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
      uint32_t uarch_index,
      size_t mr_block_start,
      size_t nr_block_start,
      size_t mr_block_size,
      size_t nr_block_size)
  {
    const size_t a_stride  = context->a_stride;
    const size_t cm_stride = context->cm_stride;

    context->dq_ukernel.function[uarch_index](
        mr_block_size,
        nr_block_size,
        context->k_scaled,
        (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
        a_stride,
        (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
        (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
        cm_stride,
        context->cn_stride,
        context->fused_params,
       (const void*) ((uintptr_t) &context->quantization_params[mr_block_start]));
  }

  void xnn_compute_hmp_grouped_batch_igemm(
      const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
      uint32_t uarch_index,
      size_t batch_index,
      size_t group_index,
      size_t mr_block_start,
      size_t nr_block_start,
      size_t mr_block_size,
      size_t nr_block_size)
  {
    const size_t ks        = context->ks;
    const size_t cm_stride = context->cm_stride;

    context->ukernel.function[uarch_index](
        mr_block_size,
        nr_block_size,
        context->kc,
        context->ks_scaled,
        (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
        (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
        (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
        cm_stride,
        context->cn_stride,
        context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
        context->zero,
        &context->params);
  }

  void xnn_compute_hmp_grouped_batch_dqigemm(
      const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
      uint32_t uarch_index,
      size_t batch_index,
      size_t group_index,
      size_t mr_block_start,
      size_t nr_block_start,
      size_t mr_block_size,
      size_t nr_block_size)
  {
    const size_t ks        = context->ks;
    const size_t cm_stride = context->cm_stride;

    context->dq_ukernel.function[uarch_index](
        mr_block_size,
        nr_block_size,
        context->kc,
        context->ks_scaled,
        (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
        (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
        (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
        cm_stride,
        context->cn_stride,
        context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
        context->zero,
        context->zero_buffers[batch_index],
        &context->params,
        (const void*) ((uintptr_t) &context->quantization_params[batch_index]));
  }

  void xnn_compute_hmp_grouped_igemm(
      const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
      uint32_t uarch_index,
      size_t group_index,
      size_t mr_block_start,
      size_t nr_block_start,
      size_t mr_block_size,
      size_t nr_block_size)
  {
    const size_t ks        = context->ks;
    const size_t cm_stride = context->cm_stride;

    context->ukernel.function[uarch_index](
        mr_block_size,
        nr_block_size,
        context->kc,
        context->ks_scaled,
        (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
        (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
        (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
        cm_stride,
        context->cn_stride,
        context->a_offset + group_index * context->ga_stride,
        context->zero,
        &context->params);
  }

  void xnn_compute_hmp_grouped_dqigemm(
      const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
      uint32_t uarch_index,
      size_t group_index,
      size_t mr_block_start,
      size_t nr_block_start,
      size_t mr_block_size,
      size_t nr_block_size)
  {
    const size_t ks        = context->ks;
    const size_t cm_stride = context->cm_stride;

    context->dq_ukernel.function[uarch_index](
        mr_block_size,
        nr_block_size,
        context->kc,
        context->ks_scaled,
        (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
        (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
        (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
        cm_stride,
        context->cn_stride,
        context->a_offset + group_index * context->ga_stride,
        context->zero,
        context->zero_buffers[0],
        &context->params,
        (const void*) ((uintptr_t) context->quantization_params));
  }

  void xnn_compute_batch_hmp_igemm(
      const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
      uint32_t uarch_index,
      size_t batch_index,
      size_t mr_block_start,
      size_t nr_block_start,
      size_t mr_block_size,
      size_t nr_block_size)
  {
    const size_t ks        = context->ks;
    const size_t cm_stride = context->cm_stride;

    context->ukernel.function[uarch_index](
        mr_block_size,
        nr_block_size,
        context->kc,
        context->ks_scaled,
        (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
        (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
        (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
        cm_stride,
        context->cn_stride,
        context->a_offset + batch_index * context->ba_stride,
        context->zero,
        &context->params);
  }

  void xnn_compute_batch_hmp_dqigemm(
      const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
      uint32_t uarch_index,
      size_t batch_index,
      size_t mr_block_start,
      size_t nr_block_start,
      size_t mr_block_size,
      size_t nr_block_size)
  {
    const size_t ks        = context->ks;
    const size_t cm_stride = context->cm_stride;

    context->dq_ukernel.function[uarch_index](
        mr_block_size,
        nr_block_size,
        context->kc,
        context->ks_scaled,
        (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
        (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
        (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
        cm_stride,
        context->cn_stride,
        context->a_offset + batch_index * context->ba_stride,
        context->zero,
        context->zero_buffers[batch_index],
        &context->params,
        (const void*) ((uintptr_t) &context->quantization_params[batch_index]));
  }

  void xnn_compute_hmp_igemm(
      const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
      uint32_t uarch_index,
      size_t mr_block_start,
      size_t nr_block_start,
      size_t mr_block_size,
      size_t nr_block_size)
  {
    const size_t ks        = context->ks;
    const size_t cm_stride = context->cm_stride;

    context->ukernel.function[uarch_index](
        mr_block_size,
        nr_block_size,
        context->kc,
        context->ks_scaled,
        (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
        (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
        (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
        cm_stride,
        context->cn_stride,
        context->a_offset,
        context->zero,
        &context->params);
  }

  void xnn_compute_hmp_dqigemm(
      const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
      uint32_t uarch_index,
      size_t mr_block_start,
      size_t nr_block_start,
      size_t mr_block_size,
      size_t nr_block_size)
  {
    const size_t ks        = context->ks;
    const size_t cm_stride = context->cm_stride;

    context->dq_ukernel.function[uarch_index](
        mr_block_size,
        nr_block_size,
        context->kc,
        context->ks_scaled,
        (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
        (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
        (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
        cm_stride,
        context->cn_stride,
        context->a_offset,
        context->zero,
        context->zero_buffers[0],
        &context->params,
        (const void*) ((uintptr_t) context->quantization_params));
  }

void xnn_compute_hmp_scaled_dot_product_attention(
  const struct scaled_dot_product_attention_context context[restrict XNN_MIN_ELEMENTS(1)],
  uint32_t uarch_index,
  size_t batch_index,
  size_t head_index,
  size_t tokens_start,
  size_t tokens_block_size)
{
  const size_t query_key_scaled_channels = context->query_key_scaled_channels;
  const size_t query_tile_offset =
    batch_index * context->query_batch_stride + head_index * context->query_head_stride +
    tokens_start * query_key_scaled_channels;
  const size_t key_value_tokens_scaled = context->key_value_tokens_scaled;
  const size_t key_value_tokens_start_scaled = tokens_start * key_value_tokens_scaled;
  const size_t cn_stride = context->cn_stride;
  const void* scaled_query = (void*) ((uintptr_t) context->scaled_query + query_tile_offset);
  const void* minmax_params = &context->minmax_params;

  {
    uintptr_t query = (uintptr_t) context->query + query_tile_offset;
    uintptr_t query_scaled_current = (uintptr_t) scaled_query;
    // Q_scaled = Q * Scale (along channels). Q and Q_scaled have dimensions [tokens_block_size, query_key_channels].
    size_t i = tokens_block_size;
    do {
      context->vmul_ukernel(
        /*batch=*/query_key_scaled_channels,
        /*input_x=*/(const void*) query,
        /*input_y=*/context->scale,
        /*output=*/(void*) query_scaled_current,
        /*params=*/minmax_params);
      query += query_key_scaled_channels;
      query_scaled_current += query_key_scaled_channels;
    } while (--i != 0);
  }

  const size_t logits_batch_offset =
    batch_index * context->logits_batch_stride + head_index * context->logits_head_stride;
  void* const logits =
    (void*) (((uintptr_t) context->logits_buffer + logits_batch_offset + key_value_tokens_start_scaled));

  {
    void* key = (void*) ((uintptr_t) context->key +
                         batch_index * context->key_batch_stride +
                         head_index * context->key_head_stride);
    // S = GEMM(Q_scaled, K^t). S is [tokens_block_size, key_value_tokens].
    context->gemm_ukernel.function[uarch_index](
      /*mr=*/tokens_block_size,
      /*nr=*/context->key_value_tokens,
      /*k=*/query_key_scaled_channels,
      /*a=*/scaled_query,
      /*a_stride=*/query_key_scaled_channels,
      /*w=*/(void*) key,
      /*c=*/(void*) (uintptr_t) logits,
      /*cm_stride=*/key_value_tokens_scaled,
      /*cn_stride=*/cn_stride,
      /*params=*/minmax_params);
  }

  {
    const size_t tokens_block_size_scaled = tokens_block_size * key_value_tokens_scaled;
    struct attention_logits_cap logits_cap = context->logits_cap;
    if (logits_cap.type == xnn_attention_logits_cap_type_tanh) {
      // (Optional) S = TanH(S/Cap) * Cap. Overwrites buffer.
      context->vmulc_ukernel(
        /*batch=*/tokens_block_size_scaled,
        /*input_x=*/logits,
        /*input_y=*/&logits_cap.cap_reciprocal,
        /*output=*/logits,
        /*params=*/minmax_params);
      context->vtanh_ukernel(
        /*batch=*/tokens_block_size_scaled,
        /*input=*/logits,
        /*output=*/logits,
        /*params=*/&context->tanh_params);
      context->vmulc_ukernel(
        /*batch=*/tokens_block_size_scaled,
        /*input_x=*/logits,
        /*input_y=*/&logits_cap.cap,
        /*output=*/logits,
        /*params=*/minmax_params);
    }

    // S = S + Mask. Mask has dimensions [query_tokens, key_value_tokens].
    // Mask. Overwrites buffer.
    context->vadd_ukernel(
      /*batch=*/tokens_block_size_scaled,
      /*input_x=*/logits,
      /*input_y=*/(void*) ((uintptr_t) context->mask + key_value_tokens_start_scaled),
      /*output=*/logits,
      /*params=*/minmax_params);
  }

  // P = Softmax(S). P has dimensions [tokens_block_size, key_value_tokens].
  {
    void* logits_row = logits;
    size_t i = tokens_block_size;
    do {
      // Skip initialization of locals as they will be written to immediately.
      float rowmax;
      context->rmax_ukernel(
        /*batch=*/key_value_tokens_scaled,
        /*input=*/logits_row,
        /*output=*/&rowmax,
        /*params=*/&context->rmax_params);

      float rowsum;
      context->raddstoreexpminusmax_ukernel(
        /*batch=*/key_value_tokens_scaled,
        /*input=*/logits_row,
        /*max=*/&rowmax,
        /*output=*/logits_row,
        /*sum=*/&rowsum,
        /*params=*/&context->expminus_params);

      float rowscale;
      context->compute_reciprocal(
        /*input=*/&rowsum,
        /*output=*/&rowscale);

      context->vmulc_ukernel(
        /*batch=*/key_value_tokens_scaled,
        /*input_x=*/logits_row,
        /*input_y=*/&rowscale,
        /*output=*/logits_row,
        /*params=*/minmax_params);

      logits_row = (void*) ((uintptr_t) logits_row + key_value_tokens_scaled);
    } while (--i != 0);
  }

  {
    void* value = (void*) ((uintptr_t) context->value +
                           batch_index * context->value_batch_stride +
                           head_index * context->value_head_stride);
    const size_t output_tile_offset =
      batch_index * context->output_batch_stride + head_index * context->output_head_stride +
      tokens_start * context->value_scaled_channels;
    // O = GEMM(P, V). O has dimension [tokens_block_size, value_channels].
    context->gemm_ukernel.function[uarch_index](
        /*mr=*/tokens_block_size,
        /*nc=*/context->value_channels,
        /*kc=*/key_value_tokens_scaled,
        /*a=*/logits,
        /*a_stride=*/key_value_tokens_scaled,
        /*w=*/value,
        /*c=*/(void*) ((uintptr_t) context->output + output_tile_offset),
        /*cm_stride=*/context->value_scaled_channels,
        /*cn_stride=*/cn_stride,
        /*params=*/minmax_params);
  }
}

void xnn_compute_hmp_scaled_dot_product_attention_with_thread(
  const struct scaled_dot_product_attention_context context[restrict XNN_MIN_ELEMENTS(1)],
  uint32_t uarch_index,
  size_t thread_index,
  size_t batch_index,
  size_t head_index,
  size_t tokens_start,
  size_t tokens_block_size)
{
  const size_t query_key_scaled_channels = context->query_key_scaled_channels;
  const size_t query_tile_offset =
    batch_index * context->query_batch_stride + head_index * context->query_head_stride +
    tokens_start * query_key_scaled_channels;
  const size_t key_value_tokens_scaled = context->key_value_tokens_scaled;
  const size_t key_value_tokens_start_scaled = tokens_start * key_value_tokens_scaled;
  const size_t cn_stride = context->cn_stride;
  const void* scaled_query =
    (void*) ((uintptr_t) context->scaled_query + thread_index * context->scaled_query_thread_stride);
  const void* minmax_params = &context->minmax_params;

  {
    uintptr_t query = (uintptr_t) context->query + query_tile_offset;
    uintptr_t query_scaled_current = (uintptr_t) scaled_query;
    // Q_scaled = Q * Scale (along channels). Q and Q_scaled have dimensions [tokens_block_size, query_key_channels].
    size_t i = tokens_block_size;
    do {
      context->vmul_ukernel(
        /*batch=*/query_key_scaled_channels,
        /*input_x=*/(const void*) query,
        /*input_y=*/context->scale,
        /*output=*/(void*) query_scaled_current,
        /*params=*/minmax_params);
      query += query_key_scaled_channels;
      query_scaled_current += query_key_scaled_channels;
    } while (--i != 0);
  }

  void* const logits = (void*) ((uintptr_t) context->logits_buffer + thread_index * context->logits_thread_stride);

  {
    void* key = (void*) ((uintptr_t) context->key +
                         batch_index * context->key_batch_stride +
                         head_index * context->key_head_stride);
    // S = GEMM(Q_scaled, K^t). S is [tokens_block_size, key_value_tokens].
    context->gemm_ukernel.function[uarch_index](
      /*mr=*/tokens_block_size,
      /*nr=*/context->key_value_tokens,
      /*k=*/query_key_scaled_channels,
      /*a=*/scaled_query,
      /*a_stride=*/query_key_scaled_channels,
      /*w=*/(void*) key,
      /*c=*/(void*) (uintptr_t) logits,
      /*cm_stride=*/key_value_tokens_scaled,
      /*cn_stride=*/cn_stride,
      /*params=*/minmax_params);
  }

  {
    const size_t tokens_block_size_scaled = tokens_block_size * key_value_tokens_scaled;
    struct attention_logits_cap logits_cap = context->logits_cap;
    if (logits_cap.type == xnn_attention_logits_cap_type_tanh) {
      // (Optional) S = TanH(S/Cap) * Cap. Overwrites buffer.
      context->vmulc_ukernel(
        /*batch=*/tokens_block_size_scaled,
        /*input_x=*/logits,
        /*input_y=*/&logits_cap.cap_reciprocal,
        /*output=*/logits,
        /*params=*/minmax_params);
      context->vtanh_ukernel(
        /*batch=*/tokens_block_size_scaled,
        /*input=*/logits,
        /*output=*/logits,
        /*params=*/&context->tanh_params);
      context->vmulc_ukernel(
        /*batch=*/tokens_block_size_scaled,
        /*input_x=*/logits,
        /*input_y=*/&logits_cap.cap,
        /*output=*/logits,
        /*params=*/minmax_params);
    }

    // S = S + Mask. Mask has dimensions [query_tokens, key_value_tokens].
    // Mask. Overwrites buffer.
    context->vadd_ukernel(
      /*batch=*/tokens_block_size_scaled,
      /*input_x=*/logits,
      /*input_y=*/(void*) ((uintptr_t) context->mask + key_value_tokens_start_scaled),
      /*output=*/logits,
      /*params=*/minmax_params);
  }

  // P = Softmax(S). P has dimensions [tokens_block_size, key_value_tokens].
  {
    void* logits_row = logits;
    size_t i = tokens_block_size;
    do {
      // Skip initialization of locals as they will be written to immediately.
      float rowmax;
      context->rmax_ukernel(
        /*batch=*/key_value_tokens_scaled,
        /*input=*/logits_row,
        /*output=*/&rowmax,
        /*params=*/&context->rmax_params);

      float rowsum;
      context->raddstoreexpminusmax_ukernel(
        /*batch=*/key_value_tokens_scaled,
        /*input=*/logits_row,
        /*max=*/&rowmax,
        /*output=*/logits_row,
        /*sum=*/&rowsum,
        /*params=*/&context->expminus_params);

      float rowscale;
      context->compute_reciprocal(
        /*input=*/&rowsum,
        /*output=*/&rowscale);

      context->vmulc_ukernel(
        /*batch=*/key_value_tokens_scaled,
        /*input_x=*/logits_row,
        /*input_y=*/&rowscale,
        /*output=*/logits_row,
        /*params=*/minmax_params);

      logits_row = (void*) ((uintptr_t) logits_row + key_value_tokens_scaled);
    } while (--i != 0);
  }

  {
    void* value = (void*) ((uintptr_t) context->value +
                           batch_index * context->value_batch_stride +
                           head_index * context->value_head_stride);
    const size_t output_tile_offset =
      batch_index * context->output_batch_stride + head_index * context->output_head_stride +
      tokens_start * context->value_scaled_channels;
    // O = GEMM(P, V). O has dimension [tokens_block_size, value_channels].
    context->gemm_ukernel.function[uarch_index](
        /*mr=*/tokens_block_size,
        /*nc=*/context->value_channels,
        /*kc=*/key_value_tokens_scaled,
        /*a=*/logits,
        /*a_stride=*/key_value_tokens_scaled,
        /*w=*/value,
        /*c=*/(void*) ((uintptr_t) context->output + output_tile_offset),
        /*cm_stride=*/context->value_scaled_channels,
        /*cn_stride=*/cn_stride,
        /*params=*/minmax_params);
  }
}
#endif  // XNN_MAX_UARCH_TYPES > 1

enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool)
{}

enum xnn_status xnn_run_operator_with_index(
  xnn_operator_t op,
  size_t opdata_index,
  size_t operator_object_index,
  pthreadpool_t threadpool)
{}