#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_CUSTOM_GEMV_H_
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_CUSTOM_GEMV_H_
#include <stdint.h>
#include <algorithm>
#include <type_traits>
#include <vector>
#include "ruy/profiler/instrumentation.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
#include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
namespace tflite {
namespace cpu_backend_gemm {
namespace detail {
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar, QuantizationFlavor quantization_flavor>
struct CustomGemvImpl { … };
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar, QuantizationFlavor quantization_flavor>
class CustomGemvTask : public cpu_backend_threadpool::Task { … };
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar, QuantizationFlavor quantization_flavor>
bool CustomGemv(
const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
CpuBackendContext* context) { … }
#ifdef USE_NEON
inline int16x8x2_t Load16AndSubtractZeroPoint(const std::uint8_t* src,
std::uint8_t zero_point) {
uint8x16_t src_u8 = vld1q_u8(src);
int16x8_t src_s16_0 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(src_u8)));
int16x8_t src_s16_1 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(src_u8)));
int16x8x2_t result;
int16x8_t zero_point_vec = vdupq_n_s16(zero_point);
result.val[0] = vsubq_s16(src_s16_0, zero_point_vec);
result.val[1] = vsubq_s16(src_s16_1, zero_point_vec);
return result;
}
inline int16x8x2_t Load16AndSubtractZeroPoint(const std::int8_t* src,
std::int8_t zero_point) {
int8x16_t src_s8 = vld1q_s8(src);
int16x8_t src_s16_0 = vmovl_s8(vget_low_s8(src_s8));
int16x8_t src_s16_1 = vmovl_s8(vget_high_s8(src_s8));
int16x8x2_t result;
int16x8_t zero_point_vec = vdupq_n_s16(zero_point);
result.val[0] = vsubq_s16(src_s16_0, zero_point_vec);
result.val[1] = vsubq_s16(src_s16_1, zero_point_vec);
return result;
}
inline int16x8_t Load8AndSubtractZeroPoint(const std::uint8_t* src,
std::uint8_t zero_point) {
uint8x8_t src_u8 = vld1_u8(src);
int16x8_t src_s16 = vreinterpretq_s16_u16(vmovl_u8(src_u8));
int16x8_t zero_point_vec = vdupq_n_s16(zero_point);
return vsubq_s16(src_s16, zero_point_vec);
}
inline int16x8_t Load8AndSubtractZeroPoint(const std::int8_t* src,
std::int8_t zero_point) {
int8x8_t src_s8 = vld1_s8(src);
int16x8_t src_s16 = vmovl_s8(src_s8);
int16x8_t zero_point_vec = vdupq_n_s16(zero_point);
return vsubq_s16(src_s16, zero_point_vec);
}
inline void ClampAndStore(int32x4_t src, std::uint8_t clamp_min,
std::uint8_t clamp_max, std::uint8_t* dst) {
const int16x4_t res16 = vqmovn_s32(src);
uint8x8_t res8 = vqmovun_s16(vcombine_s16(res16, res16));
res8 = vmax_u8(res8, vdup_n_u8(clamp_min));
res8 = vmin_u8(res8, vdup_n_u8(clamp_max));
vst1_lane_u8(dst + 0, res8, 0);
vst1_lane_u8(dst + 1, res8, 1);
vst1_lane_u8(dst + 2, res8, 2);
vst1_lane_u8(dst + 3, res8, 3);
}
inline void ClampAndStore(int32x4_t src, std::int8_t clamp_min,
std::int8_t clamp_max, std::int8_t* dst) {
const int16x4_t res16 = vqmovn_s32(src);
int8x8_t res8 = vqmovn_s16(vcombine_s16(res16, res16));
res8 = vmax_s8(res8, vdup_n_s8(clamp_min));
res8 = vmin_s8(res8, vdup_n_s8(clamp_max));
vst1_lane_s8(dst + 0, res8, 0);
vst1_lane_s8(dst + 1, res8, 1);
vst1_lane_s8(dst + 2, res8, 2);
vst1_lane_s8(dst + 3, res8, 3);
}
inline void ClampAndStore(int32x4_t src, std::int16_t clamp_min,
std::int16_t clamp_max, std::int16_t* dst) {
int16x4_t res16 = vqmovn_s32(src);
res16 = vmax_s16(res16, vdup_n_s16(clamp_min));
res16 = vmin_s16(res16, vdup_n_s16(clamp_max));
vst1_lane_s16(dst + 0, res16, 0);
vst1_lane_s16(dst + 1, res16, 1);
vst1_lane_s16(dst + 2, res16, 2);
vst1_lane_s16(dst + 3, res16, 3);
}
template <typename LhsScalar, typename RhsScalar, typename DstScalar,
QuantizationFlavor quantization_flavor>
struct CustomGemvImpl<LhsScalar, RhsScalar, std::int32_t, DstScalar,
quantization_flavor> {
static_assert(std::is_same<LhsScalar, std::uint8_t>::value ||
std::is_same<LhsScalar, std::int8_t>::value,
"");
static_assert(std::is_same<RhsScalar, std::uint8_t>::value ||
std::is_same<RhsScalar, std::int8_t>::value,
"");
static_assert(std::is_same<DstScalar, std::uint8_t>::value ||
std::is_same<DstScalar, std::int8_t>::value ||
std::is_same<DstScalar, std::int16_t>::value,
"");
static_assert(quantization_flavor ==
QuantizationFlavor::kIntegerWithUniformMultiplier ||
quantization_flavor ==
QuantizationFlavor::kIntegerWithPerRowMultiplier,
"");
static constexpr int kKernelRows = 4;
static bool IsSupportedGivenSufficientlyManyRows(
const MatrixParams<LhsScalar>& lhs_params,
const MatrixParams<RhsScalar>& rhs_params,
const MatrixParams<DstScalar>& dst_params,
const GemmParams<std::int32_t, DstScalar, quantization_flavor>& params) {
return lhs_params.cols >= 8;
}
static void Run(
const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
const GemmParams<std::int32_t, DstScalar, quantization_flavor>& params,
int row_start, int row_end) {
TFLITE_DCHECK_GE(row_end - row_start, kKernelRows);
for (int row = row_start; row < row_end; row += kKernelRows) {
row = std::min(row, row_end - kKernelRows);
const LhsScalar* filter_ptr = lhs_data + row * lhs_params.cols;
static constexpr int kCacheLineSize = 64;
for (int k = 0; k < rhs_params.rows;
k += kCacheLineSize / sizeof(RhsScalar)) {
optimized_ops_preload_l1_keep(rhs_data + k);
}
static constexpr int kPreloadAhead = 256;
int32x4_t acc0 = vdupq_n_s32(0);
int32x4_t acc1 = acc0;
int32x4_t acc2 = acc0;
int32x4_t acc3 = acc0;
int in = 0;
for (; in <= lhs_params.cols - 16; in += 16) {
const LhsScalar* local_filter_ptr = filter_ptr;
int16x8x2_t input_val =
Load16AndSubtractZeroPoint(rhs_data + in, rhs_params.zero_point);
int16x8x2_t filter_val_0 =
Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
optimized_ops_preload_l1_stream(local_filter_ptr +
kPreloadAhead / sizeof(LhsScalar));
local_filter_ptr += lhs_params.cols;
int16x8x2_t filter_val_1 =
Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
optimized_ops_preload_l1_stream(local_filter_ptr +
kPreloadAhead / sizeof(LhsScalar));
local_filter_ptr += lhs_params.cols;
int16x8x2_t filter_val_2 =
Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
optimized_ops_preload_l1_stream(local_filter_ptr +
kPreloadAhead / sizeof(LhsScalar));
local_filter_ptr += lhs_params.cols;
int16x8x2_t filter_val_3 =
Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
optimized_ops_preload_l1_stream(local_filter_ptr +
kPreloadAhead / sizeof(LhsScalar));
filter_ptr += 16;
acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0.val[0]),
vget_low_s16(input_val.val[0]));
acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1.val[0]),
vget_low_s16(input_val.val[0]));
acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2.val[0]),
vget_low_s16(input_val.val[0]));
acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3.val[0]),
vget_low_s16(input_val.val[0]));
acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0.val[1]),
vget_low_s16(input_val.val[1]));
acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1.val[1]),
vget_low_s16(input_val.val[1]));
acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2.val[1]),
vget_low_s16(input_val.val[1]));
acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3.val[1]),
vget_low_s16(input_val.val[1]));
acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0.val[0]),
vget_high_s16(input_val.val[0]));
acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1.val[0]),
vget_high_s16(input_val.val[0]));
acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2.val[0]),
vget_high_s16(input_val.val[0]));
acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3.val[0]),
vget_high_s16(input_val.val[0]));
acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0.val[1]),
vget_high_s16(input_val.val[1]));
acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1.val[1]),
vget_high_s16(input_val.val[1]));
acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2.val[1]),
vget_high_s16(input_val.val[1]));
acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3.val[1]),
vget_high_s16(input_val.val[1]));
}
if (in <= lhs_params.cols - 8) {
int16x8_t input_val =
Load8AndSubtractZeroPoint(rhs_data + in, rhs_params.zero_point);
int16x8_t filter_val_0 = Load8AndSubtractZeroPoint(
filter_ptr + 0 * lhs_params.cols, lhs_params.zero_point);
int16x8_t filter_val_1 = Load8AndSubtractZeroPoint(
filter_ptr + 1 * lhs_params.cols, lhs_params.zero_point);
int16x8_t filter_val_2 = Load8AndSubtractZeroPoint(
filter_ptr + 2 * lhs_params.cols, lhs_params.zero_point);
int16x8_t filter_val_3 = Load8AndSubtractZeroPoint(
filter_ptr + 3 * lhs_params.cols, lhs_params.zero_point);
filter_ptr += 8;
acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0),
vget_low_s16(input_val));
acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1),
vget_low_s16(input_val));
acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2),
vget_low_s16(input_val));
acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3),
vget_low_s16(input_val));
acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0),
vget_high_s16(input_val));
acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1),
vget_high_s16(input_val));
acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2),
vget_high_s16(input_val));
acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3),
vget_high_s16(input_val));
in += 8;
}
if (in < lhs_params.cols) {
const int back = in + 8 - lhs_params.cols;
TFLITE_DCHECK_GE(back, 1);
TFLITE_DCHECK_LE(back, 7);
int16x8_t input_val = Load8AndSubtractZeroPoint(
rhs_data + lhs_params.cols - 8, rhs_params.zero_point);
const LhsScalar* local_filter_ptr = filter_ptr - back;
filter_ptr += lhs_params.cols - in;
int16x8_t filter_val_0 =
Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
local_filter_ptr += lhs_params.cols;
int16x8_t filter_val_1 =
Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
local_filter_ptr += lhs_params.cols;
int16x8_t filter_val_2 =
Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
local_filter_ptr += lhs_params.cols;
int16x8_t filter_val_3 =
Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
switch (back) {
case 7:
input_val = vsetq_lane_s16(0, input_val, 6);
[[clang::fallthrough]];
case 6:
input_val = vsetq_lane_s16(0, input_val, 5);
[[clang::fallthrough]];
case 5:
input_val = vsetq_lane_s16(0, input_val, 4);
[[clang::fallthrough]];
case 4:
input_val = vsetq_lane_s16(0, input_val, 3);
[[clang::fallthrough]];
case 3:
input_val = vsetq_lane_s16(0, input_val, 2);
[[clang::fallthrough]];
case 2:
input_val = vsetq_lane_s16(0, input_val, 1);
[[clang::fallthrough]];
default:
input_val = vsetq_lane_s16(0, input_val, 0);
}
acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0),
vget_low_s16(input_val));
acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1),
vget_low_s16(input_val));
acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2),
vget_low_s16(input_val));
acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3),
vget_low_s16(input_val));
acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0),
vget_high_s16(input_val));
acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1),
vget_high_s16(input_val));
acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2),
vget_high_s16(input_val));
acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3),
vget_high_s16(input_val));
}
int32x2_t pairwise_reduced_acc_0 =
vpadd_s32(vget_low_s32(acc0), vget_high_s32(acc0));
int32x2_t pairwise_reduced_acc_1 =
vpadd_s32(vget_low_s32(acc1), vget_high_s32(acc1));
int32x2_t pairwise_reduced_acc_2 =
vpadd_s32(vget_low_s32(acc2), vget_high_s32(acc2));
int32x2_t pairwise_reduced_acc_3 =
vpadd_s32(vget_low_s32(acc3), vget_high_s32(acc3));
const int32x2_t reduced_lo =
vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
const int32x2_t reduced_hi =
vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
if (params.bias) {
int32x4_t bias_vec = vld1q_s32(params.bias + row);
reduced = vaddq_s32(reduced, bias_vec);
}
int32x4_t multiplier_fixedpoint;
int32x4_t multiplier_exponent;
if (quantization_flavor ==
QuantizationFlavor::kIntegerWithPerRowMultiplier) {
multiplier_exponent =
vld1q_s32(params.multiplier_exponent_perchannel + row);
multiplier_fixedpoint =
vld1q_s32(params.multiplier_fixedpoint_perchannel + row);
} else {
multiplier_exponent = vdupq_n_s32(params.multiplier_exponent);
multiplier_fixedpoint = vdupq_n_s32(params.multiplier_fixedpoint);
}
int32x4_t exponent_positive_part =
vmaxq_s32(multiplier_exponent, vdupq_n_s32(0));
reduced = vshlq_s32(reduced, exponent_positive_part);
reduced = vqrdmulhq_s32(reduced, multiplier_fixedpoint);
int32x4_t exponent_negative_part =
vminq_s32(multiplier_exponent, vdupq_n_s32(0));
reduced = vrshlq_s32(reduced, exponent_negative_part);
const int32x4_t output_offset_vec = vdupq_n_s32(dst_params.zero_point);
reduced = vaddq_s32(reduced, output_offset_vec);
ClampAndStore(reduced, params.clamp_min, params.clamp_max,
dst_data + row);
}
}
};
#ifdef TFLITE_WITH_RUY
inline float32x4_t mul_add(float32x4_t acc, float32x4_t lhs, float32x4_t rhs) {
#ifdef __ARM_FEATURE_FMA
return vfmaq_f32(acc, lhs, rhs);
#else
return vmlaq_f32(acc, lhs, rhs);
#endif
}
template <>
struct CustomGemvImpl<float, float, float, float,
QuantizationFlavor::kFloatingPoint> {
static constexpr int kKernelRows = 4;
static bool IsSupportedGivenSufficientlyManyRows(
const MatrixParams<float>& lhs_params,
const MatrixParams<float>& rhs_params,
const MatrixParams<float>& dst_params,
const GemmParams<float, float>& params) {
return lhs_params.cols >= 4;
}
static void Run(const MatrixParams<float>& lhs_params, const float* lhs_data,
const MatrixParams<float>& rhs_params, const float* rhs_data,
const MatrixParams<float>& dst_params, float* dst_data,
const GemmParams<float, float>& params, int row_start,
int row_end) {
TFLITE_DCHECK_GE(row_end - row_start, kKernelRows);
for (int row = row_start; row < row_end; row += kKernelRows) {
row = std::min(row, row_end - kKernelRows);
const float* filter_ptr = lhs_data + row * lhs_params.cols;
static constexpr int kCacheLineSize = 64;
for (int k = 0; k < rhs_params.rows;
k += kCacheLineSize / sizeof(float)) {
optimized_ops_preload_l1_keep(rhs_data + k);
}
static constexpr int kPreloadAhead = 256;
float32x4_t acc0 = vdupq_n_f32(0);
float32x4_t acc1 = acc0;
float32x4_t acc2 = acc0;
float32x4_t acc3 = acc0;
int in = 0;
for (; in <= lhs_params.cols - 4; in += 4) {
float32x4_t input_val = vld1q_f32(rhs_data + in);
const float* local_filter_ptr = filter_ptr;
float32x4_t filter_val_0 = vld1q_f32(local_filter_ptr);
optimized_ops_preload_l1_stream(local_filter_ptr +
kPreloadAhead / sizeof(float));
local_filter_ptr += lhs_params.cols;
float32x4_t filter_val_1 = vld1q_f32(local_filter_ptr);
optimized_ops_preload_l1_stream(local_filter_ptr +
kPreloadAhead / sizeof(float));
local_filter_ptr += lhs_params.cols;
float32x4_t filter_val_2 = vld1q_f32(local_filter_ptr);
optimized_ops_preload_l1_stream(local_filter_ptr +
kPreloadAhead / sizeof(float));
local_filter_ptr += lhs_params.cols;
float32x4_t filter_val_3 = vld1q_f32(local_filter_ptr);
optimized_ops_preload_l1_stream(local_filter_ptr +
kPreloadAhead / sizeof(float));
filter_ptr += 4;
acc0 = mul_add(acc0, filter_val_0, input_val);
acc1 = mul_add(acc1, filter_val_1, input_val);
acc2 = mul_add(acc2, filter_val_2, input_val);
acc3 = mul_add(acc3, filter_val_3, input_val);
}
if (in < lhs_params.cols) {
const int back = in + 4 - lhs_params.cols;
TFLITE_DCHECK_GE(back, 1);
TFLITE_DCHECK_LE(back, 3);
float32x4_t input_val = vld1q_f32(rhs_data + lhs_params.cols - 4);
const float* local_filter_ptr = filter_ptr - back;
filter_ptr += lhs_params.cols - in;
float32x4_t filter_val_0 = vld1q_f32(local_filter_ptr);
local_filter_ptr += lhs_params.cols;
float32x4_t filter_val_1 = vld1q_f32(local_filter_ptr);
local_filter_ptr += lhs_params.cols;
float32x4_t filter_val_2 = vld1q_f32(local_filter_ptr);
local_filter_ptr += lhs_params.cols;
float32x4_t filter_val_3 = vld1q_f32(local_filter_ptr);
switch (back) {
case 3:
input_val = vsetq_lane_f32(0, input_val, 2);
[[clang::fallthrough]];
case 2:
input_val = vsetq_lane_f32(0, input_val, 1);
[[clang::fallthrough]];
default:
input_val = vsetq_lane_f32(0, input_val, 0);
}
acc0 = mul_add(acc0, filter_val_0, input_val);
acc1 = mul_add(acc1, filter_val_1, input_val);
acc2 = mul_add(acc2, filter_val_2, input_val);
acc3 = mul_add(acc3, filter_val_3, input_val);
}
float32x2_t pairwise_reduced_acc_0 =
vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
float32x2_t pairwise_reduced_acc_1 =
vpadd_f32(vget_low_f32(acc1), vget_high_f32(acc1));
float32x2_t pairwise_reduced_acc_2 =
vpadd_f32(vget_low_f32(acc2), vget_high_f32(acc2));
float32x2_t pairwise_reduced_acc_3 =
vpadd_f32(vget_low_f32(acc3), vget_high_f32(acc3));
float32x2_t reduced_lo =
vpadd_f32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
float32x2_t reduced_hi =
vpadd_f32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
float32x4_t reduced = vcombine_f32(reduced_lo, reduced_hi);
if (params.bias) {
reduced = vaddq_f32(reduced, vld1q_f32(params.bias + row));
}
reduced = vminq_f32(reduced, vdupq_n_f32(params.clamp_max));
reduced = vmaxq_f32(reduced, vdupq_n_f32(params.clamp_min));
vst1q_f32(dst_data + row, reduced);
}
}
};
#endif
#endif
}
}
}
#endif