#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_RESIZE_BILINEAR_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_RESIZE_BILINEAR_H_
#include <stdint.h>
#include <sys/types.h>
#include <algorithm>
#include <cmath>
#include <limits>
#include <memory>
#include <type_traits>
#include "ruy/profiler/instrumentation.h"
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace optimized_ops {
namespace resize_bilinear {
#ifdef USE_NEON
inline int16x8_t Load8IntoLowerS16(const uint8_t* data_ptr) {
return vreinterpretq_s16_u16(vmovl_u8(vld1_u8(data_ptr)));
}
inline uint16x8_t Move8IntoUpperU16(const uint8x8_t vec_val) {
return vshlq_n_u16(vmovl_u8(vec_val), 8);
}
inline uint16x8_t Load8IntoUpperU16(const uint8_t* data_ptr) {
return Move8IntoUpperU16(vld1_u8(data_ptr));
}
inline void PairExtractUpper(const uint16x8_t accum_0, const uint16x8_t accum_1,
uint8x8_t* res_0, uint8x8_t* res_1) {
uint8x16x2_t unzipped =
vuzpq_u8(vreinterpretq_u8_u16(accum_0), vreinterpretq_u8_u16(accum_1));
*res_0 = vget_low_u8(unzipped.val[1]);
*res_1 = vget_high_u8(unzipped.val[1]);
}
struct op_int16x8_t {
inline op_int16x8_t() = default;
inline explicit op_int16x8_t(const int16x8_t& initial_val) {
val = initial_val;
}
inline op_int16x8_t& operator=(const int16x8_t& new_val) {
val = new_val;
return *this;
}
inline op_int16x8_t operator+=(const op_int16x8_t& add_val) {
val = vaddq_s16(val, add_val.val);
return *this;
}
inline op_int16x8_t operator-=(const op_int16x8_t& sub_val) {
val = vsubq_s16(val, sub_val.val);
return *this;
}
inline op_int16x8_t operator<<=(int32_t left_shift) {
switch (left_shift) {
case 1:
val = vshlq_n_s16(val, 1);
break;
case 4:
val = vshlq_n_s16(val, 4);
break;
case 8:
val = vshlq_n_s16(val, 8);
break;
default:
TFLITE_CHECK(false);
break;
}
return *this;
}
inline op_int16x8_t operator>>=(int32_t right_shift) {
switch (right_shift) {
case 1:
val = vshrq_n_s16(val, 1);
break;
case 4:
val = vshrq_n_s16(val, 4);
break;
case 8:
val = vshrq_n_s16(val, 8);
break;
default:
TFLITE_CHECK(false);
break;
}
return *this;
}
friend inline op_int16x8_t operator+(op_int16x8_t lhs,
const op_int16x8_t& rhs) {
lhs += rhs;
return lhs;
}
friend inline op_int16x8_t operator-(op_int16x8_t lhs,
const op_int16x8_t& rhs) {
lhs -= rhs;
return lhs;
}
friend inline op_int16x8_t operator<<(op_int16x8_t lhs, int32_t left_shift) {
lhs <<= left_shift;
return lhs;
}
friend inline op_int16x8_t operator>>(op_int16x8_t lhs, int32_t right_shift) {
lhs >>= right_shift;
return lhs;
}
int16x8_t val;
};
struct op_uint16x8_t {
inline op_uint16x8_t() = default;
inline explicit op_uint16x8_t(const uint16x8_t initial_val) {
val = initial_val;
}
inline op_uint16x8_t& operator=(const uint16x8_t& new_val) {
val = new_val;
return *this;
}
inline op_uint16x8_t operator+=(const op_int16x8_t& add_val) {
val = vaddq_u16(val, vreinterpretq_u16_s16(add_val.val));
return *this;
}
inline op_uint16x8_t operator-=(const op_int16x8_t& sub_val) {
val = vsubq_u16(val, vreinterpretq_u16_s16(sub_val.val));
return *this;
}
inline op_uint16x8_t operator<<=(int32_t left_shift) {
switch (left_shift) {
case 1:
val = vshlq_n_u16(val, 1);
break;
case 4:
val = vshlq_n_u16(val, 4);
break;
case 8:
val = vshlq_n_u16(val, 8);
break;
default:
TFLITE_CHECK(false);
break;
}
return *this;
}
inline op_uint16x8_t operator>>=(int32_t right_shift) {
switch (right_shift) {
case 1:
val = vshrq_n_u16(val, 1);
break;
case 4:
val = vshrq_n_u16(val, 4);
break;
case 8:
val = vshrq_n_u16(val, 8);
break;
default:
TFLITE_CHECK(false);
break;
}
return *this;
}
friend inline op_uint16x8_t operator+(op_uint16x8_t lhs,
const op_int16x8_t& rhs) {
lhs += rhs;
return lhs;
}
friend inline op_uint16x8_t operator-(op_uint16x8_t lhs,
const op_int16x8_t& rhs) {
lhs -= rhs;
return lhs;
}
friend inline op_uint16x8_t operator<<(op_uint16x8_t lhs,
int32_t left_shift) {
lhs <<= left_shift;
return lhs;
}
friend inline op_uint16x8_t operator>>(op_uint16x8_t lhs,
int32_t right_shift) {
lhs >>= right_shift;
return lhs;
}
uint16x8_t val;
};
inline op_uint16x8_t VReinterpretQU16S16(const op_int16x8_t& other) {
op_uint16x8_t ret_val(vreinterpretq_u16_s16(other.val));
return ret_val;
}
#endif
inline void ResizeBilinear888Uint8(int32_t batches, int32_t input_height,
int32_t input_width, int32_t depth,
const uint8_t* input_data,
uint8_t* output_data) { … }
}
#ifdef USE_NEON
inline void ResizeBilinearKernel(const float* input_ptr, int32_t depth,
float scale, float* output_ptr) {
int ic = 0;
for (; ic <= depth - 32; ic += 32) {
float32x4x2_t input[4];
for (int i = 0; i < 4; i++) {
input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
}
float32x4x2_t acc[4];
for (int i = 0; i < 4; i++) {
acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
}
for (int i = 0; i < 4; i++) {
acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
}
for (int i = 0; i < 4; i++) {
vst1q_f32(output_ptr, acc[i].val[0]);
vst1q_f32(output_ptr + 4, acc[i].val[1]);
output_ptr += 8;
}
input_ptr += 32;
}
for (; ic <= depth - 16; ic += 16) {
float32x4x2_t input[2];
for (int i = 0; i < 2; i++) {
input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
}
float32x4x2_t acc[2];
for (int i = 0; i < 2; i++) {
acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
}
for (int i = 0; i < 2; i++) {
acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
}
for (int i = 0; i < 2; i++) {
vst1q_f32(output_ptr, acc[i].val[0]);
vst1q_f32(output_ptr + 4, acc[i].val[1]);
output_ptr += 8;
}
input_ptr += 16;
}
for (; ic <= depth - 8; ic += 8) {
float32x4x2_t input;
input.val[0] = vld1q_f32(input_ptr);
input.val[1] = vld1q_f32(input_ptr + 4);
float32x4x2_t acc;
acc.val[0] = vld1q_f32(output_ptr);
acc.val[1] = vld1q_f32(output_ptr + 4);
acc.val[0] = vmlaq_n_f32(acc.val[0], input.val[0], scale);
acc.val[1] = vmlaq_n_f32(acc.val[1], input.val[1], scale);
vst1q_f32(output_ptr, acc.val[0]);
vst1q_f32(output_ptr + 4, acc.val[1]);
input_ptr += 8;
output_ptr += 8;
}
for (; ic <= depth - 4; ic += 4) {
float32x4_t input = vld1q_f32(input_ptr);
float32x4_t acc = vld1q_f32(output_ptr);
acc = vmlaq_n_f32(acc, input, scale);
vst1q_f32(output_ptr, acc);
input_ptr += 4;
output_ptr += 4;
}
for (; ic < depth; ic++) {
*output_ptr += *input_ptr * scale;
output_ptr++;
input_ptr++;
}
}
#else
inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
float scale, float* output_ptr) { … }
#endif
inline void ResizeBilinearKernel2x2(int32_t x0, int32_t x1, int32_t y0,
int32_t y1, int32_t x, int32_t y,
int32_t depth, int32_t batch,
const RuntimeShape& input_shape,
const float* input_data,
const RuntimeShape& output_shape,
float* output_data) { … }
inline void ResizeBilinear2x2(int32_t batches, int32_t input_height,
int32_t input_width, int32_t depth,
int32_t output_height, int32_t output_width,
const RuntimeShape& input_shape,
const float* input_data,
const RuntimeShape& output_shape,
float* output_data) { … }
inline void ResizeBilinearGeneric(
int32_t batches, int32_t input_height, int32_t input_width, int32_t depth,
int32_t output_height, int32_t output_width, float height_scale,
float width_scale, const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data,
const bool half_pixel_centers) { … }
template <typename T>
inline void ResizeBilinearGenericSmallChannel(
int32_t batches, int32_t input_height, int32_t input_width, int32_t depth,
int32_t output_height, int32_t output_width, float height_scale,
float width_scale, const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& output_shape, T* output_data,
const bool half_pixel_centers) { … }
inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
const RuntimeShape& unextended_input_shape,
const float* input_data,
const RuntimeShape& output_size_shape,
const int32_t* output_size_data,
const RuntimeShape& unextended_output_shape,
float* output_data) { … }
inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
const RuntimeShape& unextended_input_shape,
const uint8_t* input_data,
const RuntimeShape& output_size_shape,
const int32_t* output_size_data,
const RuntimeShape& unextended_output_shape,
uint8_t* output_data) { … }
inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
const RuntimeShape& unextended_input_shape,
const int8_t* input_data,
const RuntimeShape& unextended_output_size_shape,
const int32_t* output_size_data,
const RuntimeShape& unextended_output_shape,
int8_t* output_data) { … }
}
}
#endif