chromium/third_party/tflite/src/tensorflow/lite/kernels/internal/optimized/resize_bilinear.h

/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_RESIZE_BILINEAR_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_RESIZE_BILINEAR_H_

#include <stdint.h>
#include <sys/types.h>

#include <algorithm>
#include <cmath>
#include <limits>
#include <memory>
#include <type_traits>

#include "ruy/profiler/instrumentation.h"  // from @ruy
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/internal/types.h"

namespace tflite {
namespace optimized_ops {
namespace resize_bilinear {

#ifdef USE_NEON
// These utility functions are split off not just for convenience. Most
// incoporate packing or unpacking of data.
//
// (a) Optimizations can be tried experimentally.
// (b) Optimizations can be specialized for architectures, eg Intel vs ARM.

inline int16x8_t Load8IntoLowerS16(const uint8_t* data_ptr) {
  return vreinterpretq_s16_u16(vmovl_u8(vld1_u8(data_ptr)));
}

inline uint16x8_t Move8IntoUpperU16(const uint8x8_t vec_val) {
  // Alternatively one could zip with a zero vector.
  return vshlq_n_u16(vmovl_u8(vec_val), 8);
}

inline uint16x8_t Load8IntoUpperU16(const uint8_t* data_ptr) {
  return Move8IntoUpperU16(vld1_u8(data_ptr));
}

// Extract upper 8 bits from each 16-bit integer in vector registers. This is
// performed for a pair, because instructions often work on pairs.
inline void PairExtractUpper(const uint16x8_t accum_0, const uint16x8_t accum_1,
                             uint8x8_t* res_0, uint8x8_t* res_1) {
  uint8x16x2_t unzipped =
      vuzpq_u8(vreinterpretq_u8_u16(accum_0), vreinterpretq_u8_u16(accum_1));
  *res_0 = vget_low_u8(unzipped.val[1]);
  *res_1 = vget_high_u8(unzipped.val[1]);
}

// This is an exceptional definition.
//
// Modify int16x8_t, adding operators.
//
// There are exceptional circumstances that make it reasonable to write code
// on vector types for quantized resize bilinear in *some cases*.
//
// (a) In exact quant resize bilinear, it should be possible to guarantee that
//     arithmetic never overflows.
// (b) When the resize scaling is 2 or 4 or 8 it is possible to guarantee
//     exact accumulation and exact incrementation.
// (c) In quant resize bilinear the choice of unsigned vs signed accumulation
//     and saturated vs unsaturated arithmetic is often unimportant.
//
// This pattern simplifies the code considerably. This pattern should not be
// used more widely in code since it can hide important numerical detail.
//
// DO NOT add to this any "class-like" methods: only those that do no more than
// redirecting operators to specific intrinsics functions.
struct op_int16x8_t {
  inline op_int16x8_t() = default;
  inline explicit op_int16x8_t(const int16x8_t& initial_val) {
    val = initial_val;
  }
  inline op_int16x8_t& operator=(const int16x8_t& new_val) {
    val = new_val;
    return *this;
  }
  inline op_int16x8_t operator+=(const op_int16x8_t& add_val) {
    val = vaddq_s16(val, add_val.val);
    return *this;
  }
  inline op_int16x8_t operator-=(const op_int16x8_t& sub_val) {
    val = vsubq_s16(val, sub_val.val);
    return *this;
  }
  // This really selects vshlq_n_s16, but requires a longer implementation to
  // convert the shift argument back to a constant. In some compiles are macros
  // requiring constant args.
  inline op_int16x8_t operator<<=(int32_t left_shift) {
    switch (left_shift) {
      case 1:
        val = vshlq_n_s16(val, 1);
        break;
      case 4:
        val = vshlq_n_s16(val, 4);
        break;
      case 8:
        val = vshlq_n_s16(val, 8);
        break;
      default:
        TFLITE_CHECK(false);
        break;
    }
    return *this;
  }
  // This really selects vshrq_n_u16, but requires a longer implementation to
  // convert the shift argument back to a constant. In some compiles are macros
  // requiring constant args.
  inline op_int16x8_t operator>>=(int32_t right_shift) {
    switch (right_shift) {
      case 1:
        val = vshrq_n_s16(val, 1);
        break;
      case 4:
        val = vshrq_n_s16(val, 4);
        break;
      case 8:
        val = vshrq_n_s16(val, 8);
        break;
      default:
        TFLITE_CHECK(false);
        break;
    }
    return *this;
  }
  friend inline op_int16x8_t operator+(op_int16x8_t lhs,
                                       const op_int16x8_t& rhs) {
    lhs += rhs;
    return lhs;
  }
  friend inline op_int16x8_t operator-(op_int16x8_t lhs,
                                       const op_int16x8_t& rhs) {
    lhs -= rhs;
    return lhs;
  }
  friend inline op_int16x8_t operator<<(op_int16x8_t lhs, int32_t left_shift) {
    lhs <<= left_shift;
    return lhs;
  }
  friend inline op_int16x8_t operator>>(op_int16x8_t lhs, int32_t right_shift) {
    lhs >>= right_shift;
    return lhs;
  }

  int16x8_t val;
};

// This is an exceptional definition.
//
// Modify uint16x8_t, adding operators.
//
// Important: See above notes on op_int16x8_t.
struct op_uint16x8_t {
  inline op_uint16x8_t() = default;
  inline explicit op_uint16x8_t(const uint16x8_t initial_val) {
    val = initial_val;
  }
  inline op_uint16x8_t& operator=(const uint16x8_t& new_val) {
    val = new_val;
    return *this;
  }
  inline op_uint16x8_t operator+=(const op_int16x8_t& add_val) {
    val = vaddq_u16(val, vreinterpretq_u16_s16(add_val.val));
    return *this;
  }
  inline op_uint16x8_t operator-=(const op_int16x8_t& sub_val) {
    val = vsubq_u16(val, vreinterpretq_u16_s16(sub_val.val));
    return *this;
  }
  // This really selects vshlq_n_s16, but requires a longer implementation to
  // convert the shift argument back to a constant. In some compiles are macros
  // requiring constant args.
  inline op_uint16x8_t operator<<=(int32_t left_shift) {
    switch (left_shift) {
      case 1:
        val = vshlq_n_u16(val, 1);
        break;
      case 4:
        val = vshlq_n_u16(val, 4);
        break;
      case 8:
        val = vshlq_n_u16(val, 8);
        break;
      default:
        TFLITE_CHECK(false);
        break;
    }
    return *this;
  }
  // This really selects vshrq_n_u16, but requires a longer implementation to
  // convert the shift argument back to a constant. In some compiles are macros
  // requiring constant args.
  inline op_uint16x8_t operator>>=(int32_t right_shift) {
    switch (right_shift) {
      case 1:
        val = vshrq_n_u16(val, 1);
        break;
      case 4:
        val = vshrq_n_u16(val, 4);
        break;
      case 8:
        val = vshrq_n_u16(val, 8);
        break;
      default:
        TFLITE_CHECK(false);
        break;
    }
    return *this;
  }
  friend inline op_uint16x8_t operator+(op_uint16x8_t lhs,
                                        const op_int16x8_t& rhs) {
    lhs += rhs;
    return lhs;
  }
  friend inline op_uint16x8_t operator-(op_uint16x8_t lhs,
                                        const op_int16x8_t& rhs) {
    lhs -= rhs;
    return lhs;
  }
  friend inline op_uint16x8_t operator<<(op_uint16x8_t lhs,
                                         int32_t left_shift) {
    lhs <<= left_shift;
    return lhs;
  }
  friend inline op_uint16x8_t operator>>(op_uint16x8_t lhs,
                                         int32_t right_shift) {
    lhs >>= right_shift;
    return lhs;
  }

  uint16x8_t val;
};

inline op_uint16x8_t VReinterpretQU16S16(const op_int16x8_t& other) {
  op_uint16x8_t ret_val(vreinterpretq_u16_s16(other.val));
  return ret_val;
}
#endif  // USE_NEON

// Optimized resize-bilinear for the special case where the scaling is x8 in
// width and height, and where we can operate on depth-8 blocks at a time. So
// the output blocks are 8x8x8 in width-height-depth.
//
// This optimization is for the half_pixel_centers == true version, for uint8.
// There are versions for NEON and non-NEON compilation.
inline void ResizeBilinear888Uint8(int32_t batches, int32_t input_height,
                                   int32_t input_width, int32_t depth,
                                   const uint8_t* input_data,
                                   uint8_t* output_data) {}  // NOLINT(readability/fn_size)

}  // namespace resize_bilinear

#ifdef USE_NEON
inline void ResizeBilinearKernel(const float* input_ptr, int32_t depth,
                                 float scale, float* output_ptr) {
  int ic = 0;
  // Handle 32 input channels at a time.
  for (; ic <= depth - 32; ic += 32) {
    float32x4x2_t input[4];
    for (int i = 0; i < 4; i++) {
      input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
      input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
    }
    float32x4x2_t acc[4];
    for (int i = 0; i < 4; i++) {
      acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
      acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
    }
    for (int i = 0; i < 4; i++) {
      acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
      acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
    }
    for (int i = 0; i < 4; i++) {
      vst1q_f32(output_ptr, acc[i].val[0]);
      vst1q_f32(output_ptr + 4, acc[i].val[1]);
      output_ptr += 8;
    }
    input_ptr += 32;
  }
  // Handle 16 input channels at a time.
  for (; ic <= depth - 16; ic += 16) {
    float32x4x2_t input[2];
    for (int i = 0; i < 2; i++) {
      input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
      input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
    }
    float32x4x2_t acc[2];
    for (int i = 0; i < 2; i++) {
      acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
      acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
    }
    for (int i = 0; i < 2; i++) {
      acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
      acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
    }
    for (int i = 0; i < 2; i++) {
      vst1q_f32(output_ptr, acc[i].val[0]);
      vst1q_f32(output_ptr + 4, acc[i].val[1]);
      output_ptr += 8;
    }
    input_ptr += 16;
  }
  // Handle 8 input channels at a time.
  for (; ic <= depth - 8; ic += 8) {
    float32x4x2_t input;
    input.val[0] = vld1q_f32(input_ptr);
    input.val[1] = vld1q_f32(input_ptr + 4);

    float32x4x2_t acc;
    acc.val[0] = vld1q_f32(output_ptr);
    acc.val[1] = vld1q_f32(output_ptr + 4);
    acc.val[0] = vmlaq_n_f32(acc.val[0], input.val[0], scale);
    acc.val[1] = vmlaq_n_f32(acc.val[1], input.val[1], scale);

    vst1q_f32(output_ptr, acc.val[0]);
    vst1q_f32(output_ptr + 4, acc.val[1]);

    input_ptr += 8;
    output_ptr += 8;
  }
  // Handle 4 input channels at a time.
  for (; ic <= depth - 4; ic += 4) {
    float32x4_t input = vld1q_f32(input_ptr);
    float32x4_t acc = vld1q_f32(output_ptr);

    acc = vmlaq_n_f32(acc, input, scale);
    vst1q_f32(output_ptr, acc);

    input_ptr += 4;
    output_ptr += 4;
  }
  // Handle 1 input channel at a time.
  for (; ic < depth; ic++) {
    *output_ptr += *input_ptr * scale;
    output_ptr++;
    input_ptr++;
  }
}
#else
inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
                                 float scale, float* output_ptr) {}
#endif

inline void ResizeBilinearKernel2x2(int32_t x0, int32_t x1, int32_t y0,
                                    int32_t y1, int32_t x, int32_t y,
                                    int32_t depth, int32_t batch,
                                    const RuntimeShape& input_shape,
                                    const float* input_data,
                                    const RuntimeShape& output_shape,
                                    float* output_data) {}

inline void ResizeBilinear2x2(int32_t batches, int32_t input_height,
                              int32_t input_width, int32_t depth,
                              int32_t output_height, int32_t output_width,
                              const RuntimeShape& input_shape,
                              const float* input_data,
                              const RuntimeShape& output_shape,
                              float* output_data) {}

inline void ResizeBilinearGeneric(
    int32_t batches, int32_t input_height, int32_t input_width, int32_t depth,
    int32_t output_height, int32_t output_width, float height_scale,
    float width_scale, const RuntimeShape& input_shape, const float* input_data,
    const RuntimeShape& output_shape, float* output_data,
    const bool half_pixel_centers) {}

template <typename T>
inline void ResizeBilinearGenericSmallChannel(
    int32_t batches, int32_t input_height, int32_t input_width, int32_t depth,
    int32_t output_height, int32_t output_width, float height_scale,
    float width_scale, const RuntimeShape& input_shape, const T* input_data,
    const RuntimeShape& output_shape, T* output_data,
    const bool half_pixel_centers) {}

inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
                           const RuntimeShape& unextended_input_shape,
                           const float* input_data,
                           const RuntimeShape& output_size_shape,
                           const int32_t* output_size_data,
                           const RuntimeShape& unextended_output_shape,
                           float* output_data) {}

// Note: This is not a universal quantized bilinear. It does not use int8
// or int16 arithmetic.
inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
                           const RuntimeShape& unextended_input_shape,
                           const uint8_t* input_data,
                           const RuntimeShape& output_size_shape,
                           const int32_t* output_size_data,
                           const RuntimeShape& unextended_output_shape,
                           uint8_t* output_data) {}

// TODO(b/180609127) Create optimized int8 version from uint8. Call from here.
inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
                           const RuntimeShape& unextended_input_shape,
                           const int8_t* input_data,
                           const RuntimeShape& unextended_output_size_shape,
                           const int32_t* output_size_data,
                           const RuntimeShape& unextended_output_shape,
                           int8_t* output_data) {}

}  // namespace optimized_ops
}  // namespace tflite

#endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_RESIZE_BILINEAR_H_