#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include <algorithm>
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
namespace tflite {
namespace kernel_utils {
void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
const float* recurrent_weights_ptr, const float* bias_ptr,
int input_size, int num_units, int batch_size,
int output_batch_leading_dim,
TfLiteFusedActivation activation,
float* hidden_state_ptr_batch, float* output_ptr_batch) { … }
void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
const float* aux_input_ptr_batch,
const float* aux_input_weights_ptr,
const float* recurrent_weights_ptr, const float* bias_ptr,
int input_size, int aux_input_size, int num_units,
int batch_size, int output_batch_leading_dim,
TfLiteFusedActivation activation,
float* hidden_state_ptr_batch, float* output_ptr_batch) { … }
void RnnBatchStep(
const float* input_ptr_batch, const int8_t* input_weights_ptr,
float input_weights_scale, const int8_t* recurrent_weights_ptr,
float recurrent_weights_scale, const float* bias_ptr, int input_size,
int num_units, int batch_size, int output_batch_leading_dim,
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
float* hidden_state_ptr_batch, float* output_ptr_batch,
bool asymmetric_quantize_inputs, int32_t* zero_points,
int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums) { … }
void RnnBatchStep(
const float* input_ptr_batch, const int8_t* input_weights_ptr,
float input_weights_scale, const float* aux_input_ptr_batch,
const int8_t* aux_input_weights_ptr, float aux_input_weights_scale,
const int8_t* recurrent_weights_ptr, float recurrent_weights_scale,
const float* bias_ptr, int input_size, int aux_input_size, int num_units,
int batch_size, int output_batch_leading_dim,
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
int8_t* aux_quantized_input_ptr_batch,
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
float* hidden_state_ptr_batch, float* output_ptr_batch,
bool asymmetric_quantize_inputs, int32_t* zero_points,
int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums) { … }
}
}