#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
#include <algorithm>
#include "fixedpoint/fixedpoint.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
enum class DepthwiseConvImplementation { … };
enum class DepthwiseConvOutputRounding { … };
enum class DepthwiseConvDepthMultiplication { … };
namespace reference_ops {
namespace depthwise_conv {
template <DepthwiseConvOutputRounding output_rounding>
inline int32_t DepthwiseConvRound(int32_t x, int32_t quantized_multiplier,
int shift) { … }
#if TFLITE_SINGLE_ROUNDING
template <>
inline int32_t DepthwiseConvRound<DepthwiseConvOutputRounding::kAwayFromZero>(
int32_t x, int32_t quantized_multiplier, int shift) {
using gemmlowp::RoundingDivideByPOT;
using gemmlowp::SaturatingRoundingDoublingHighMul;
int left_shift = shift > 0 ? shift : 0;
int right_shift = shift > 0 ? 0 : -shift;
return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
x * (1 << left_shift), quantized_multiplier),
right_shift);
}
template <>
inline int32_t DepthwiseConvRound<DepthwiseConvOutputRounding::kUpward>(
int32_t x, int32_t quantized_multiplier, int shift) {
return MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift);
}
#else
template <>
inline int32_t DepthwiseConvRound<DepthwiseConvOutputRounding::kAwayFromZero>(
int32_t x, int32_t quantized_multiplier, int shift) { … }
template <>
inline int32_t DepthwiseConvRound<DepthwiseConvOutputRounding::kUpward>(
int32_t x, int32_t quantized_multiplier, int shift) { … }
#endif
template <DepthwiseConvOutputRounding output_rounding>
struct DepthwiseConvBasicKernel { … };
}
inline void DepthwiseConv(
const DepthwiseParams& params, const RuntimeShape& input_shape,
const uint8_t* input_data, const RuntimeShape& filter_shape,
const uint8_t* filter_data, const RuntimeShape& bias_shape,
const int32_t* bias_data, const RuntimeShape& output_shape,
uint8_t* output_data) { … }
}
}
#endif