#include "tensorflow/lite/kernels/lstm_eval.h"
#include <math.h>
#include <string.h>
#include <algorithm>
#include <cstdint>
#include <memory>
#include <vector>
#include "ruy/matrix.h"
#include "ruy/mul_params.h"
#include "ruy/profiler/instrumentation.h"
#include "ruy/ruy.h"
#include "tensorflow/lite/core/c/builtin_op_data.h"
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/op_macros.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace lstm_eval {
namespace {
void MatrixBatchVectorMultiplyAccumulate(
const float* matrix, const float* vector, const float* result,
float* output, int m_rows, int m_cols, int n_batch,
CpuBackendContext* cpu_backend_context) { … }
void ComputeRowSums(
int32_t* input_to_input_row_sums, int32_t* input_to_forget_row_sums,
int32_t* input_to_cell_row_sums, int32_t* input_to_output_row_sums,
int32_t* aux_input_to_input_row_sums, int32_t* aux_input_to_forget_row_sums,
int32_t* aux_input_to_cell_row_sums, int32_t* aux_input_to_output_row_sums,
int32_t* recurrent_to_input_row_sums, int32_t* recurrent_to_forget_row_sums,
int32_t* recurrent_to_cell_row_sums, int32_t* recurrent_to_output_row_sums,
int32_t* projection_weights_row_sums, int32_t* row_sums, int n_cell,
int n_input, int n_aux_input, int n_output,
const int8_t* input_to_input_weights_ptr,
const int8_t* input_to_forget_weights_ptr,
const int8_t* input_to_cell_weights_ptr,
const int8_t* input_to_output_weights_ptr,
const int8_t* aux_input_to_input_weights_ptr,
const int8_t* aux_input_to_forget_weights_ptr,
const int8_t* aux_input_to_cell_weights_ptr,
const int8_t* aux_input_to_output_weights_ptr,
const int8_t* recurrent_to_input_weights_ptr,
const int8_t* recurrent_to_forget_weights_ptr,
const int8_t* recurrent_to_cell_weights_ptr,
const int8_t* recurrent_to_output_weights_ptr,
const int8_t* projection_weights_ptr, bool use_cifg,
const float* aux_input_ptr, bool recurrent_to_input_is_diag = false,
bool recurrent_to_forget_is_diag = false,
bool recurrent_to_cell_is_diag = false,
bool recurrent_to_output_is_diag = false) { … }
inline float GetTensorScale(const TfLiteTensor* tensor) { … }
inline void CalculateLstmGateFloat(
const float* input, const float* input_to_gate_weights,
const float* aux_input, const float* aux_input_to_gate_weights,
const float* output_state, const float* recurrent_to_gate_weights,
const float* cell_state, const float* cell_to_gate_weights,
const float* layer_norm_coefficients, const float* gate_bias,
const int n_batch, const int n_input, const int n_aux_input,
const int n_output, const int n_cell,
const TfLiteFusedActivation activation, float* gate,
const bool is_input_all_zeros, const bool is_aux_input_all_zeros,
float* output, bool recurrent_is_diag, CpuBackendContext* context) { … }
void UpdateLstmCellFloat(int n_batch, int n_cell, float* cell_state,
const float* input_gate, float* forget_gate,
const float* cell_gate, bool use_cifg, float clip) { … }
void CalculateLstmOutputFloat(int n_batch, int n_cell, int n_output,
const float* cell_state, const float* output_gate,
TfLiteFusedActivation activation,
const float* projection_weights,
const float* projection_bias,
const float proj_clip, float* output_state,
float* scratch, float* projection_bias_scratch,
CpuBackendContext* context) { … }
void CalculateLstmGateHybrid(
const int8_t* input, const float* input_sf, const int32_t* input_zp,
const int8_t* input_to_gate_weights,
const uint8_t* input_to_gate_weights_ledger,
const float input_to_gate_weights_scale, int32_t* input_to_gate_row_sums,
const int8_t* aux_input, const float* aux_input_sf,
const int32_t* aux_input_zp, const int8_t* aux_input_to_gate_weights,
const float aux_input_to_gate_weights_scale,
int32_t* aux_input_to_gate_row_sums,
const int8_t* output_state, const float* output_state_float,
const float* output_state_sf, const int32_t* output_state_zp,
const int8_t* recurrent_to_gate_weights,
const float* recurrent_to_gate_diag,
const uint8_t* recurrent_to_gate_weights_ledger,
const float recurrent_to_gate_weights_scale,
int32_t* recurrent_to_gate_row_sums,
const float* cell_state, const int8_t* cell_to_gate_weights,
const float cell_to_gate_weights_scale,
const float* layer_norm_coefficients, const float* gate_bias,
const int n_batch, const int n_input, const int n_aux_input,
const int n_output, const int n_cell,
const TfLiteFusedActivation activation,
float* gate,
const bool is_input_all_zeros, const bool is_aux_input_all_zeros,
const bool is_output_state_all_zeros, bool* compute_row_sums,
CpuBackendContext* context,
float* scratch0,
float* scratch1,
int32_t* accum_scratch,
bool recurrent_is_diag) { … }
void CalculateLstmOutputHybrid(
int n_batch, int n_cell, int n_output, const float* cell_state,
const float* output_gate, TfLiteFusedActivation activation,
const int8_t* projection_weights, const uint8_t* projection_weights_ledger,
float projection_weights_scale, const float* projection_bias,
const float proj_clip, float* output_state, bool asymmetric_quantize_inputs,
int32_t* projection_weights_row_sums, bool* compute_row_sums,
CpuBackendContext* context, float* scratch0, int8_t* scratch1,
float* scratch2, int32_t* scratch3, int32_t* scratch4) { … }
void CalculateLstmGateInteger8x8_16(
const int8_t* input, const int8_t* input_to_gate_weights,
const int32_t* input_to_gate_bias, const int32_t input_to_gate_scale_a,
const int32_t input_to_gate_scale_b,
const int8_t* output_state, const int8_t* recurrent_to_gate_weights,
const int32_t* recurrent_to_gate_bias,
const int32_t recurrent_to_gate_scale_a,
const int32_t recurrent_to_gate_scale_b,
const int16_t* cell_state, const int16_t* cell_to_gate_weights,
const int32_t cell_to_gate_scale_a, const int32_t cell_to_gate_scale_b,
const int16_t* layer_norm_coefficients, const int32_t* layer_norm_bias,
const int32_t layer_norm_input_scale_a,
const int32_t layer_norm_input_scale_b,
const int32_t layer_norm_variance_guard,
const int n_batch, const int n_input, const int n_output, const int n_cell,
const TfLiteFusedActivation activation,
int16_t* gate,
CpuBackendContext* context,
int32_t* scratch5) { … }
void UpdateLstmCellInteger(int n_batch, int n_cell, int16_t* cell_state,
int32_t cell_state_scale, const int16_t* input_gate,
int16_t* forget_gate, const int16_t* cell_gate,
bool use_cifg, int16_t clip) { … }
void CalculateLstmOutputInteger8x8_16(
int n_batch, int n_cell, int n_output, const int16_t* cell_state,
int32_t cell_state_scale, const int16_t* output_gate,
int32_t hidden_scale_a, int32_t hidden_scale_b, int32_t hidden_zp,
const int8_t* projection_weights, int32_t proj_scale_a,
int32_t proj_scale_b, const int32_t* projection_bias,
int32_t output_state_zp, int8_t quantized_proj_clip, int8_t* output_state,
CpuBackendContext* context, int16_t* scratch0, int8_t* scratch1,
int32_t* scratch2) { … }
void CalculateLstmGateInteger8x8_8(
const int8_t* input, int32_t input_zp, const int8_t* input_to_gate_weight,
const int32_t input_to_gate_scale_a, const int32_t input_to_gate_scale_b,
const int32_t input_times_weights_scale_a,
const int32_t input_times_weights_scale_b,
const int32_t input_times_weights_zp,
const int8_t* output_state, const int32_t output_state_zp,
const int8_t* recurrent_to_gate_weight,
const int32_t recurrent_to_gate_scale_a,
const int32_t recurrent_to_gate_scale_b,
const int32_t output_state_times_weights_scale_a,
const int32_t output_state_times_weights_scale_b,
const int32_t output_state_times_weights_zp,
const int16_t* layer_norm_gate_weight,
const int32_t layer_norm_gate_scale_a,
const int32_t layer_norm_gate_scale_b, const int32_t* gate_bias,
const int n_batch, const int n_input, const int n_output, const int n_cell,
const TfLiteFusedActivation activation,
int16_t* gate,
int8_t* scratch0, int8_t* scratch1) { … }
void CalculateLstmOutputInteger8x8_8(
int n_batch, int n_cell, int n_output, const int16_t* cell_state,
const int16_t* output_gate, const int8_t* projection_weights,
int32_t proj_scale_a, int32_t proj_scale_b, const int32_t* projection_bias,
int32_t output_state_zp, int32_t quantized_proj_clip, int8_t* output_state,
int16_t* scratch) { … }
inline void LstmStepFloat(
const float* input_ptr, const float* input_to_input_weights_ptr,
const float* input_to_forget_weights_ptr,
const float* input_to_cell_weights_ptr,
const float* input_to_output_weights_ptr, const float* aux_input_ptr,
const float* aux_input_to_input_weights_ptr,
const float* aux_input_to_forget_weights_ptr,
const float* aux_input_to_cell_weights_ptr,
const float* aux_input_to_output_weights_ptr,
const float* recurrent_to_input_weights_ptr,
const float* recurrent_to_forget_weights_ptr,
const float* recurrent_to_cell_weights_ptr,
const float* recurrent_to_output_weights_ptr,
const float* cell_to_input_weights_ptr,
const float* cell_to_forget_weights_ptr,
const float* cell_to_output_weights_ptr,
const float* input_layer_norm_coefficients_ptr,
const float* forget_layer_norm_coefficients_ptr,
const float* cell_layer_norm_coefficients_ptr,
const float* output_layer_norm_coefficients_ptr,
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
const float* projection_weights_ptr, const float* projection_bias_ptr,
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
int n_aux_input, int n_output, int output_batch_leading_dim,
float* output_state_ptr, float* cell_state_ptr, float* scratch0,
float* scratch1, float* scratch2, float* scratch3, float* scratch4,
float* output_ptr, bool recurrent_to_input_is_diag,
bool recurrent_to_forget_is_diag, bool recurrent_to_cell_is_diag,
bool recurrent_to_output_is_diag, CpuBackendContext* context) { … }
inline void LstmStepHybrid(
const float* input_ptr, const int8_t* input_to_input_weights_ptr,
const uint8_t* input_to_input_weights_ledger_ptr,
float input_to_input_weights_scale,
const int8_t* input_to_forget_weights_ptr,
const uint8_t* input_to_forget_weights_ledger_ptr,
float input_to_forget_weights_scale,
const int8_t* input_to_cell_weights_ptr,
const uint8_t* input_to_cell_weights_ledger_ptr,
float input_to_cell_weights_scale,
const int8_t* input_to_output_weights_ptr,
const uint8_t* input_to_output_weights_ledger_ptr,
float input_to_output_weights_scale, const float* aux_input_ptr,
const int8_t* aux_input_to_input_weights_ptr,
float aux_input_to_input_weights_scale,
const int8_t* aux_input_to_forget_weights_ptr,
float aux_input_to_forget_weights_scale,
const int8_t* aux_input_to_cell_weights_ptr,
float aux_input_to_cell_weights_scale,
const int8_t* aux_input_to_output_weights_ptr,
float aux_input_to_output_weights_scale,
const int8_t* recurrent_to_input_weights_ptr,
const float* recurrent_to_input_diag,
const uint8_t* recurrent_to_input_weights_ledger_ptr,
float recurrent_to_input_weights_scale,
const int8_t* recurrent_to_forget_weights_ptr,
const float* recurrent_to_forget_diag,
const uint8_t* recurrent_to_forget_weights_ledger_ptr,
float recurrent_to_forget_weights_scale,
const int8_t* recurrent_to_cell_weights_ptr,
const float* recurrent_to_cell_diag,
const uint8_t* recurrent_to_cell_weights_ledger_ptr,
float recurrent_to_cell_weights_scale,
const int8_t* recurrent_to_output_weights_ptr,
const float* recurrent_to_output_diag,
const uint8_t* recurrent_to_output_weights_ledger_ptr,
float recurrent_to_output_weights_scale,
const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
const int8_t* cell_to_forget_weights_ptr,
float cell_to_forget_weights_scale,
const int8_t* cell_to_output_weights_ptr,
float cell_to_output_weights_scale,
const float* input_layer_norm_coefficients_ptr,
const float* forget_layer_norm_coefficients_ptr,
const float* cell_layer_norm_coefficients_ptr,
const float* output_layer_norm_coefficients_ptr,
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
const int8_t* projection_weights_ptr,
const uint8_t* projection_weights_ledger_ptr,
float projection_weights_scale, const float* projection_bias_ptr,
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
int n_aux_input, int n_output, int output_batch_leading_dim,
float* scratch0, float* scratch1, float* scratch2, float* scratch3,
float* input_sf, float* aux_input_sf, float* output_state_sf,
float* scaling_factors_scratch, float* recovered_cell_weights,
int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr,
int8_t* quantized_output_state_ptr, int8_t* quantized_output_scratch,
float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr,
float* output_ptr, int32_t* input_zp, int32_t* aux_input_zp,
int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
bool* compute_row_sums, bool asymmetric_quantize_inputs,
bool recurrent_to_input_is_diag, bool recurrent_to_forget_is_diag,
bool recurrent_to_cell_is_diag, bool recurrent_to_output_is_diag,
CpuBackendContext* context) { … }
inline void LstmStepInteger8x8_16(
const int8_t* input_ptr, const int8_t* input_to_input_weight_ptr,
int32_t effective_input_to_input_scale_a,
int32_t effective_input_to_input_scale_b,
const int8_t* input_to_forget_weight_ptr,
int32_t effective_input_to_forget_scale_a,
int32_t effective_input_to_forget_scale_b,
const int8_t* input_to_cell_weight_ptr,
int32_t effective_input_to_cell_scale_a,
int32_t effective_input_to_cell_scale_b,
const int8_t* input_to_output_weight_ptr,
int32_t effective_input_to_output_scale_a,
int32_t effective_input_to_output_scale_b,
const int8_t* recurrent_to_input_weight_ptr,
int32_t effective_recurrent_to_input_scale_a,
int32_t effective_recurrent_to_input_scale_b,
const int8_t* recurrent_to_forget_weight_ptr,
int32_t effective_recurrent_to_forget_scale_a,
int32_t effective_recurrent_to_forget_scale_b,
const int8_t* recurrent_to_cell_weight_ptr,
int32_t effective_recurrent_to_cell_scale_a,
int32_t effective_recurrent_to_cell_scale_b,
const int8_t* recurrent_to_output_weight_ptr,
int32_t effective_recurrent_to_output_scale_a,
int32_t effective_recurrent_to_output_scale_b,
const int16_t* cell_to_input_weight_ptr,
int32_t effective_cell_to_input_scale_a,
int32_t effective_cell_to_input_scale_b,
const int16_t* cell_to_forget_weight_ptr,
int32_t effective_cell_to_forget_scale_a,
int32_t effective_cell_to_forget_scale_b,
const int16_t* cell_to_output_weight_ptr,
int32_t effective_cell_to_output_scale_a,
int32_t effective_cell_to_output_scale_b,
const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
int32_t effective_proj_scale_b, int32_t hidden_zp,
int32_t effective_hidden_scale_a, int32_t effective_hidden_scale_b,
const int16_t* layer_norm_input_weight_ptr,
int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
const int16_t* layer_norm_forget_weight_ptr,
int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
int32_t layer_norm_cell_scale_b,
const int16_t* layer_norm_output_weight_ptr,
int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
int16_t quantized_cell_clip, int8_t quantized_proj_clip,
int32_t cell_state_scale, int32_t input_variance_guard,
int32_t forget_variance_guard, int32_t cell_variance_guard,
int32_t output_variance_guard,
const int32_t* input_to_forget_effective_bias,
const int32_t* recurrent_to_forget_effective_bias,
const int32_t* input_to_cell_effective_bias,
const int32_t* recurrent_to_cell_effective_bias,
const int32_t* input_to_output_effective_bias,
const int32_t* recurrent_to_output_effective_bias,
const int32_t* input_to_input_effective_bias,
const int32_t* recurrent_to_input_effective_bias,
const int32_t* projection_effective_bias, int n_batch, int n_cell,
int n_input, int n_output, int8_t* output_state_ptr,
int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
int16_t* scratch0, int16_t* scratch1, int16_t* scratch2, int16_t* scratch3,
int8_t* scratch4, int32_t* scratch5, CpuBackendContext* context) { … }
inline void LstmStepInteger8x8_8(
const int8_t* input_ptr, int32_t input_zp,
const int8_t* input_to_input_weight_ptr,
int32_t effective_input_to_input_scale_a,
int32_t effective_input_to_input_scale_b,
const int8_t* input_to_forget_weight_ptr,
int32_t effective_input_to_forget_scale_a,
int32_t effective_input_to_forget_scale_b,
const int8_t* input_to_cell_weight_ptr,
int32_t effective_input_to_cell_scale_a,
int32_t effective_input_to_cell_scale_b,
const int8_t* input_to_output_weight_ptr,
int32_t effective_input_to_output_scale_a,
int32_t effective_input_to_output_scale_b,
const int8_t* recurrent_to_input_weight_ptr,
int32_t effective_recurrent_to_input_scale_a,
int32_t effective_recurrent_to_input_scale_b,
const int8_t* recurrent_to_forget_weight_ptr,
int32_t effective_recurrent_to_forget_scale_a,
int32_t effective_recurrent_to_forget_scale_b,
const int8_t* recurrent_to_cell_weight_ptr,
int32_t effective_recurrent_to_cell_scale_a,
int32_t effective_recurrent_to_cell_scale_b,
const int8_t* recurrent_to_output_weight_ptr,
int32_t effective_recurrent_to_output_scale_a,
int32_t effective_recurrent_to_output_scale_b,
const int8_t* cell_to_input_weight_ptr,
int32_t effective_cell_to_input_scale_a,
int32_t effective_cell_to_input_scale_b,
const int8_t* cell_to_forget_weight_ptr,
int32_t effective_cell_to_forget_scale_a,
int32_t effective_cell_to_forget_scale_b,
const int8_t* cell_to_output_weight_ptr,
int32_t effective_cell_to_output_scale_a,
int32_t effective_cell_to_output_scale_b,
const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
int32_t effective_proj_scale_b, const int16_t* layer_norm_input_weight_ptr,
int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
const int16_t* layer_norm_forget_weight_ptr,
int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
int32_t layer_norm_cell_scale_b,
const int16_t* layer_norm_output_weight_ptr,
int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
const int32_t* projection_bias_ptr, const TfLiteLSTMParams* params,
const int32_t* intermediate_scale_a, const int32_t* intermediate_scale_b,
const int32_t* intermediate_zp, int16_t quantized_cell_clip,
int8_t quantized_proj_clip, int n_batch, int n_cell, int n_input,
int n_output, int output_batch_leading_dim, int8_t* output_state_ptr,
int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
int16_t* scratch4, int16_t* scratch5, int16_t* scratch6,
int16_t* scratch7) { … }
}
TfLiteStatus EvalFloat(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
const TfLiteTensor* input_to_forget_weights,
const TfLiteTensor* input_to_cell_weights,
const TfLiteTensor* input_to_output_weights,
const TfLiteTensor* recurrent_to_input_weights,
const TfLiteTensor* recurrent_to_forget_weights,
const TfLiteTensor* recurrent_to_cell_weights,
const TfLiteTensor* recurrent_to_output_weights,
const TfLiteTensor* cell_to_input_weights,
const TfLiteTensor* cell_to_forget_weights,
const TfLiteTensor* cell_to_output_weights,
const TfLiteTensor* input_layer_norm_coefficients,
const TfLiteTensor* forget_layer_norm_coefficients,
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* aux_input,
const TfLiteTensor* aux_input_to_input_weights,
const TfLiteTensor* aux_input_to_forget_weights,
const TfLiteTensor* aux_input_to_cell_weights,
const TfLiteTensor* aux_input_to_output_weights,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state,
TfLiteTensor* cell_state, TfLiteTensor* output,
bool recurrent_to_input_is_diag, bool recurrent_to_forget_is_diag,
bool recurrent_to_cell_is_diag, bool recurrent_to_output_is_diag,
CpuBackendContext* context) { … }
TfLiteStatus EvalHybrid(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
const TfLiteTensor* input_to_input_weights_ledger,
const TfLiteTensor* input_to_forget_weights,
const TfLiteTensor* input_to_forget_weights_ledger,
const TfLiteTensor* input_to_cell_weights,
const TfLiteTensor* input_to_cell_weights_ledger,
const TfLiteTensor* input_to_output_weights,
const TfLiteTensor* input_to_output_weights_ledger,
const TfLiteTensor* recurrent_to_input_weights,
const TfLiteTensor* recurrent_to_input_weights_ledger,
const TfLiteTensor* recurrent_to_forget_weights,
const TfLiteTensor* recurrent_to_forget_weights_ledger,
const TfLiteTensor* recurrent_to_cell_weights,
const TfLiteTensor* recurrent_to_cell_weights_ledger,
const TfLiteTensor* recurrent_to_output_weights,
const TfLiteTensor* recurrent_to_output_weights_ledger,
const TfLiteTensor* cell_to_input_weights,
const TfLiteTensor* cell_to_forget_weights,
const TfLiteTensor* cell_to_output_weights,
const TfLiteTensor* input_layer_norm_coefficients,
const TfLiteTensor* forget_layer_norm_coefficients,
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* aux_input,
const TfLiteTensor* aux_input_to_input_weights,
const TfLiteTensor* aux_input_to_forget_weights,
const TfLiteTensor* aux_input_to_cell_weights,
const TfLiteTensor* aux_input_to_output_weights,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights,
const TfLiteTensor* projection_weights_ledger,
const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params,
bool forward_sequence, bool time_major, int output_offset,
TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf,
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
TfLiteTensor* output_state, TfLiteTensor* cell_state,
TfLiteTensor* output_scratch_buffer, TfLiteTensor* output,
TfLiteTensor* input_zp, TfLiteTensor* aux_input_zp,
TfLiteTensor* output_state_zp, TfLiteTensor* row_sums, int row_sums_size,
bool* compute_row_sums, bool recurrent_to_input_is_diag,
bool recurrent_to_forget_is_diag, bool recurrent_to_cell_is_diag,
bool recurrent_to_output_is_diag, CpuBackendContext* context) { … }
TfLiteStatus EvalInteger8x8_16(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
const TfLiteTensor* input_to_forget_weights,
const TfLiteTensor* input_to_cell_weights,
const TfLiteTensor* input_to_output_weights,
const TfLiteTensor* recurrent_to_input_weights,
const TfLiteTensor* recurrent_to_forget_weights,
const TfLiteTensor* recurrent_to_cell_weights,
const TfLiteTensor* recurrent_to_output_weights,
const TfLiteTensor* cell_to_input_weights,
const TfLiteTensor* cell_to_forget_weights,
const TfLiteTensor* cell_to_output_weights,
const TfLiteTensor* input_layer_norm_coefficients,
const TfLiteTensor* forget_layer_norm_coefficients,
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
const lstm_eval::IntegerLstmParameter* integer_lstm_param,
TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output,
TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
CpuBackendContext* context) { … }
TfLiteStatus EvalInteger8x8_8(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
const TfLiteTensor* input_to_forget_weights,
const TfLiteTensor* input_to_cell_weights,
const TfLiteTensor* input_to_output_weights,
const TfLiteTensor* recurrent_to_input_weights,
const TfLiteTensor* recurrent_to_forget_weights,
const TfLiteTensor* recurrent_to_cell_weights,
const TfLiteTensor* recurrent_to_output_weights,
const TfLiteTensor* cell_to_input_weights,
const TfLiteTensor* cell_to_forget_weights,
const TfLiteTensor* cell_to_output_weights,
const TfLiteTensor* input_layer_norm_coefficients,
const TfLiteTensor* forget_layer_norm_coefficients,
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, TfLiteTensor* output_state,
TfLiteTensor* cell_state, TfLiteTensor* output,
const lstm_eval::IntegerLstmParameter* integer_lstm_param,
TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
TfLiteTensor* scratch6, TfLiteTensor* scratch7) { … }
}
}
}
}