#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
#include <algorithm>
#include "ruy/profiler/instrumentation.h"
#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace optimized_ops {
template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
struct FloatDepthwiseConvKernel { … };
#ifdef USE_NEON
template <>
struct FloatDepthwiseConvKernel<false, 8, 1> {
static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
const float* input_ptr, int input_ptr_increment,
const float* filter_ptr, float* acc_buffer_ptr) {
float32x4_t filter[2];
for (int i = 0; i < 2; i++) {
filter[i] = vld1q_f32(filter_ptr + 4 * i);
}
int outp = 0;
for (; outp <= num_output_pixels - 2; outp += 2) {
float32x4_t input[4];
for (int i = 0; i < 4; i++) {
input[i] = vld1q_f32(input_ptr + 4 * i);
}
input_ptr += 16;
float32x4_t acc[4];
for (int i = 0; i < 4; i++) {
acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
}
acc[0] = vmlaq_f32(acc[0], input[0], filter[0]);
acc[1] = vmlaq_f32(acc[1], input[1], filter[1]);
acc[2] = vmlaq_f32(acc[2], input[2], filter[0]);
acc[3] = vmlaq_f32(acc[3], input[3], filter[1]);
for (int i = 0; i < 4; i++) {
vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
}
acc_buffer_ptr += 16;
}
for (; outp < num_output_pixels; outp++) {
float32x4_t input[2];
for (int i = 0; i < 2; i++) {
input[i] = vld1q_f32(input_ptr + 4 * i);
}
input_ptr += 8;
float32x4_t acc[2];
for (int i = 0; i < 2; i++) {
acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
}
for (int i = 0; i < 2; i++) {
acc[i] = vmlaq_f32(acc[i], input[i], filter[i]);
}
for (int i = 0; i < 2; i++) {
vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
}
acc_buffer_ptr += 8;
}
}
};
template <>
struct FloatDepthwiseConvKernel<false, 2, 1> {
static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
const float* input_ptr, int input_ptr_increment,
const float* filter_ptr, float* acc_buffer_ptr) {
const float32x2_t filters = vld1_f32(filter_ptr);
const float32x4_t filters_dup2 = vcombine_f32(filters, filters);
int outp = 0;
for (; outp <= num_output_pixels - 8; outp += 8) {
float32x4_t input[4];
for (int i = 0; i < 4; i++) {
input[i] = vld1q_f32(input_ptr + 4 * i);
}
input_ptr += 16;
float32x4_t acc[4];
for (int i = 0; i < 4; i++) {
acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
}
for (int i = 0; i < 4; i++) {
acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2);
}
for (int i = 0; i < 4; i++) {
vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
}
acc_buffer_ptr += 16;
}
for (; outp <= num_output_pixels - 4; outp += 4) {
float32x4_t input[2];
for (int i = 0; i < 2; i++) {
input[i] = vld1q_f32(input_ptr + 4 * i);
}
input_ptr += 8;
float32x4_t acc[2];
for (int i = 0; i < 2; i++) {
acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
}
for (int i = 0; i < 2; i++) {
acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2);
}
for (int i = 0; i < 2; i++) {
vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
}
acc_buffer_ptr += 8;
}
for (; outp <= num_output_pixels - 2; outp += 2) {
const float32x4_t input = vld1q_f32(input_ptr);
input_ptr += 4;
float32x4_t acc = vld1q_f32(acc_buffer_ptr);
acc = vmlaq_f32(acc, input, filters_dup2);
vst1q_f32(acc_buffer_ptr, acc);
acc_buffer_ptr += 4;
}
for (; outp < num_output_pixels; outp++) {
const float32x2_t input = vld1_f32(input_ptr);
input_ptr += 2;
float32x2_t acc = vld1_f32(acc_buffer_ptr);
acc = vmla_f32(acc, input, filters);
vst1_f32(acc_buffer_ptr, acc);
acc_buffer_ptr += 2;
}
}
};
template <>
struct FloatDepthwiseConvKernel<true, 0, 1> {
static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
const float* input_ptr, int input_ptr_increment,
const float* filter_ptr, float* acc_buffer_ptr) {
for (int outp = 0; outp < num_output_pixels; outp++) {
const float* local_filter_ptr = filter_ptr;
const float* local_input_ptr = input_ptr;
int ic = 0;
for (; ic <= input_depth - 16; ic += 16) {
float32x4_t filter_0 = vld1q_f32(local_filter_ptr + 4 * 0);
float32x4_t filter_1 = vld1q_f32(local_filter_ptr + 4 * 1);
float32x4_t filter_2 = vld1q_f32(local_filter_ptr + 4 * 2);
float32x4_t filter_3 = vld1q_f32(local_filter_ptr + 4 * 3);
local_filter_ptr += 16;
float32x4_t input_0 = vld1q_f32(local_input_ptr + 4 * 0);
float32x4_t input_1 = vld1q_f32(local_input_ptr + 4 * 1);
float32x4_t input_2 = vld1q_f32(local_input_ptr + 4 * 2);
float32x4_t input_3 = vld1q_f32(local_input_ptr + 4 * 3);
local_input_ptr += 16;
float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0);
float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1);
float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2);
float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3);
acc_0 = vmlaq_f32(acc_0, input_0, filter_0);
acc_1 = vmlaq_f32(acc_1, input_1, filter_1);
acc_2 = vmlaq_f32(acc_2, input_2, filter_2);
acc_3 = vmlaq_f32(acc_3, input_3, filter_3);
vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0);
vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1);
vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2);
vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3);
acc_buffer_ptr += 16;
}
for (; ic <= input_depth - 4; ic += 4) {
float32x4_t filter;
filter = vld1q_f32(local_filter_ptr);
local_filter_ptr += 4;
float32x4_t input;
input = vld1q_f32(local_input_ptr);
local_input_ptr += 4;
float32x4_t acc;
acc = vld1q_f32(acc_buffer_ptr);
acc = vmlaq_f32(acc, input, filter);
vst1q_f32(acc_buffer_ptr, acc);
acc_buffer_ptr += 4;
}
for (; ic < input_depth; ic++) {
const float input_val = *local_input_ptr++;
const float filter_val = *local_filter_ptr++;
*acc_buffer_ptr++ += filter_val * input_val;
}
input_ptr += input_ptr_increment;
}
}
};
template <>
struct FloatDepthwiseConvKernel<true, 0, 8> {
static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
const float* input_ptr, int input_ptr_increment,
const float* filter_ptr, float* acc_buffer_ptr) {
for (int outp = 0; outp < num_output_pixels; outp++) {
const float* local_filter_ptr = filter_ptr;
const float* local_input_ptr = input_ptr;
int ic = 0;
for (; ic <= input_depth - 2; ic += 2) {
float32x4_t filter[4];
for (int i = 0; i < 4; i++) {
filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
}
local_filter_ptr += 16;
const float32x2_t input = vld1_f32(local_input_ptr);
local_input_ptr += 2;
float32x4_t acc[4];
for (int i = 0; i < 4; i++) {
acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
}
acc[0] = vmlaq_lane_f32(acc[0], filter[0], input, 0);
acc[1] = vmlaq_lane_f32(acc[1], filter[1], input, 0);
acc[2] = vmlaq_lane_f32(acc[2], filter[2], input, 1);
acc[3] = vmlaq_lane_f32(acc[3], filter[3], input, 1);
for (int i = 0; i < 4; i++) {
vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
}
acc_buffer_ptr += 16;
}
for (; ic < input_depth; ic++) {
float32x4_t filter[2];
for (int i = 0; i < 2; i++) {
filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
}
local_filter_ptr += 8;
const float input_val = *local_input_ptr++;
float32x4_t acc[2];
for (int i = 0; i < 2; i++) {
acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
}
for (int i = 0; i < 2; i++) {
acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val);
}
for (int i = 0; i < 2; i++) {
vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
}
acc_buffer_ptr += 8;
}
input_ptr += input_ptr_increment;
}
}
};
template <>
struct FloatDepthwiseConvKernel<true, 0, 2> {
static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
const float* input_ptr, int input_ptr_increment,
const float* filter_ptr, float* acc_buffer_ptr) {
for (int outp = 0; outp < num_output_pixels; outp++) {
const float* local_filter_ptr = filter_ptr;
const float* local_input_ptr = input_ptr;
int ic = 0;
for (; ic <= input_depth - 8; ic += 8) {
float32x4_t filter[4];
for (int i = 0; i < 4; i++) {
filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
}
local_filter_ptr += 16;
float32x4x2_t input_dup2[2];
for (int i = 0; i < 2; i++) {
const float32x4_t input = vld1q_f32(local_input_ptr + 4 * i);
input_dup2[i] = vzipq_f32(input, input);
}
local_input_ptr += 8;
float32x4_t acc[4];
for (int i = 0; i < 4; i++) {
acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
}
acc[0] = vmlaq_f32(acc[0], filter[0], input_dup2[0].val[0]);
acc[1] = vmlaq_f32(acc[1], filter[1], input_dup2[0].val[1]);
acc[2] = vmlaq_f32(acc[2], filter[2], input_dup2[1].val[0]);
acc[3] = vmlaq_f32(acc[3], filter[3], input_dup2[1].val[1]);
for (int i = 0; i < 4; i++) {
vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
}
acc_buffer_ptr += 16;
}
for (; ic <= input_depth - 4; ic += 4) {
float32x2_t filter[4];
for (int i = 0; i < 4; i++) {
filter[i] = vld1_f32(local_filter_ptr + 2 * i);
}
local_filter_ptr += 8;
const float32x4_t input = vld1q_f32(local_input_ptr);
local_input_ptr += 4;
float32x2_t acc[4];
for (int i = 0; i < 4; i++) {
acc[i] = vld1_f32(acc_buffer_ptr + 2 * i);
}
acc[0] = vmla_lane_f32(acc[0], filter[0], vget_low_f32(input), 0);
acc[1] = vmla_lane_f32(acc[1], filter[1], vget_low_f32(input), 1);
acc[2] = vmla_lane_f32(acc[2], filter[2], vget_high_f32(input), 0);
acc[3] = vmla_lane_f32(acc[3], filter[3], vget_high_f32(input), 1);
for (int i = 0; i < 4; i++) {
vst1_f32(acc_buffer_ptr + 2 * i, acc[i]);
}
acc_buffer_ptr += 8;
}
for (; ic <= input_depth - 2; ic += 2) {
const float32x4_t filter = vld1q_f32(local_filter_ptr);
local_filter_ptr += 4;
const float32x2_t input = vld1_f32(local_input_ptr);
local_input_ptr += 2;
float32x2_t acc[2];
for (int i = 0; i < 2; i++) {
acc[i] = vld1_f32(acc_buffer_ptr + 2 * i);
}
acc[0] = vmla_lane_f32(acc[0], vget_low_f32(filter), input, 0);
acc[1] = vmla_lane_f32(acc[1], vget_high_f32(filter), input, 1);
for (int i = 0; i < 2; i++) {
vst1_f32(acc_buffer_ptr + 2 * i, acc[i]);
}
acc_buffer_ptr += 4;
}
for (; ic < input_depth; ic++) {
const float input_val = *local_input_ptr++;
for (int i = 0; i < 2; i++) {
acc_buffer_ptr[i] += local_filter_ptr[i] * input_val;
}
local_filter_ptr += 2;
acc_buffer_ptr += 2;
}
input_ptr += input_ptr_increment;
}
}
};
template <>
struct FloatDepthwiseConvKernel<true, 3, 2> {
static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
const float* input_ptr, int input_ptr_increment,
const float* filter_ptr, float* acc_buffer_ptr) {
float32x2_t filter[3];
for (int i = 0; i < 3; i++) {
filter[i] = vld1_f32(filter_ptr + 2 * i);
}
for (int outp = 0; outp < num_output_pixels; outp++) {
const float32x2_t input01 = vld1_f32(input_ptr);
const float32x2_t input2 = vld1_dup_f32(input_ptr + 2);
float32x2_t acc[3];
for (int i = 0; i < 3; i++) {
acc[i] = vld1_f32(acc_buffer_ptr + 2 * i);
}
acc[0] = vmla_lane_f32(acc[0], filter[0], input01, 0);
acc[1] = vmla_lane_f32(acc[1], filter[1], input01, 1);
acc[2] = vmla_lane_f32(acc[2], filter[2], input2, 0);
for (int i = 0; i < 3; i++) {
vst1_f32(acc_buffer_ptr + 2 * i, acc[i]);
}
acc_buffer_ptr += 6;
input_ptr += input_ptr_increment;
}
}
};
template <>
struct FloatDepthwiseConvKernel<true, 3, 4> {
static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
const float* input_ptr, int input_ptr_increment,
const float* filter_ptr, float* acc_buffer_ptr) {
float32x4_t filter[3];
for (int i = 0; i < 3; i++) {
filter[i] = vld1q_f32(filter_ptr + 4 * i);
}
for (int outp = 0; outp < num_output_pixels; outp++) {
const float32x2_t input01 = vld1_f32(input_ptr);
const float32x2_t input2 = vld1_dup_f32(input_ptr + 2);
float32x4_t acc[3];
for (int i = 0; i < 3; i++) {
acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
}
acc[0] = vmlaq_lane_f32(acc[0], filter[0], input01, 0);
acc[1] = vmlaq_lane_f32(acc[1], filter[1], input01, 1);
acc[2] = vmlaq_lane_f32(acc[2], filter[2], input2, 0);
for (int i = 0; i < 3; i++) {
vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
}
acc_buffer_ptr += 12;
input_ptr += input_ptr_increment;
}
}
};
template <>
struct FloatDepthwiseConvKernel<true, 1, 8> {
static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
const float* input_ptr, int input_ptr_increment,
const float* filter_ptr, float* acc_buffer_ptr) {
float32x4_t filter[2];
for (int i = 0; i < 2; i++) {
filter[i] = vld1q_f32(filter_ptr + 4 * i);
}
for (int outp = 0; outp < num_output_pixels; outp++) {
const float input_val = *input_ptr;
input_ptr += input_ptr_increment;
float32x4_t acc[2];
for (int i = 0; i < 2; i++) {
acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
}
for (int i = 0; i < 2; i++) {
acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val);
}
for (int i = 0; i < 2; i++) {
vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
}
acc_buffer_ptr += 8;
}
}
};
template <>
struct FloatDepthwiseConvKernel<true, 1, 32> {
static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
const float* input_ptr, int input_ptr_increment,
const float* filter_ptr, float* acc_buffer_ptr) {
float32x4_t filter_0 = vld1q_f32(filter_ptr + 4 * 0);
float32x4_t filter_1 = vld1q_f32(filter_ptr + 4 * 1);
float32x4_t filter_2 = vld1q_f32(filter_ptr + 4 * 2);
float32x4_t filter_3 = vld1q_f32(filter_ptr + 4 * 3);
float32x4_t filter_4 = vld1q_f32(filter_ptr + 4 * 4);
float32x4_t filter_5 = vld1q_f32(filter_ptr + 4 * 5);
float32x4_t filter_6 = vld1q_f32(filter_ptr + 4 * 6);
float32x4_t filter_7 = vld1q_f32(filter_ptr + 4 * 7);
for (int outp = 0; outp < num_output_pixels; outp++) {
const float input_val = *input_ptr;
input_ptr += input_ptr_increment;
float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0);
float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1);
float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2);
float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3);
float32x4_t acc_4 = vld1q_f32(acc_buffer_ptr + 4 * 4);
float32x4_t acc_5 = vld1q_f32(acc_buffer_ptr + 4 * 5);
float32x4_t acc_6 = vld1q_f32(acc_buffer_ptr + 4 * 6);
float32x4_t acc_7 = vld1q_f32(acc_buffer_ptr + 4 * 7);
acc_0 = vmlaq_n_f32(acc_0, filter_0, input_val);
acc_1 = vmlaq_n_f32(acc_1, filter_1, input_val);
acc_2 = vmlaq_n_f32(acc_2, filter_2, input_val);
acc_3 = vmlaq_n_f32(acc_3, filter_3, input_val);
acc_4 = vmlaq_n_f32(acc_4, filter_4, input_val);
acc_5 = vmlaq_n_f32(acc_5, filter_5, input_val);
acc_6 = vmlaq_n_f32(acc_6, filter_6, input_val);
acc_7 = vmlaq_n_f32(acc_7, filter_7, input_val);
vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0);
vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1);
vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2);
vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3);
vst1q_f32(acc_buffer_ptr + 4 * 4, acc_4);
vst1q_f32(acc_buffer_ptr + 4 * 5, acc_5);
vst1q_f32(acc_buffer_ptr + 4 * 6, acc_6);
vst1q_f32(acc_buffer_ptr + 4 * 7, acc_7);
acc_buffer_ptr += 32;
}
}
};
template <>
struct FloatDepthwiseConvKernel<true, 1, 20> {
static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
const float* input_ptr, int input_ptr_increment,
const float* filter_ptr, float* acc_buffer_ptr) {
float32x4_t filter_0 = vld1q_f32(filter_ptr + 4 * 0);
float32x4_t filter_1 = vld1q_f32(filter_ptr + 4 * 1);
float32x4_t filter_2 = vld1q_f32(filter_ptr + 4 * 2);
float32x4_t filter_3 = vld1q_f32(filter_ptr + 4 * 3);
float32x4_t filter_4 = vld1q_f32(filter_ptr + 4 * 4);
for (int outp = 0; outp < num_output_pixels; outp++) {
const float input_val = *input_ptr;
input_ptr += input_ptr_increment;
float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0);
float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1);
float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2);
float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3);
float32x4_t acc_4 = vld1q_f32(acc_buffer_ptr + 4 * 4);
acc_0 = vmlaq_n_f32(acc_0, filter_0, input_val);
acc_1 = vmlaq_n_f32(acc_1, filter_1, input_val);
acc_2 = vmlaq_n_f32(acc_2, filter_2, input_val);
acc_3 = vmlaq_n_f32(acc_3, filter_3, input_val);
acc_4 = vmlaq_n_f32(acc_4, filter_4, input_val);
vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0);
vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1);
vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2);
vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3);
vst1q_f32(acc_buffer_ptr + 4 * 4, acc_4);
acc_buffer_ptr += 20;
}
}
};
template <>
struct FloatDepthwiseConvKernel<true, 0, 16> {
static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
const float* input_ptr, int input_ptr_increment,
const float* filter_ptr, float* acc_buffer_ptr) {
for (int outp = 0; outp < num_output_pixels; outp++) {
const float* local_filter_ptr = filter_ptr;
const float* local_input_ptr = input_ptr;
for (int ic = 0; ic < input_depth; ic++) {
float32x4_t filter[4];
for (int i = 0; i < 4; i++) {
filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
}
local_filter_ptr += 16;
const float input_val = *local_input_ptr++;
float32x4_t acc[4];
for (int i = 0; i < 4; i++) {
acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
}
for (int i = 0; i < 4; i++) {
acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val);
}
for (int i = 0; i < 4; i++) {
vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
}
acc_buffer_ptr += 16;
}
input_ptr += input_ptr_increment;
}
}
};
template <>
struct FloatDepthwiseConvKernel<true, 8, 1> {
static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
const float* input_ptr, int input_ptr_increment,
const float* filter_ptr, float* acc_buffer_ptr) {
float32x4_t filter[2];
for (int i = 0; i < 2; i++) {
filter[i] = vld1q_f32(filter_ptr + 4 * i);
}
for (int outp = 0; outp < num_output_pixels; outp++) {
float32x4_t input[2];
for (int i = 0; i < 2; i++) {
input[i] = vld1q_f32(input_ptr + 4 * i);
}
float32x4_t acc[2];
for (int i = 0; i < 2; i++) {
acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
}
for (int i = 0; i < 2; i++) {
acc[i] = vmlaq_f32(acc[i], input[i], filter[i]);
}
for (int i = 0; i < 2; i++) {
vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
}
acc_buffer_ptr += 8;
input_ptr += input_ptr_increment;
}
}
};
template <>
struct FloatDepthwiseConvKernel<true, 2, 1> {
static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
const float* input_ptr, int input_ptr_increment,
const float* filter_ptr, float* acc_buffer_ptr) {
float32x2_t filter = vld1_f32(filter_ptr);
float32x4_t filter_x4 = vcombine_f32(filter, filter);
int outp = 0;
for (; outp <= num_output_pixels - 2; outp += 2) {
float32x2_t input_1 = vld1_f32(input_ptr);
input_ptr += input_ptr_increment;
float32x2_t input_2 = vld1_f32(input_ptr);
input_ptr += input_ptr_increment;
float32x4_t input = vcombine_f32(input_1, input_2);
float32x4_t acc = vld1q_f32(acc_buffer_ptr);
acc = vmlaq_f32(acc, input, filter_x4);
vst1q_f32(acc_buffer_ptr, acc);
acc_buffer_ptr += 4;
}
for (; outp < num_output_pixels; outp++) {
float32x2_t input = vld1_f32(input_ptr);
input_ptr += input_ptr_increment;
float32x2_t acc = vld1_f32(acc_buffer_ptr);
acc = vmla_f32(acc, input, filter);
vst1_f32(acc_buffer_ptr, acc);
acc_buffer_ptr += 2;
}
}
};
template <>
struct FloatDepthwiseConvKernel<true, 4, 1> {
static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
const float* input_ptr, int input_ptr_increment,
const float* filter_ptr, float* acc_buffer_ptr) {
float32x4_t filter = vld1q_f32(filter_ptr);
for (int outp = 0; outp < num_output_pixels; outp++) {
float32x4_t input = vld1q_f32(input_ptr);
float32x4_t acc = vld1q_f32(acc_buffer_ptr);
acc = vmlaq_f32(acc, input, filter);
vst1q_f32(acc_buffer_ptr, acc);
acc_buffer_ptr += 4;
input_ptr += input_ptr_increment;
}
}
};
#endif
template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
void FloatDepthwiseConvAccumRow(int stride, int dilation_factor,
int input_depth, int input_width,
const float* input_data, int pad_width,
int depth_multiplier, int filter_width,
const float* filter_data,
int out_x_buffer_start, int out_x_buffer_end,
int output_depth, float* acc_buffer) { … }
inline void FloatDepthwiseConvAccumRowGeneric(
int stride, int dilation_factor, int input_depth, int input_width,
const float* input_data, int pad_width, int depth_multiplier,
int filter_width, const float* filter_data, int out_x_buffer_start,
int out_x_buffer_end, int output_depth, float* acc_buffer) { … }
inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth,
const float* bias_data,
float* acc_buffer) { … }
inline void DepthwiseConvImpl(
const DepthwiseParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& filter_shape,
const float* filter_data, const RuntimeShape& bias_shape,
const float* bias_data, const RuntimeShape& output_shape,
float* output_data, const CpuFlags& , int thread_start,
int thread_end, int thread_dim) { … }
}
}
#endif