chromium/third_party/tflite/src/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_

#include <algorithm>

#include "ruy/profiler/instrumentation.h"  // from @ruy
#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
#include "tensorflow/lite/kernels/internal/types.h"

namespace tflite {
namespace optimized_ops {

// Implementation of float DepthwiseConv

template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
struct FloatDepthwiseConvKernel {};

#ifdef USE_NEON

template <>
struct FloatDepthwiseConvKernel<false, 8, 1> {
  static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
                  const float* input_ptr, int input_ptr_increment,
                  const float* filter_ptr, float* acc_buffer_ptr) {
    // Load the filters
    float32x4_t filter[2];
    for (int i = 0; i < 2; i++) {
      filter[i] = vld1q_f32(filter_ptr + 4 * i);
    }
    int outp = 0;
    // Handle 2 output pixels at a time.
    for (; outp <= num_output_pixels - 2; outp += 2) {
      // Load the inputs
      float32x4_t input[4];
      for (int i = 0; i < 4; i++) {
        input[i] = vld1q_f32(input_ptr + 4 * i);
      }
      input_ptr += 16;
      // Load the accumulators from acc_buffer
      float32x4_t acc[4];
      for (int i = 0; i < 4; i++) {
        acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
      }
      // Multiply-accumulate
      acc[0] = vmlaq_f32(acc[0], input[0], filter[0]);
      acc[1] = vmlaq_f32(acc[1], input[1], filter[1]);
      acc[2] = vmlaq_f32(acc[2], input[2], filter[0]);
      acc[3] = vmlaq_f32(acc[3], input[3], filter[1]);
      // Store the accumulators back to acc_buffer
      for (int i = 0; i < 4; i++) {
        vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
      }
      acc_buffer_ptr += 16;
    }
    // Handle one output pixel at a time.
    for (; outp < num_output_pixels; outp++) {
      // Load the inputs
      float32x4_t input[2];
      for (int i = 0; i < 2; i++) {
        input[i] = vld1q_f32(input_ptr + 4 * i);
      }
      input_ptr += 8;
      // Load the accumulators from acc_buffer
      float32x4_t acc[2];
      for (int i = 0; i < 2; i++) {
        acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
      }
      // Multiply-accumulate
      for (int i = 0; i < 2; i++) {
        acc[i] = vmlaq_f32(acc[i], input[i], filter[i]);
      }
      // Store the accumulators back to acc_buffer
      for (int i = 0; i < 2; i++) {
        vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
      }
      acc_buffer_ptr += 8;
    }
  }
};

template <>
struct FloatDepthwiseConvKernel<false, 2, 1> {
  static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
                  const float* input_ptr, int input_ptr_increment,
                  const float* filter_ptr, float* acc_buffer_ptr) {
    const float32x2_t filters = vld1_f32(filter_ptr);
    const float32x4_t filters_dup2 = vcombine_f32(filters, filters);
    int outp = 0;
    // Handle 8 output pixels at a time.
    for (; outp <= num_output_pixels - 8; outp += 8) {
      // Load the inputs
      float32x4_t input[4];
      for (int i = 0; i < 4; i++) {
        input[i] = vld1q_f32(input_ptr + 4 * i);
      }
      input_ptr += 16;
      // Load the accumulators from acc_buffer
      float32x4_t acc[4];
      for (int i = 0; i < 4; i++) {
        acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
      }
      // Multiply-accumulate
      for (int i = 0; i < 4; i++) {
        acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2);
      }
      // Store the accumulators back to acc_buffer
      for (int i = 0; i < 4; i++) {
        vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
      }
      acc_buffer_ptr += 16;
    }
    // Handle 4 output pixels at a time.
    for (; outp <= num_output_pixels - 4; outp += 4) {
      // Load the inputs
      float32x4_t input[2];
      for (int i = 0; i < 2; i++) {
        input[i] = vld1q_f32(input_ptr + 4 * i);
      }
      input_ptr += 8;
      // Load the accumulators from acc_buffer
      float32x4_t acc[2];
      for (int i = 0; i < 2; i++) {
        acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
      }
      // Multiply-accumulate
      for (int i = 0; i < 2; i++) {
        acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2);
      }
      // Store the accumulators back to acc_buffer
      for (int i = 0; i < 2; i++) {
        vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
      }
      acc_buffer_ptr += 8;
    }
    // Handle 2 output pixels at a time.
    for (; outp <= num_output_pixels - 2; outp += 2) {
      // Load the inputs
      const float32x4_t input = vld1q_f32(input_ptr);
      input_ptr += 4;
      // Load the accumulators from acc_buffer
      float32x4_t acc = vld1q_f32(acc_buffer_ptr);
      // Multiply-accumulate
      acc = vmlaq_f32(acc, input, filters_dup2);
      // Store the accumulators back to acc_buffer
      vst1q_f32(acc_buffer_ptr, acc);
      acc_buffer_ptr += 4;
    }
    // Handle 1 output pixel at a time
    for (; outp < num_output_pixels; outp++) {
      // Load the inputs
      const float32x2_t input = vld1_f32(input_ptr);
      input_ptr += 2;
      // Load the accumulators from acc_buffer
      float32x2_t acc = vld1_f32(acc_buffer_ptr);
      // Multiply-accumulate
      acc = vmla_f32(acc, input, filters);
      // Store the accumulators back to acc_buffer
      vst1_f32(acc_buffer_ptr, acc);
      acc_buffer_ptr += 2;
    }
  }
};

template <>
struct FloatDepthwiseConvKernel<true, 0, 1> {
  static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
                  const float* input_ptr, int input_ptr_increment,
                  const float* filter_ptr, float* acc_buffer_ptr) {
    // Handle one output pixel at a time.
    for (int outp = 0; outp < num_output_pixels; outp++) {
      const float* local_filter_ptr = filter_ptr;
      const float* local_input_ptr = input_ptr;
      int ic = 0;
      // Handle 16 input channels at a time.
      for (; ic <= input_depth - 16; ic += 16) {
        // Load the filters
        float32x4_t filter_0 = vld1q_f32(local_filter_ptr + 4 * 0);
        float32x4_t filter_1 = vld1q_f32(local_filter_ptr + 4 * 1);
        float32x4_t filter_2 = vld1q_f32(local_filter_ptr + 4 * 2);
        float32x4_t filter_3 = vld1q_f32(local_filter_ptr + 4 * 3);
        local_filter_ptr += 16;
        // Load the inputs
        float32x4_t input_0 = vld1q_f32(local_input_ptr + 4 * 0);
        float32x4_t input_1 = vld1q_f32(local_input_ptr + 4 * 1);
        float32x4_t input_2 = vld1q_f32(local_input_ptr + 4 * 2);
        float32x4_t input_3 = vld1q_f32(local_input_ptr + 4 * 3);
        local_input_ptr += 16;
        // Load the accumulators from acc_buffer
        float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0);
        float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1);
        float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2);
        float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3);
        // Multiply-accumulate
        acc_0 = vmlaq_f32(acc_0, input_0, filter_0);
        acc_1 = vmlaq_f32(acc_1, input_1, filter_1);
        acc_2 = vmlaq_f32(acc_2, input_2, filter_2);
        acc_3 = vmlaq_f32(acc_3, input_3, filter_3);
        // Store the accumulators back to acc_buffer
        vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0);
        vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1);
        vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2);
        vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3);
        acc_buffer_ptr += 16;
      }
      // Handle 4 input channels at a time.
      for (; ic <= input_depth - 4; ic += 4) {
        // Load the filters
        float32x4_t filter;
        filter = vld1q_f32(local_filter_ptr);
        local_filter_ptr += 4;
        // Load the inputs
        float32x4_t input;
        input = vld1q_f32(local_input_ptr);
        local_input_ptr += 4;
        // Load the accumulators from acc_buffer
        float32x4_t acc;
        acc = vld1q_f32(acc_buffer_ptr);
        // Multiply-accumulate
        acc = vmlaq_f32(acc, input, filter);
        // Store the accumulators back to acc_buffer
        vst1q_f32(acc_buffer_ptr, acc);
        acc_buffer_ptr += 4;
      }
      // Handle one input channel at a time.
      for (; ic < input_depth; ic++) {
        const float input_val = *local_input_ptr++;
        const float filter_val = *local_filter_ptr++;
        *acc_buffer_ptr++ += filter_val * input_val;
      }
      input_ptr += input_ptr_increment;
    }
  }
};

template <>
struct FloatDepthwiseConvKernel<true, 0, 8> {
  static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
                  const float* input_ptr, int input_ptr_increment,
                  const float* filter_ptr, float* acc_buffer_ptr) {
    // Handle one output pixel at a time.
    for (int outp = 0; outp < num_output_pixels; outp++) {
      const float* local_filter_ptr = filter_ptr;
      const float* local_input_ptr = input_ptr;
      int ic = 0;
      // Handle 2 input channels at a time.
      for (; ic <= input_depth - 2; ic += 2) {
        // Load the filters
        float32x4_t filter[4];
        for (int i = 0; i < 4; i++) {
          filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
        }
        local_filter_ptr += 16;
        // Load the inputs
        const float32x2_t input = vld1_f32(local_input_ptr);
        local_input_ptr += 2;
        // Load the accumulators from acc_buffer
        float32x4_t acc[4];
        for (int i = 0; i < 4; i++) {
          acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
        }
        // Multiply-accumulate
        acc[0] = vmlaq_lane_f32(acc[0], filter[0], input, 0);
        acc[1] = vmlaq_lane_f32(acc[1], filter[1], input, 0);
        acc[2] = vmlaq_lane_f32(acc[2], filter[2], input, 1);
        acc[3] = vmlaq_lane_f32(acc[3], filter[3], input, 1);
        // Store the accumulators back to acc_buffer
        for (int i = 0; i < 4; i++) {
          vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
        }
        acc_buffer_ptr += 16;
      }
      // Handle one input channel at a time.
      for (; ic < input_depth; ic++) {
        // Load the filters
        float32x4_t filter[2];
        for (int i = 0; i < 2; i++) {
          filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
        }
        local_filter_ptr += 8;
        // Load the inputs
        const float input_val = *local_input_ptr++;
        // Load the accumulators from acc_buffer
        float32x4_t acc[2];
        for (int i = 0; i < 2; i++) {
          acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
        }
        // Multiply-accumulate
        for (int i = 0; i < 2; i++) {
          acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val);
        }
        // Store the accumulators back to acc_buffer
        for (int i = 0; i < 2; i++) {
          vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
        }
        acc_buffer_ptr += 8;
      }
      input_ptr += input_ptr_increment;
    }
  }
};

// Note this implementation is very slow for input_depths < 8
// (e.g. comparable to reference implementation) see, specializations for
// input_depth=3 below.
template <>
struct FloatDepthwiseConvKernel<true, 0, 2> {
  static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
                  const float* input_ptr, int input_ptr_increment,
                  const float* filter_ptr, float* acc_buffer_ptr) {
    // Handle one output pixel at a time.
    for (int outp = 0; outp < num_output_pixels; outp++) {
      const float* local_filter_ptr = filter_ptr;
      const float* local_input_ptr = input_ptr;
      int ic = 0;
      // Handle 8 input channels at a time.
      for (; ic <= input_depth - 8; ic += 8) {
        // Load the filters
        float32x4_t filter[4];
        for (int i = 0; i < 4; i++) {
          filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
        }
        local_filter_ptr += 16;
        // Load the inputs
        float32x4x2_t input_dup2[2];
        for (int i = 0; i < 2; i++) {
          const float32x4_t input = vld1q_f32(local_input_ptr + 4 * i);
          input_dup2[i] = vzipq_f32(input, input);
        }
        local_input_ptr += 8;
        // Load the accumulators from acc_buffer
        float32x4_t acc[4];
        for (int i = 0; i < 4; i++) {
          acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
        }
        // Multiply-accumulate
        acc[0] = vmlaq_f32(acc[0], filter[0], input_dup2[0].val[0]);
        acc[1] = vmlaq_f32(acc[1], filter[1], input_dup2[0].val[1]);
        acc[2] = vmlaq_f32(acc[2], filter[2], input_dup2[1].val[0]);
        acc[3] = vmlaq_f32(acc[3], filter[3], input_dup2[1].val[1]);
        // Store the accumulators back to acc_buffer
        for (int i = 0; i < 4; i++) {
          vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
        }
        acc_buffer_ptr += 16;
      }
      // Handle 4 input channels at a time.
      for (; ic <= input_depth - 4; ic += 4) {
        // Load the filters
        float32x2_t filter[4];
        for (int i = 0; i < 4; i++) {
          filter[i] = vld1_f32(local_filter_ptr + 2 * i);
        }
        local_filter_ptr += 8;
        // Load the inputs
        const float32x4_t input = vld1q_f32(local_input_ptr);
        local_input_ptr += 4;
        // Load the accumulators from acc_buffer
        float32x2_t acc[4];
        for (int i = 0; i < 4; i++) {
          acc[i] = vld1_f32(acc_buffer_ptr + 2 * i);
        }
        // Multiply-accumulate
        acc[0] = vmla_lane_f32(acc[0], filter[0], vget_low_f32(input), 0);
        acc[1] = vmla_lane_f32(acc[1], filter[1], vget_low_f32(input), 1);
        acc[2] = vmla_lane_f32(acc[2], filter[2], vget_high_f32(input), 0);
        acc[3] = vmla_lane_f32(acc[3], filter[3], vget_high_f32(input), 1);
        // Store the accumulators back to acc_buffer
        for (int i = 0; i < 4; i++) {
          vst1_f32(acc_buffer_ptr + 2 * i, acc[i]);
        }
        acc_buffer_ptr += 8;
      }
      // Handle 2 input channels at a time.
      for (; ic <= input_depth - 2; ic += 2) {
        // Load the filters
        const float32x4_t filter = vld1q_f32(local_filter_ptr);
        local_filter_ptr += 4;
        // Load the inputs
        const float32x2_t input = vld1_f32(local_input_ptr);
        local_input_ptr += 2;
        // Load the accumulators from acc_buffer
        float32x2_t acc[2];
        for (int i = 0; i < 2; i++) {
          acc[i] = vld1_f32(acc_buffer_ptr + 2 * i);
        }
        // Multiply-accumulate
        acc[0] = vmla_lane_f32(acc[0], vget_low_f32(filter), input, 0);
        acc[1] = vmla_lane_f32(acc[1], vget_high_f32(filter), input, 1);
        // Store the accumulators back to acc_buffer
        for (int i = 0; i < 2; i++) {
          vst1_f32(acc_buffer_ptr + 2 * i, acc[i]);
        }
        acc_buffer_ptr += 4;
      }
      // Handle one input channel at a time.
      for (; ic < input_depth; ic++) {
        // Load the inputs
        const float input_val = *local_input_ptr++;
        // Multiply-accumulate
        for (int i = 0; i < 2; i++) {
          acc_buffer_ptr[i] += local_filter_ptr[i] * input_val;
        }
        local_filter_ptr += 2;
        acc_buffer_ptr += 2;
      }
      input_ptr += input_ptr_increment;
    }
  }
};

template <>
struct FloatDepthwiseConvKernel<true, 3, 2> {
  static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
                  const float* input_ptr, int input_ptr_increment,
                  const float* filter_ptr, float* acc_buffer_ptr) {
    // Load the filters
    float32x2_t filter[3];
    for (int i = 0; i < 3; i++) {
      filter[i] = vld1_f32(filter_ptr + 2 * i);
    }
    // Handle one output pixel at a time.
    for (int outp = 0; outp < num_output_pixels; outp++) {
      const float32x2_t input01 = vld1_f32(input_ptr);
      const float32x2_t input2 = vld1_dup_f32(input_ptr + 2);
      // Load the accumulators from acc_buffer
      float32x2_t acc[3];
      for (int i = 0; i < 3; i++) {
        acc[i] = vld1_f32(acc_buffer_ptr + 2 * i);
      }
      // Multiply-accumulate for each input channel there 2 outputs
      acc[0] = vmla_lane_f32(acc[0], filter[0], input01, 0);
      acc[1] = vmla_lane_f32(acc[1], filter[1], input01, 1);
      acc[2] = vmla_lane_f32(acc[2], filter[2], input2, 0);
      // Store the accumulators back to acc_buffer
      for (int i = 0; i < 3; i++) {
        vst1_f32(acc_buffer_ptr + 2 * i, acc[i]);
      }
      acc_buffer_ptr += 6;
      input_ptr += input_ptr_increment;
    }
  }
};

template <>
struct FloatDepthwiseConvKernel<true, 3, 4> {
  static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
                  const float* input_ptr, int input_ptr_increment,
                  const float* filter_ptr, float* acc_buffer_ptr) {
    // Load the filters
    float32x4_t filter[3];
    for (int i = 0; i < 3; i++) {
      filter[i] = vld1q_f32(filter_ptr + 4 * i);
    }
    // Handle one output pixel at a time.
    for (int outp = 0; outp < num_output_pixels; outp++) {
      // NOTE: we only want 3 values, so we read it as two ops where
      // the second op just duplicates the lane
      const float32x2_t input01 = vld1_f32(input_ptr);
      const float32x2_t input2 = vld1_dup_f32(input_ptr + 2);
      // Load the accumulators from acc_buffer
      float32x4_t acc[3];
      for (int i = 0; i < 3; i++) {
        acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
      }
      // Multiply-accumulate all outputs.
      acc[0] = vmlaq_lane_f32(acc[0], filter[0], input01, 0);
      acc[1] = vmlaq_lane_f32(acc[1], filter[1], input01, 1);
      acc[2] = vmlaq_lane_f32(acc[2], filter[2], input2, 0);
      // Store the accumulators back to acc_buffer
      for (int i = 0; i < 3; i++) {
        vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
      }
      acc_buffer_ptr += 12;
      input_ptr += input_ptr_increment;
    }
  }
};

template <>
struct FloatDepthwiseConvKernel<true, 1, 8> {
  static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
                  const float* input_ptr, int input_ptr_increment,
                  const float* filter_ptr, float* acc_buffer_ptr) {
    // Load the filters
    float32x4_t filter[2];
    for (int i = 0; i < 2; i++) {
      filter[i] = vld1q_f32(filter_ptr + 4 * i);
    }
    // Handle one output pixel at a time.
    for (int outp = 0; outp < num_output_pixels; outp++) {
      // Load the inputs
      const float input_val = *input_ptr;
      input_ptr += input_ptr_increment;
      // Load the accumulators from acc_buffer
      float32x4_t acc[2];
      for (int i = 0; i < 2; i++) {
        acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
      }
      // Multiply-accumulate
      for (int i = 0; i < 2; i++) {
        acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val);
      }
      // Store the accumulators back to acc_buffer
      for (int i = 0; i < 2; i++) {
        vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
      }
      acc_buffer_ptr += 8;
    }
  }
};

template <>
struct FloatDepthwiseConvKernel<true, 1, 32> {
  static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
                  const float* input_ptr, int input_ptr_increment,
                  const float* filter_ptr, float* acc_buffer_ptr) {
    // Load the filters
    float32x4_t filter_0 = vld1q_f32(filter_ptr + 4 * 0);
    float32x4_t filter_1 = vld1q_f32(filter_ptr + 4 * 1);
    float32x4_t filter_2 = vld1q_f32(filter_ptr + 4 * 2);
    float32x4_t filter_3 = vld1q_f32(filter_ptr + 4 * 3);
    float32x4_t filter_4 = vld1q_f32(filter_ptr + 4 * 4);
    float32x4_t filter_5 = vld1q_f32(filter_ptr + 4 * 5);
    float32x4_t filter_6 = vld1q_f32(filter_ptr + 4 * 6);
    float32x4_t filter_7 = vld1q_f32(filter_ptr + 4 * 7);

    // Handle one output pixel at a time.
    for (int outp = 0; outp < num_output_pixels; outp++) {
      // Load the inputs
      const float input_val = *input_ptr;
      input_ptr += input_ptr_increment;
      // Load the accumulators from acc_buffer
      float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0);
      float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1);
      float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2);
      float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3);
      float32x4_t acc_4 = vld1q_f32(acc_buffer_ptr + 4 * 4);
      float32x4_t acc_5 = vld1q_f32(acc_buffer_ptr + 4 * 5);
      float32x4_t acc_6 = vld1q_f32(acc_buffer_ptr + 4 * 6);
      float32x4_t acc_7 = vld1q_f32(acc_buffer_ptr + 4 * 7);
      // Multiply-accumulate
      acc_0 = vmlaq_n_f32(acc_0, filter_0, input_val);
      acc_1 = vmlaq_n_f32(acc_1, filter_1, input_val);
      acc_2 = vmlaq_n_f32(acc_2, filter_2, input_val);
      acc_3 = vmlaq_n_f32(acc_3, filter_3, input_val);
      acc_4 = vmlaq_n_f32(acc_4, filter_4, input_val);
      acc_5 = vmlaq_n_f32(acc_5, filter_5, input_val);
      acc_6 = vmlaq_n_f32(acc_6, filter_6, input_val);
      acc_7 = vmlaq_n_f32(acc_7, filter_7, input_val);
      // Store the accumulators back to acc_buffer
      vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0);
      vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1);
      vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2);
      vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3);
      vst1q_f32(acc_buffer_ptr + 4 * 4, acc_4);
      vst1q_f32(acc_buffer_ptr + 4 * 5, acc_5);
      vst1q_f32(acc_buffer_ptr + 4 * 6, acc_6);
      vst1q_f32(acc_buffer_ptr + 4 * 7, acc_7);
      acc_buffer_ptr += 32;
    }
  }
};

template <>
struct FloatDepthwiseConvKernel<true, 1, 20> {
  static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
                  const float* input_ptr, int input_ptr_increment,
                  const float* filter_ptr, float* acc_buffer_ptr) {
    // Load the filters
    float32x4_t filter_0 = vld1q_f32(filter_ptr + 4 * 0);
    float32x4_t filter_1 = vld1q_f32(filter_ptr + 4 * 1);
    float32x4_t filter_2 = vld1q_f32(filter_ptr + 4 * 2);
    float32x4_t filter_3 = vld1q_f32(filter_ptr + 4 * 3);
    float32x4_t filter_4 = vld1q_f32(filter_ptr + 4 * 4);

    // Handle one output pixel at a time.
    for (int outp = 0; outp < num_output_pixels; outp++) {
      // Load the inputs
      const float input_val = *input_ptr;
      input_ptr += input_ptr_increment;
      // Load the accumulators from acc_buffer
      float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0);
      float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1);
      float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2);
      float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3);
      float32x4_t acc_4 = vld1q_f32(acc_buffer_ptr + 4 * 4);
      // Multiply-accumulate
      acc_0 = vmlaq_n_f32(acc_0, filter_0, input_val);
      acc_1 = vmlaq_n_f32(acc_1, filter_1, input_val);
      acc_2 = vmlaq_n_f32(acc_2, filter_2, input_val);
      acc_3 = vmlaq_n_f32(acc_3, filter_3, input_val);
      acc_4 = vmlaq_n_f32(acc_4, filter_4, input_val);
      // Store the accumulators back to acc_buffer
      vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0);
      vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1);
      vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2);
      vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3);
      vst1q_f32(acc_buffer_ptr + 4 * 4, acc_4);
      acc_buffer_ptr += 20;
    }
  }
};

template <>
struct FloatDepthwiseConvKernel<true, 0, 16> {
  static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
                  const float* input_ptr, int input_ptr_increment,
                  const float* filter_ptr, float* acc_buffer_ptr) {
    // Handle one output pixel at a time.
    for (int outp = 0; outp < num_output_pixels; outp++) {
      const float* local_filter_ptr = filter_ptr;
      const float* local_input_ptr = input_ptr;
      for (int ic = 0; ic < input_depth; ic++) {
        // Load the filters
        float32x4_t filter[4];
        for (int i = 0; i < 4; i++) {
          filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
        }
        local_filter_ptr += 16;
        // Load the inputs
        const float input_val = *local_input_ptr++;
        // Load the accumulators from acc_buffer
        float32x4_t acc[4];
        for (int i = 0; i < 4; i++) {
          acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
        }
        // Multiply-accumulate
        for (int i = 0; i < 4; i++) {
          acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val);
        }
        // Store the accumulators back to acc_buffer
        for (int i = 0; i < 4; i++) {
          vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
        }
        acc_buffer_ptr += 16;
      }
      input_ptr += input_ptr_increment;
    }
  }
};

template <>
struct FloatDepthwiseConvKernel<true, 8, 1> {
  static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
                  const float* input_ptr, int input_ptr_increment,
                  const float* filter_ptr, float* acc_buffer_ptr) {
    // Load the filters
    float32x4_t filter[2];
    for (int i = 0; i < 2; i++) {
      filter[i] = vld1q_f32(filter_ptr + 4 * i);
    }
    // Handle one output pixel at a time.
    for (int outp = 0; outp < num_output_pixels; outp++) {
      // Load the inputs
      float32x4_t input[2];
      for (int i = 0; i < 2; i++) {
        input[i] = vld1q_f32(input_ptr + 4 * i);
      }
      // Load the accumulators from acc_buffer
      float32x4_t acc[2];
      for (int i = 0; i < 2; i++) {
        acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
      }
      // Multiply-accumulate
      for (int i = 0; i < 2; i++) {
        acc[i] = vmlaq_f32(acc[i], input[i], filter[i]);
      }
      // Store the accumulators back to acc_buffer
      for (int i = 0; i < 2; i++) {
        vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
      }
      acc_buffer_ptr += 8;
      input_ptr += input_ptr_increment;
    }
  }
};

template <>
struct FloatDepthwiseConvKernel<true, 2, 1> {
  static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
                  const float* input_ptr, int input_ptr_increment,
                  const float* filter_ptr, float* acc_buffer_ptr) {
    float32x2_t filter = vld1_f32(filter_ptr);
    float32x4_t filter_x4 = vcombine_f32(filter, filter);
    int outp = 0;

    // Handle two output pixels at a time.
    for (; outp <= num_output_pixels - 2; outp += 2) {
      // Load the inputs
      float32x2_t input_1 = vld1_f32(input_ptr);
      input_ptr += input_ptr_increment;
      float32x2_t input_2 = vld1_f32(input_ptr);
      input_ptr += input_ptr_increment;
      float32x4_t input = vcombine_f32(input_1, input_2);

      // Load the accumulators from acc_buffer
      float32x4_t acc = vld1q_f32(acc_buffer_ptr);

      // Multiply-accumulate
      acc = vmlaq_f32(acc, input, filter_x4);

      // Store the accumulators back to acc_buffer
      vst1q_f32(acc_buffer_ptr, acc);
      acc_buffer_ptr += 4;
    }
    // Handle one output pixel at a time.
    for (; outp < num_output_pixels; outp++) {
      // Load the inputs
      float32x2_t input = vld1_f32(input_ptr);
      input_ptr += input_ptr_increment;

      // Load the accumulators from acc_buffer
      float32x2_t acc = vld1_f32(acc_buffer_ptr);

      // Multiply-accumulate
      acc = vmla_f32(acc, input, filter);

      // Store the accumulators back to acc_buffer
      vst1_f32(acc_buffer_ptr, acc);
      acc_buffer_ptr += 2;
    }
  }
};

template <>
struct FloatDepthwiseConvKernel<true, 4, 1> {
  static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
                  const float* input_ptr, int input_ptr_increment,
                  const float* filter_ptr, float* acc_buffer_ptr) {
    float32x4_t filter = vld1q_f32(filter_ptr);

    // Handle one output pixel at a time.
    for (int outp = 0; outp < num_output_pixels; outp++) {
      // Load the inputs
      float32x4_t input = vld1q_f32(input_ptr);
      // Load the accumulators from acc_buffer
      float32x4_t acc = vld1q_f32(acc_buffer_ptr);
      // Multiply-accumulate
      acc = vmlaq_f32(acc, input, filter);
      // Store the accumulators back to acc_buffer
      vst1q_f32(acc_buffer_ptr, acc);
      acc_buffer_ptr += 4;
      input_ptr += input_ptr_increment;
    }
  }
};
#endif

// Accumulates the effect of one row of the filter, on a segment of one row
// of the output, accessing the corresponding one row of the input.
template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
void FloatDepthwiseConvAccumRow(int stride, int dilation_factor,
                                int input_depth, int input_width,
                                const float* input_data, int pad_width,
                                int depth_multiplier, int filter_width,
                                const float* filter_data,
                                int out_x_buffer_start, int out_x_buffer_end,
                                int output_depth, float* acc_buffer) {}

// generic fallback of FloatDepthwiseConvAccumRow, portable, non-templatized.
inline void FloatDepthwiseConvAccumRowGeneric(
    int stride, int dilation_factor, int input_depth, int input_width,
    const float* input_data, int pad_width, int depth_multiplier,
    int filter_width, const float* filter_data, int out_x_buffer_start,
    int out_x_buffer_end, int output_depth, float* acc_buffer) {}

// Initializes the accumulator buffer with bias values.
inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth,
                                       const float* bias_data,
                                       float* acc_buffer) {}

// DepthwiseConv can run with multi threads on the dim specified by thread_dim.
// Each thread processes output elements on dim, thread_dim, in the range of
// [thread_start, thread_end).
// For example, assume thread_start = 2, thread_end = 6, and thread_dim = 1, it
// means that it will calculate DepthwiseConv for output_data[:, 2:5, :, :].
//
// The cpu_flags is currently unused. This
// parameter is included so that the signature matches that required by a
// templated function. Other versions, such as quantized, need this parameter.
inline void DepthwiseConvImpl(
    const DepthwiseParams& params, const RuntimeShape& input_shape,
    const float* input_data, const RuntimeShape& filter_shape,
    const float* filter_data, const RuntimeShape& bias_shape,
    const float* bias_data, const RuntimeShape& output_shape,
    float* output_data, const CpuFlags& /* cpu_flags */, int thread_start,
    int thread_end, int thread_dim) {}


}  // namespace optimized_ops
}  // namespace tflite

#endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_