#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_MUL_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_MUL_H_
#include <algorithm>
#include <complex>
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {
namespace reference_ops {
constexpr int kMaxMulBroadcastDim = …;
inline void MulElementwise(int size, const ArithmeticParams& params,
const uint8_t* input1_data,
const uint8_t* input2_data, uint8_t* output_data) { … }
template <typename T>
inline void Mul(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const T* input1_data,
const RuntimeShape& input2_shape, const T* input2_data,
const RuntimeShape& output_shape, T* output_data) { … }
inline void Mul(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const std::complex<float>* input1_data,
const RuntimeShape& input2_shape,
const std::complex<float>* input2_data,
const RuntimeShape& output_shape,
std::complex<float>* output_data) { … }
inline void Mul(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const uint8_t* input1_data,
const RuntimeShape& input2_shape, const uint8_t* input2_data,
const RuntimeShape& output_shape, uint8_t* output_data) { … }
template <typename T, typename F>
void BroadcastMulRecursiveDimensions(
const ArithmeticParams& params, int dimension, const T* input1_data,
const T* input2_data, T* output_data, size_t* input1_offset_p,
size_t* input2_offset_p, size_t* output_offset,
const NdArrayDesc<kMaxMulBroadcastDim>& desc1,
const NdArrayDesc<kMaxMulBroadcastDim>& desc2,
const int32_t extended_output_shape_dims[kMaxMulBroadcastDim],
F binary_func) { … }
inline void BroadcastMul6DSlow(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const uint8_t* input1_data,
const RuntimeShape& input2_shape,
const uint8_t* input2_data,
const RuntimeShape& output_shape,
uint8_t* output_data) { … }
template <typename T,
bool enable_for_short_integers = false>
inline typename std::enable_if<
!is_small_integer<T>::value || enable_for_short_integers, void>::type
BroadcastMul6DSlow(const ArithmeticParams& params,
const RuntimeShape& unextended_input1_shape,
const T* input1_data,
const RuntimeShape& unextended_input2_shape,
const T* input2_data,
const RuntimeShape& unextended_output_shape,
T* output_data) { … }
inline void BroadcastMul6DSlow(const ArithmeticParams& params,
const RuntimeShape& unextended_input1_shape,
const std::complex<float>* input1_data,
const RuntimeShape& unextended_input2_shape,
const std::complex<float>* input2_data,
const RuntimeShape& unextended_output_shape,
std::complex<float>* output_data) { … }
template <typename T>
inline void BroadcastMul4DSlow(
const ArithmeticParams& params, const RuntimeShape& input1_shape,
const T* input1_data, const RuntimeShape& input2_shape,
const T* input2_data, const RuntimeShape& output_shape, T* output_data) { … }
}
}
#endif