chromium/third_party/tflite/src/tensorflow/lite/kernels/internal/reference/mul.h

/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_MUL_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_MUL_H_

#include <algorithm>
#include <complex>

#include "tensorflow/lite/kernels/internal/common.h"

namespace tflite {

namespace reference_ops {

// Maximum dimension supported by the broadcast mul operation.
constexpr int kMaxMulBroadcastDim =;

// Element-wise mul that can often be used for inner loop of broadcast Mul as
// well as the non-broadcast Mul.
inline void MulElementwise(int size, const ArithmeticParams& params,
                           const uint8_t* input1_data,
                           const uint8_t* input2_data, uint8_t* output_data) {}

template <typename T>
inline void Mul(const ArithmeticParams& params,
                const RuntimeShape& input1_shape, const T* input1_data,
                const RuntimeShape& input2_shape, const T* input2_data,
                const RuntimeShape& output_shape, T* output_data) {}

inline void Mul(const ArithmeticParams& params,
                const RuntimeShape& input1_shape,
                const std::complex<float>* input1_data,
                const RuntimeShape& input2_shape,
                const std::complex<float>* input2_data,
                const RuntimeShape& output_shape,
                std::complex<float>* output_data) {}

inline void Mul(const ArithmeticParams& params,
                const RuntimeShape& input1_shape, const uint8_t* input1_data,
                const RuntimeShape& input2_shape, const uint8_t* input2_data,
                const RuntimeShape& output_shape, uint8_t* output_data) {}

template <typename T, typename F>
void BroadcastMulRecursiveDimensions(
    const ArithmeticParams& params, int dimension, const T* input1_data,
    const T* input2_data, T* output_data, size_t* input1_offset_p,
    size_t* input2_offset_p, size_t* output_offset,
    const NdArrayDesc<kMaxMulBroadcastDim>& desc1,
    const NdArrayDesc<kMaxMulBroadcastDim>& desc2,
    const int32_t extended_output_shape_dims[kMaxMulBroadcastDim],
    F binary_func) {}

inline void BroadcastMul6DSlow(const ArithmeticParams& params,
                               const RuntimeShape& input1_shape,
                               const uint8_t* input1_data,
                               const RuntimeShape& input2_shape,
                               const uint8_t* input2_data,
                               const RuntimeShape& output_shape,
                               uint8_t* output_data) {}

template <typename T,
          // For unquantized mul on small integers, explicitly set to true.
          bool enable_for_short_integers = false>
inline typename std::enable_if<
    !is_small_integer<T>::value || enable_for_short_integers, void>::type
BroadcastMul6DSlow(const ArithmeticParams& params,
                   const RuntimeShape& unextended_input1_shape,
                   const T* input1_data,
                   const RuntimeShape& unextended_input2_shape,
                   const T* input2_data,
                   const RuntimeShape& unextended_output_shape,
                   T* output_data) {}

inline void BroadcastMul6DSlow(const ArithmeticParams& params,
                               const RuntimeShape& unextended_input1_shape,
                               const std::complex<float>* input1_data,
                               const RuntimeShape& unextended_input2_shape,
                               const std::complex<float>* input2_data,
                               const RuntimeShape& unextended_output_shape,
                               std::complex<float>* output_data) {}

template <typename T>
inline void BroadcastMul4DSlow(
    const ArithmeticParams& params, const RuntimeShape& input1_shape,
    const T* input1_data, const RuntimeShape& input2_shape,
    const T* input2_data, const RuntimeShape& output_shape, T* output_data) {}

}  // namespace reference_ops
}  // namespace tflite

#endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_MUL_H_