#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include "xnnpack.h"
#include "xnnpack/common.h"
#include "xnnpack/compute.h"
#include "xnnpack/config-types.h"
#include "xnnpack/indirection.h"
#include "xnnpack/log.h"
#include "xnnpack/math.h"
#include "xnnpack/microfnptr.h"
#include "xnnpack/microkernel-type.h"
#include "xnnpack/microparams.h"
#include "xnnpack/operator-type.h"
#include "xnnpack/operator.h"
#include "xnnpack/packq.h"
#include "xnnpack/quantization.h"
#include "pthreadpool.h"
void xnn_compute_transposec_2d(
const struct transpose_context* context,
size_t i,
size_t j,
size_t tile_i,
size_t tile_j)
{ … }
void xnn_compute_transposec_3d(
const struct transpose_context* context,
size_t i,
size_t j,
size_t k,
size_t tile_j,
size_t tile_k)
{ … }
void xnn_compute_transposec_4d(
const struct transpose_context* context,
size_t i,
size_t j,
size_t k,
size_t l,
size_t tile_k,
size_t tile_l)
{ … }
void xnn_compute_transposec_5d(
const struct transpose_context* context,
size_t i,
size_t j,
size_t k,
size_t l,
size_t m,
size_t tile_l,
size_t tile_m)
{ … }
void xnn_compute_transposec_6d(
const struct transpose_context* context,
size_t i,
size_t j,
size_t k,
size_t l,
size_t m,
size_t n,
size_t tile_m,
size_t tile_n)
{ … }
void xnn_compute_transposev_2d(
const struct transpose_context* context,
size_t i,
size_t j,
size_t tile_i,
size_t tile_j)
{ … }
void xnn_compute_transposev_3d(
const struct transpose_context* context,
size_t i,
size_t j,
size_t k,
size_t tile_j,
size_t tile_k)
{ … }
void xnn_compute_transposev_4d(
const struct transpose_context* context,
size_t i,
size_t j,
size_t k,
size_t l,
size_t tile_k,
size_t tile_l)
{ … }
void xnn_compute_transposev_5d(
const struct transpose_context* context,
size_t i,
size_t j,
size_t k,
size_t l,
size_t m,
size_t tile_l,
size_t tile_m)
{ … }
void xnn_compute_transposev_6d(
const struct transpose_context* context,
size_t i,
size_t j,
size_t k,
size_t l,
size_t m,
size_t n,
size_t tile_m,
size_t tile_n)
{ … }
void xnn_compute_packw_gemm_gio(
const struct packw_gemm_gio_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t n_block_start,
size_t n_block_size)
{ … }
void xnn_compute_batched_packw_gemm_gio(
const struct packw_gemm_gio_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t n_block_start,
size_t n_block_size)
{ … }
void xnn_compute_packw_gemm_goi(
const struct packw_gemm_goi_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t n_block_start,
size_t n_block_size)
{ … }
void xnn_compute_batched_packw_gemm_goi(
const struct packw_gemm_goi_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t n_block_start,
size_t n_block_size)
{ … }
void xnn_compute_hmp_grouped_gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index, size_t group_index, size_t mr_block_start,
size_t nr_block_start, size_t mr_block_size, size_t nr_block_size) { … }
void xnn_compute_grouped_gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t group_index, size_t mr_block_start, size_t nr_block_start,
size_t mr_block_size, size_t nr_block_size) { … }
void xnn_compute_gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{ … }
void xnn_compute_dqgemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{ … }
void xnn_compute_hmp_qp8gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start,
size_t mr_block_size, size_t nr_block_size) { … }
void xnn_compute_qp8gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t mr_block_start, size_t nr_block_start, size_t mr_block_size,
size_t nr_block_size) { … }
void xnn_compute_hmp_dqgemm_bl(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{ … }
void xnn_compute_dqgemm_bl(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{ … }
void xnn_compute_spmm(
const struct spmm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t mr_block_start,
size_t mr_block_size)
{ … }
void xnn_compute_grouped_batch_igemm(
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t group_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{ … }
void xnn_compute_dq_zero_buffer_igemm(
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index
) { … }
void xnn_compute_dq_zero_buffer_subconv(
const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index
) { … }
void xnn_compute_grouped_batch_dqigemm(
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t group_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{ … }
void xnn_compute_grouped_igemm(
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t group_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{ … }
void xnn_compute_grouped_dqigemm(
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t group_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{ … }
void xnn_compute_batch_igemm(
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{ … }
void xnn_compute_batch_dqigemm(
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{ … }
void xnn_compute_igemm(
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{ … }
void xnn_compute_dqigemm(
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{ … }
void xnn_compute_conv2d_igemm_indirection(
const struct conv2d_igemm_indirection_init_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t output_tile_start,
size_t output_tile_size)
{ … }
void xnn_compute_grouped_subgemm2d(
const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t group_index,
size_t subkernel_index,
size_t slice_y,
size_t slice_x_start,
size_t nc_block_start,
size_t slice_x_max,
size_t nc_block_size)
{ … }
void xnn_compute_subgemm2d(
const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t subkernel_index,
size_t slice_y,
size_t slice_x_start,
size_t nc_block_start,
size_t slice_x_max,
size_t nc_block_size)
{ … }
void xnn_compute_grouped_subconv2d(
const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t group_index,
size_t subkernel_index,
size_t slice_y,
size_t slice_x_start,
size_t nc_block_start,
size_t slice_x_max,
size_t nc_block_size)
{ … }
void xnn_compute_grouped_dqsubconv2d(
const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t group_index,
size_t subkernel_index,
size_t slice_y,
size_t slice_x_start,
size_t nc_block_start,
size_t slice_x_max,
size_t nc_block_size)
{ … }
void xnn_compute_subconv2d(
const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t subkernel_index,
size_t slice_y,
size_t slice_x_start,
size_t nc_block_start,
size_t slice_x_max,
size_t nc_block_size)
{ … }
void xnn_compute_dqsubconv2d(
const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t subkernel_index,
size_t slice_y,
size_t slice_x_start,
size_t nc_block_start,
size_t slice_x_max,
size_t nc_block_size)
{ … }
void xnn_compute_conv2d_hwc2chw(
const struct conv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t output_y_start,
size_t output_y_slice)
{ … }
void xnn_compute_dwconv_indirection(
const struct dwconv_indirection_init_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t output_y_start,
size_t output_y_tile)
{ … }
void xnn_compute_dwconv_unipass(
const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t output_y)
{ … }
void xnn_compute_dwconv_multipass(
const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t output_y)
{ … }
void xnn_compute_dwconv_multipass_with_thread(
const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t thread_index,
size_t batch_index,
size_t output_y)
{ … }
void xnn_compute_dwconv2d_chw(
const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t channel)
{ … }
void xnn_compute_argmax_pooling_unipass(
const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t output_y)
{ … }
void xnn_compute_argmax_pooling_multipass(
const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t output_y)
{ … }
void xnn_compute_argmax_pooling_multipass_with_thread(
const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t thread_index,
size_t batch_index,
size_t output_y)
{ … }
void xnn_compute_max_pooling(
const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t output_y)
{ … }
void xnn_compute_unpooling(
const struct unpooling_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t input_y,
size_t input_x)
{ … }
void xnn_compute_average_pooling_unipass(
const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t output_y)
{ … }
void xnn_compute_average_pooling_multipass(
const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t output_y)
{ … }
void xnn_compute_average_pooling_multipass_with_thread(
const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t thread_index,
size_t batch_index,
size_t output_y)
{ … }
void xnn_compute_pixelwise_average_pooling_unipass(
const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t output_y)
{ … }
void xnn_compute_pixelwise_average_pooling_multipass(
const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t output_y)
{ … }
void xnn_compute_pixelwise_average_pooling_multipass_with_thread(
const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t thread_index,
size_t batch_index,
size_t output_y)
{ … }
void xnn_compute_global_average_pooling_nwc_unipass(
const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index)
{ … }
void xnn_compute_global_average_pooling_nwc_multipass(
const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index)
{ … }
void xnn_compute_global_average_pooling_nwc_multipass_with_thread(
const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t thread_index,
size_t batch_index)
{ … }
void xnn_compute_global_average_pooling_ncw(
const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t channels_start,
size_t channels_slice)
{ … }
void xnn_compute_resize_bilinear_indirection(
const struct resize_bilinear_nhwc_indirection_init_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t output_y_start,
size_t output_y_tile)
{ … }
void xnn_compute_resize_bilinear(
const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t pixel_start,
size_t pixel_range)
{ … }
void xnn_compute_resize_bilinear_chw(
const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t channel_start,
size_t channel_range)
{ … }
void xnn_compute_prelu(
const struct prelu_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_start,
size_t batch_range)
{ … }
void xnn_compute_pad_5d(
const struct pad_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t i, size_t j, size_t k, size_t l, size_t m)
{ … }
void xnn_compute_scaled_dot_product_attention(
const struct scaled_dot_product_attention_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t head_index,
size_t tokens_start,
size_t tokens_block_size)
{ … }
void xnn_compute_scaled_dot_product_attention_with_thread(
const struct scaled_dot_product_attention_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t thread_index,
size_t batch_index,
size_t head_index,
size_t tokens_start,
size_t tokens_block_size)
{ … }
void xnn_compute_slice_1d(
const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t i)
{ … }
void xnn_compute_slice_2d(
const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t i, size_t j)
{ … }
void xnn_compute_slice_3d(
const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t i, size_t j, size_t k)
{ … }
void xnn_compute_slice_4d(
const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t i, size_t j, size_t k, size_t l)
{ … }
void xnn_compute_slice_5d(
const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t i, size_t j, size_t k, size_t l, size_t m)
{ … }
void xnn_compute_elementwise_binary_1d_tile(
const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t offset,
size_t size)
{ … }
void xnn_compute_elementwise_binary_1d(
const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t i)
{ … }
void xnn_compute_elementwise_binary_2d(
const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t i, size_t j)
{ … }
void xnn_compute_elementwise_binary_3d(
const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t i, size_t j, size_t k)
{ … }
void xnn_compute_elementwise_binary_4d(
const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t i, size_t j, size_t k, size_t l)
{ … }
void xnn_compute_elementwise_binary_5d(
const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t i, size_t j, size_t k, size_t l, size_t m)
{ … }
void xnn_compute_channel_shuffle_fixed(
const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t index)
{ … }
void xnn_compute_channel_shuffle_variable(
const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t index)
{ … }
void xnn_compute_lut_strided(
const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index)
{ … }
void xnn_compute_lut_contiguous(
const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t offset,
size_t size)
{ … }
void xnn_compute_univector_strided(
const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t batch_range)
{ … }
void xnn_compute_univector_contiguous(
const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t offset,
size_t size)
{ … }
void xnn_compute_contiguous_reduce(
const struct reduce_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t output_idx0,
size_t output_idx1,
size_t output_idx2,
size_t output1_block_size,
size_t output2_block_size)
{ … }
void xnn_compute_discontiguous_reduce(
const struct reduce_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t output_idx0,
size_t output_idx1,
size_t output_idx2,
size_t output1_block_size,
size_t output2_block_size)
{ … }
void xnn_compute_pad_qd8_params(
const struct f32_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index)
{ … }
void xnn_compute_f16_qd8_convert(
const struct f16_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index)
{ … }
void xnn_compute_f32_qd8_convert(
const struct f32_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index)
{ … }
void xnn_compute_f32_qp8_convert(
const struct f32_qp8_convert_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t m_idx_start) { … }
void xnn_compute_u8_softmax(
const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index)
{ … }
void xnn_compute_floating_point_softmax(
const struct floating_point_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index)
{ … }
void xnn_compute_vmulcaddc(
const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_start,
size_t batch_size)
{ … }
void xnn_compute_rope(
const struct rope_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t head_index,
size_t sequence_index)
{ … }
#if XNN_MAX_UARCH_TYPES > 1
void xnn_compute_hmp_gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start,
size_t mr_block_size, size_t nr_block_size) {
const size_t a_stride = context->a_stride;
const size_t cm_stride = context->cm_stride;
context->ukernel.function[uarch_index](
mr_block_size, nr_block_size, context->k_scaled,
(const void*)((uintptr_t)context->a + mr_block_start * a_stride),
a_stride,
(const void*)((uintptr_t)context->packed_w +
nr_block_start * context->w_stride),
(void*)((uintptr_t)context->c + mr_block_start * cm_stride +
(nr_block_start << context->log2_csize)),
cm_stride, context->cn_stride, context->fused_params);
}
void xnn_compute_hmp_dqgemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{
const size_t a_stride = context->a_stride;
const size_t cm_stride = context->cm_stride;
context->dq_ukernel.function[uarch_index](
mr_block_size,
nr_block_size,
context->k_scaled,
(const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
a_stride,
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
(void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
cm_stride,
context->cn_stride,
context->fused_params,
(const void*) ((uintptr_t) &context->quantization_params[mr_block_start]));
}
void xnn_compute_hmp_grouped_batch_igemm(
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index,
size_t batch_index,
size_t group_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{
const size_t ks = context->ks;
const size_t cm_stride = context->cm_stride;
context->ukernel.function[uarch_index](
mr_block_size,
nr_block_size,
context->kc,
context->ks_scaled,
(const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
(void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
cm_stride,
context->cn_stride,
context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
context->zero,
&context->params);
}
void xnn_compute_hmp_grouped_batch_dqigemm(
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index,
size_t batch_index,
size_t group_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{
const size_t ks = context->ks;
const size_t cm_stride = context->cm_stride;
context->dq_ukernel.function[uarch_index](
mr_block_size,
nr_block_size,
context->kc,
context->ks_scaled,
(const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
(void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
cm_stride,
context->cn_stride,
context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
context->zero,
context->zero_buffers[batch_index],
&context->params,
(const void*) ((uintptr_t) &context->quantization_params[batch_index]));
}
void xnn_compute_hmp_grouped_igemm(
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index,
size_t group_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{
const size_t ks = context->ks;
const size_t cm_stride = context->cm_stride;
context->ukernel.function[uarch_index](
mr_block_size,
nr_block_size,
context->kc,
context->ks_scaled,
(const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
(void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
cm_stride,
context->cn_stride,
context->a_offset + group_index * context->ga_stride,
context->zero,
&context->params);
}
void xnn_compute_hmp_grouped_dqigemm(
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index,
size_t group_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{
const size_t ks = context->ks;
const size_t cm_stride = context->cm_stride;
context->dq_ukernel.function[uarch_index](
mr_block_size,
nr_block_size,
context->kc,
context->ks_scaled,
(const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
(void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
cm_stride,
context->cn_stride,
context->a_offset + group_index * context->ga_stride,
context->zero,
context->zero_buffers[0],
&context->params,
(const void*) ((uintptr_t) context->quantization_params));
}
void xnn_compute_batch_hmp_igemm(
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index,
size_t batch_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{
const size_t ks = context->ks;
const size_t cm_stride = context->cm_stride;
context->ukernel.function[uarch_index](
mr_block_size,
nr_block_size,
context->kc,
context->ks_scaled,
(const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
(void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
cm_stride,
context->cn_stride,
context->a_offset + batch_index * context->ba_stride,
context->zero,
&context->params);
}
void xnn_compute_batch_hmp_dqigemm(
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index,
size_t batch_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{
const size_t ks = context->ks;
const size_t cm_stride = context->cm_stride;
context->dq_ukernel.function[uarch_index](
mr_block_size,
nr_block_size,
context->kc,
context->ks_scaled,
(const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
(void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
cm_stride,
context->cn_stride,
context->a_offset + batch_index * context->ba_stride,
context->zero,
context->zero_buffers[batch_index],
&context->params,
(const void*) ((uintptr_t) &context->quantization_params[batch_index]));
}
void xnn_compute_hmp_igemm(
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{
const size_t ks = context->ks;
const size_t cm_stride = context->cm_stride;
context->ukernel.function[uarch_index](
mr_block_size,
nr_block_size,
context->kc,
context->ks_scaled,
(const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
(void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
cm_stride,
context->cn_stride,
context->a_offset,
context->zero,
&context->params);
}
void xnn_compute_hmp_dqigemm(
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{
const size_t ks = context->ks;
const size_t cm_stride = context->cm_stride;
context->dq_ukernel.function[uarch_index](
mr_block_size,
nr_block_size,
context->kc,
context->ks_scaled,
(const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
(void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
cm_stride,
context->cn_stride,
context->a_offset,
context->zero,
context->zero_buffers[0],
&context->params,
(const void*) ((uintptr_t) context->quantization_params));
}
void xnn_compute_hmp_scaled_dot_product_attention(
const struct scaled_dot_product_attention_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index,
size_t batch_index,
size_t head_index,
size_t tokens_start,
size_t tokens_block_size)
{
const size_t query_key_scaled_channels = context->query_key_scaled_channels;
const size_t query_tile_offset =
batch_index * context->query_batch_stride + head_index * context->query_head_stride +
tokens_start * query_key_scaled_channels;
const size_t key_value_tokens_scaled = context->key_value_tokens_scaled;
const size_t key_value_tokens_start_scaled = tokens_start * key_value_tokens_scaled;
const size_t cn_stride = context->cn_stride;
const void* scaled_query = (void*) ((uintptr_t) context->scaled_query + query_tile_offset);
const void* minmax_params = &context->minmax_params;
{
uintptr_t query = (uintptr_t) context->query + query_tile_offset;
uintptr_t query_scaled_current = (uintptr_t) scaled_query;
size_t i = tokens_block_size;
do {
context->vmul_ukernel(
query_key_scaled_channels,
(const void*) query,
context->scale,
(void*) query_scaled_current,
minmax_params);
query += query_key_scaled_channels;
query_scaled_current += query_key_scaled_channels;
} while (--i != 0);
}
const size_t logits_batch_offset =
batch_index * context->logits_batch_stride + head_index * context->logits_head_stride;
void* const logits =
(void*) (((uintptr_t) context->logits_buffer + logits_batch_offset + key_value_tokens_start_scaled));
{
void* key = (void*) ((uintptr_t) context->key +
batch_index * context->key_batch_stride +
head_index * context->key_head_stride);
context->gemm_ukernel.function[uarch_index](
tokens_block_size,
context->key_value_tokens,
query_key_scaled_channels,
scaled_query,
query_key_scaled_channels,
(void*) key,
(void*) (uintptr_t) logits,
key_value_tokens_scaled,
cn_stride,
minmax_params);
}
{
const size_t tokens_block_size_scaled = tokens_block_size * key_value_tokens_scaled;
struct attention_logits_cap logits_cap = context->logits_cap;
if (logits_cap.type == xnn_attention_logits_cap_type_tanh) {
context->vmulc_ukernel(
tokens_block_size_scaled,
logits,
&logits_cap.cap_reciprocal,
logits,
minmax_params);
context->vtanh_ukernel(
tokens_block_size_scaled,
logits,
logits,
&context->tanh_params);
context->vmulc_ukernel(
tokens_block_size_scaled,
logits,
&logits_cap.cap,
logits,
minmax_params);
}
context->vadd_ukernel(
tokens_block_size_scaled,
logits,
(void*) ((uintptr_t) context->mask + key_value_tokens_start_scaled),
logits,
minmax_params);
}
{
void* logits_row = logits;
size_t i = tokens_block_size;
do {
float rowmax;
context->rmax_ukernel(
key_value_tokens_scaled,
logits_row,
&rowmax,
&context->rmax_params);
float rowsum;
context->raddstoreexpminusmax_ukernel(
key_value_tokens_scaled,
logits_row,
&rowmax,
logits_row,
&rowsum,
&context->expminus_params);
float rowscale;
context->compute_reciprocal(
&rowsum,
&rowscale);
context->vmulc_ukernel(
key_value_tokens_scaled,
logits_row,
&rowscale,
logits_row,
minmax_params);
logits_row = (void*) ((uintptr_t) logits_row + key_value_tokens_scaled);
} while (--i != 0);
}
{
void* value = (void*) ((uintptr_t) context->value +
batch_index * context->value_batch_stride +
head_index * context->value_head_stride);
const size_t output_tile_offset =
batch_index * context->output_batch_stride + head_index * context->output_head_stride +
tokens_start * context->value_scaled_channels;
context->gemm_ukernel.function[uarch_index](
tokens_block_size,
context->value_channels,
key_value_tokens_scaled,
logits,
key_value_tokens_scaled,
value,
(void*) ((uintptr_t) context->output + output_tile_offset),
context->value_scaled_channels,
cn_stride,
minmax_params);
}
}
void xnn_compute_hmp_scaled_dot_product_attention_with_thread(
const struct scaled_dot_product_attention_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index,
size_t thread_index,
size_t batch_index,
size_t head_index,
size_t tokens_start,
size_t tokens_block_size)
{
const size_t query_key_scaled_channels = context->query_key_scaled_channels;
const size_t query_tile_offset =
batch_index * context->query_batch_stride + head_index * context->query_head_stride +
tokens_start * query_key_scaled_channels;
const size_t key_value_tokens_scaled = context->key_value_tokens_scaled;
const size_t key_value_tokens_start_scaled = tokens_start * key_value_tokens_scaled;
const size_t cn_stride = context->cn_stride;
const void* scaled_query =
(void*) ((uintptr_t) context->scaled_query + thread_index * context->scaled_query_thread_stride);
const void* minmax_params = &context->minmax_params;
{
uintptr_t query = (uintptr_t) context->query + query_tile_offset;
uintptr_t query_scaled_current = (uintptr_t) scaled_query;
size_t i = tokens_block_size;
do {
context->vmul_ukernel(
query_key_scaled_channels,
(const void*) query,
context->scale,
(void*) query_scaled_current,
minmax_params);
query += query_key_scaled_channels;
query_scaled_current += query_key_scaled_channels;
} while (--i != 0);
}
void* const logits = (void*) ((uintptr_t) context->logits_buffer + thread_index * context->logits_thread_stride);
{
void* key = (void*) ((uintptr_t) context->key +
batch_index * context->key_batch_stride +
head_index * context->key_head_stride);
context->gemm_ukernel.function[uarch_index](
tokens_block_size,
context->key_value_tokens,
query_key_scaled_channels,
scaled_query,
query_key_scaled_channels,
(void*) key,
(void*) (uintptr_t) logits,
key_value_tokens_scaled,
cn_stride,
minmax_params);
}
{
const size_t tokens_block_size_scaled = tokens_block_size * key_value_tokens_scaled;
struct attention_logits_cap logits_cap = context->logits_cap;
if (logits_cap.type == xnn_attention_logits_cap_type_tanh) {
context->vmulc_ukernel(
tokens_block_size_scaled,
logits,
&logits_cap.cap_reciprocal,
logits,
minmax_params);
context->vtanh_ukernel(
tokens_block_size_scaled,
logits,
logits,
&context->tanh_params);
context->vmulc_ukernel(
tokens_block_size_scaled,
logits,
&logits_cap.cap,
logits,
minmax_params);
}
context->vadd_ukernel(
tokens_block_size_scaled,
logits,
(void*) ((uintptr_t) context->mask + key_value_tokens_start_scaled),
logits,
minmax_params);
}
{
void* logits_row = logits;
size_t i = tokens_block_size;
do {
float rowmax;
context->rmax_ukernel(
key_value_tokens_scaled,
logits_row,
&rowmax,
&context->rmax_params);
float rowsum;
context->raddstoreexpminusmax_ukernel(
key_value_tokens_scaled,
logits_row,
&rowmax,
logits_row,
&rowsum,
&context->expminus_params);
float rowscale;
context->compute_reciprocal(
&rowsum,
&rowscale);
context->vmulc_ukernel(
key_value_tokens_scaled,
logits_row,
&rowscale,
logits_row,
minmax_params);
logits_row = (void*) ((uintptr_t) logits_row + key_value_tokens_scaled);
} while (--i != 0);
}
{
void* value = (void*) ((uintptr_t) context->value +
batch_index * context->value_batch_stride +
head_index * context->value_head_stride);
const size_t output_tile_offset =
batch_index * context->output_batch_stride + head_index * context->output_head_stride +
tokens_start * context->value_scaled_channels;
context->gemm_ukernel.function[uarch_index](
tokens_block_size,
context->value_channels,
key_value_tokens_scaled,
logits,
key_value_tokens_scaled,
value,
(void*) ((uintptr_t) context->output + output_tile_offset),
context->value_scaled_channels,
cn_stride,
minmax_params);
}
}
#endif
enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool)
{ … }
enum xnn_status xnn_run_operator_with_index(
xnn_operator_t op,
size_t opdata_index,
size_t operator_object_index,
pthreadpool_t threadpool)
{ … }