#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_3X3_FILTER_COMMON_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_3X3_FILTER_COMMON_H_
#include <algorithm>
#include "ruy/profiler/instrumentation.h"
#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace optimized_ops {
namespace depthwise_conv {
constexpr int kDepthwiseConvScratchWorkspaceSize = …;
constexpr int kDepthwiseConvAdjustedBiasLimit = …;
constexpr int kWorkspaceExtension = …;
#ifdef USE_NEON
#ifndef __aarch64__
inline int8x16_t vqtbl4q_s8(int8x16x4_t a, int8x16_t b) {
const uint8x16_t mask = vtstq_s8(b, vdupq_n_s8(8));
const int8x16_t high_bits = vshrq_n_s8(b, 4);
int8x16_t deleted_bit_3 = b;
deleted_bit_3 = vsliq_n_s8(deleted_bit_3, high_bits, 3);
int8x8x4_t repacked_data;
repacked_data.val[0] = vget_low_s8(a.val[0]);
repacked_data.val[1] = vget_low_s8(a.val[1]);
repacked_data.val[2] = vget_low_s8(a.val[2]);
repacked_data.val[3] = vget_low_s8(a.val[3]);
const int8x16_t output_for_lower =
vcombine_s8(vtbl4_s8(repacked_data, vget_low_s8(deleted_bit_3)),
vtbl4_s8(repacked_data, vget_high_s8(deleted_bit_3)));
repacked_data.val[0] = vget_high_s8(a.val[0]);
repacked_data.val[1] = vget_high_s8(a.val[1]);
repacked_data.val[2] = vget_high_s8(a.val[2]);
repacked_data.val[3] = vget_high_s8(a.val[3]);
const int8x16_t output_for_higher =
vcombine_s8(vtbl4_s8(repacked_data, vget_low_s8(deleted_bit_3)),
vtbl4_s8(repacked_data, vget_high_s8(deleted_bit_3)));
int8x16_t output = vbslq_s8(mask, output_for_higher, output_for_lower);
return output;
}
#endif
inline void vzipq_s8_in_place(int8x16_t* a, int8x16_t* b) {
int8x16x2_t r8x16;
r8x16 = vzipq_s8(*a, *b);
*a = r8x16.val[0];
*b = r8x16.val[1];
}
inline void vzipq_s8x2_in_place(int8x16_t* a, int8x16_t* b) {
int16x8x2_t r16x8;
r16x8 = vzipq_s16(vreinterpretq_s16_s8(*a), vreinterpretq_s16_s8(*b));
*a = vreinterpretq_s8_s16(r16x8.val[0]);
*b = vreinterpretq_s8_s16(r16x8.val[1]);
}
inline void vtrn1_s8x2_in_place(int8x16_t* a, int8x16_t* b) {
int16x8x2_t r16x8;
r16x8 = vtrnq_s16(vreinterpretq_s16_s8(*a), vreinterpretq_s16_s8(*b));
*a = vreinterpretq_s8_s16(r16x8.val[0]);
}
inline int8x16_t vzip1q_s8(int8x16_t a, int8x16_t b) {
return vzipq_s8(a, b).val[0];
}
inline int8x16_t vzip2q_s8(int8x16_t a, int8x16_t b) {
return vzipq_s8(a, b).val[1];
}
inline void biregister_rotate_8(int8x16_t* left, int8x16_t* right) {
*left = vreinterpretq_s8_u32(vshrq_n_u32(vreinterpretq_u32_s8(*left), 8));
*left = vreinterpretq_s8_u32(vsliq_n_u32(vreinterpretq_u32_s8(*left),
vreinterpretq_u32_s8(*right), 24));
*right = vreinterpretq_s8_u32(vshrq_n_u32(vreinterpretq_u32_s8(*right), 8));
}
#ifndef __aarch64__
inline int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
int32x4x2_t deinterleaved = vuzpq_s32(a, b);
return vqaddq_s32(deinterleaved.val[0], deinterleaved.val[1]);
}
#endif
#ifdef __ARM_FEATURE_DOTPROD
inline int32x4_t vdotq_four_lane_s32(int32x4_t acc, int8x16_t lhs,
int8x16_t rhs, const int lane) {
switch (lane) {
case 0:
return vdotq_lane_s32(acc, lhs, vget_low_s8(rhs), 0);
case 1:
return vdotq_lane_s32(acc, lhs, vget_low_s8(rhs), 1);
case 2:
return vdotq_lane_s32(acc, lhs, vget_high_s8(rhs), 0);
case 3:
default:
return vdotq_lane_s32(acc, lhs, vget_high_s8(rhs), 1);
}
}
#else
inline int32x4_t vdotq_s32(int32x4_t acc, int8x16_t lhs, int8x16_t rhs) {
int32x4_t sum0 = vpaddlq_s16(vmull_s8(vget_low_s8(lhs), vget_low_s8(rhs)));
int32x4_t sum1 = vpaddlq_s16(vmull_s8(vget_high_s8(lhs), vget_high_s8(rhs)));
int32x4_t sum = vpaddq_s32(sum0, sum1);
return vaddq_s32(acc, sum);
}
inline int32x4_t vdotq_four_lane_s32(int32x4_t acc, int8x16_t lhs,
int8x16_t rhs, int lane) {
int8x8_t lane_rhs;
if (lane == 0) {
lane_rhs = vreinterpret_s8_s32(
vdup_lane_s32(vreinterpret_s32_s8(vget_low_s8(rhs)), 0));
} else if (lane == 1) {
lane_rhs = vreinterpret_s8_s32(
vdup_lane_s32(vreinterpret_s32_s8(vget_low_s8(rhs)), 1));
} else if (lane == 2) {
lane_rhs = vreinterpret_s8_s32(
vdup_lane_s32(vreinterpret_s32_s8(vget_high_s8(rhs)), 0));
} else {
lane_rhs = vreinterpret_s8_s32(
vdup_lane_s32(vreinterpret_s32_s8(vget_high_s8(rhs)), 1));
}
int32x4_t sum0 = vpaddlq_s16(vmull_s8(vget_low_s8(lhs), lane_rhs));
int32x4_t sum1 = vpaddlq_s16(vmull_s8(vget_high_s8(lhs), lane_rhs));
int32x4_t sum = vpaddq_s32(sum0, sum1);
return vaddq_s32(acc, sum);
}
#endif
#endif
template <DepthwiseConvOutputRounding output_rounding>
struct DivideByPOT { … };
template <>
struct DivideByPOT<DepthwiseConvOutputRounding::kAwayFromZero> { … };
#ifdef USE_NEON
template <>
struct DivideByPOT<DepthwiseConvOutputRounding::kUpward> {
template <typename IntegerType>
static inline IntegerType Run(IntegerType x, int exponent) {
return vqrshlq_s32(x, vdupq_n_s32(static_cast<int32_t>(-exponent)));
}
template <typename IntegerType>
static inline IntegerType RunMult(IntegerType x, IntegerType exponent) {
return vqrshlq_s32(x, exponent);
}
template <typename IntegerType>
static inline IntegerType RunMult(IntegerType x, int exponent) {
return vqrshlq_s32(x, vdupq_n_s32(static_cast<int32_t>(exponent)));
}
};
#endif
enum class DotProduct3x3KernelType { … };
enum class QuantizationType { … };
template <QuantizationType quantization_type>
struct QuantizationTypeImpl { … };
template <>
struct QuantizationTypeImpl<QuantizationType::kNonPerChannelUint8> { … };
template <>
struct QuantizationTypeImpl<QuantizationType::kPerChannelInt8> { … };
template <
QuantizationType quantization_type = QuantizationType::kNonPerChannelUint8>
inline DotProduct3x3KernelType CategorizeDotProductKernel(
const RuntimeShape& input_shape, const RuntimeShape& filter_shape,
const RuntimeShape& output_shape, const DepthwiseParams& params,
const int32_t* output_shift_ptr = nullptr) { … }
struct DepthwiseConvParams { … };
struct DepthwiseConvDotProdParams { … };
template <DepthwiseConvOutputRounding output_rounding, int32_t kDepth,
int32_t kStrideWidth, int32_t kStrideHeight>
struct DepthwiseConvWindow { … };
template <DepthwiseConvOutputRounding output_rounding, int32_t kDepth,
int32_t kStrideWidth, int32_t kStrideHeight>
struct DepthwiseConvWindowPerChannel { … };
enum class EdgeType { … };
template <DepthwiseConvOutputRounding output_rounding, EdgeType kEdgeType,
int kPadWidth, int kPadHeight>
struct DepthwiseConvPartial { … };
template <DepthwiseConvOutputRounding output_rounding, EdgeType kEdgeType,
int kPadWidth, int kPadHeight>
struct DepthwiseConvPartialPerChannel { … };
template <typename T>
inline void ShuffleInput(const T* input_ptr, int64_t input_depth,
int32_t input_width, int32_t input_height,
int64_t output_depth, int32_t output_width,
int32_t output_height, T* output_ptr) { … }
inline int32_t get_shuffle_input_size(int32_t stride, int32_t output) { … }
struct ShuffleParams { … };
template <
QuantizationType quantization_type = QuantizationType::kNonPerChannelUint8>
inline bool Fast3x3FilterKernelSupported(
const RuntimeShape& input_shape, const RuntimeShape& filter_shape,
int32_t stride_width, int32_t stride_height, int32_t dilation_width_factor,
int32_t dilation_height_factor, int32_t pad_width, int32_t pad_height,
int32_t depth_multiplier, const RuntimeShape& output_shape,
int32_t output_shift, const int32_t* output_shift_ptr = nullptr) { … }
template <DepthwiseConvImplementation implementation,
QuantizationType quantization_type>
struct ProcessPerDepth { … };
template <DepthwiseConvImplementation implementation,
QuantizationType quantization_type,
DepthwiseConvDepthMultiplication depth_multiplication,
int32_t max_padding>
struct PackMacroBlock { … };
template <DepthwiseConvImplementation implementation,
QuantizationType quantization_type,
DepthwiseConvDepthMultiplication depth_multiplication, int32_t stride>
struct KernelMacroBlock { … };
#if defined(__aarch64__)
template <typename T>
inline void PreloadInputBlock(
const T* input_block_data,
const DepthwiseConvDotProdParams* function_params) {
const int input_width_micro_repeats =
function_params->input_width_micro_repeats;
const int block_height = function_params->inbound_block_height;
const int residual_width = function_params->residual_width;
const int input_height_stride = function_params->input_height_stride;
const int input_depth = function_params->input_depth;
const int total_width = 4 * input_width_micro_repeats + residual_width;
const T* row_ptr = input_block_data;
for (int k_height = 0; k_height < block_height; ++k_height) {
const T* ptr = row_ptr;
for (int j = 0; j < total_width; ++j) {
optimized_ops_preload_l1_keep(ptr);
ptr += input_depth;
}
row_ptr += input_height_stride;
}
}
#endif
}
}
}
#endif