#ifndef TENSORFLOW_LITE_KERNELS_LSTM_EVAL_H_
#define TENSORFLOW_LITE_KERNELS_LSTM_EVAL_H_
#include <cstdint>
#include <memory>
#include "tensorflow/lite/core/c/builtin_op_data.h"
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace lstm_eval {
struct IntegerLstmParameter { … };
TfLiteStatus EvalFloat(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
const TfLiteTensor* input_to_forget_weights,
const TfLiteTensor* input_to_cell_weights,
const TfLiteTensor* input_to_output_weights,
const TfLiteTensor* recurrent_to_input_weights,
const TfLiteTensor* recurrent_to_forget_weights,
const TfLiteTensor* recurrent_to_cell_weights,
const TfLiteTensor* recurrent_to_output_weights,
const TfLiteTensor* cell_to_input_weights,
const TfLiteTensor* cell_to_forget_weights,
const TfLiteTensor* cell_to_output_weights,
const TfLiteTensor* input_layer_norm_coefficients,
const TfLiteTensor* forget_layer_norm_coefficients,
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* aux_input,
const TfLiteTensor* aux_input_to_input_weights,
const TfLiteTensor* aux_input_to_forget_weights,
const TfLiteTensor* aux_input_to_cell_weights,
const TfLiteTensor* aux_input_to_output_weights,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state,
TfLiteTensor* cell_state, TfLiteTensor* output,
bool recurrent_to_input_is_diag, bool recurrent_to_forget_is_diag,
bool recurrent_to_cell_is_diag, bool recurrent_to_output_is_diag,
CpuBackendContext* context);
TfLiteStatus EvalHybrid(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
const TfLiteTensor* input_to_input_weights_ledger,
const TfLiteTensor* input_to_forget_weights,
const TfLiteTensor* input_to_forget_weights_ledger,
const TfLiteTensor* input_to_cell_weights,
const TfLiteTensor* input_to_cell_weights_ledger,
const TfLiteTensor* input_to_output_weights,
const TfLiteTensor* input_to_output_weights_ledger,
const TfLiteTensor* recurrent_to_input_weights,
const TfLiteTensor* recurrent_to_input_weights_ledger,
const TfLiteTensor* recurrent_to_forget_weights,
const TfLiteTensor* recurrent_to_forget_weights_ledger,
const TfLiteTensor* recurrent_to_cell_weights,
const TfLiteTensor* recurrent_to_cell_weights_ledger,
const TfLiteTensor* recurrent_to_output_weights,
const TfLiteTensor* recurrent_to_output_weights_ledger,
const TfLiteTensor* cell_to_input_weights,
const TfLiteTensor* cell_to_forget_weights,
const TfLiteTensor* cell_to_output_weights,
const TfLiteTensor* input_layer_norm_coefficients,
const TfLiteTensor* forget_layer_norm_coefficients,
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* aux_input,
const TfLiteTensor* aux_input_to_input_weights,
const TfLiteTensor* aux_input_to_forget_weights,
const TfLiteTensor* aux_input_to_cell_weights,
const TfLiteTensor* aux_input_to_output_weights,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights,
const TfLiteTensor* projection_weights_ledger,
const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params,
bool forward_sequence, bool time_major, int output_offset,
TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf,
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
TfLiteTensor* output_state, TfLiteTensor* cell_state,
TfLiteTensor* output_scratch_buffer, TfLiteTensor* output,
TfLiteTensor* input_zp, TfLiteTensor* aux_input_zp,
TfLiteTensor* output_state_zp, TfLiteTensor* row_sums, int row_sums_size,
bool* compute_row_sums, bool recurrent_to_input_is_diag,
bool recurrent_to_forget_is_diag, bool recurrent_to_cell_is_diag,
bool recurrent_to_output_is_diag, CpuBackendContext* context);
TfLiteStatus EvalInteger8x8_16(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
const TfLiteTensor* input_to_forget_weights,
const TfLiteTensor* input_to_cell_weights,
const TfLiteTensor* input_to_output_weights,
const TfLiteTensor* recurrent_to_input_weights,
const TfLiteTensor* recurrent_to_forget_weights,
const TfLiteTensor* recurrent_to_cell_weights,
const TfLiteTensor* recurrent_to_output_weights,
const TfLiteTensor* cell_to_input_weights,
const TfLiteTensor* cell_to_forget_weights,
const TfLiteTensor* cell_to_output_weights,
const TfLiteTensor* input_layer_norm_coefficients,
const TfLiteTensor* forget_layer_norm_coefficients,
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
const lstm_eval::IntegerLstmParameter* integer_lstm_param,
TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output,
TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
CpuBackendContext* context);
TfLiteStatus EvalInteger8x8_8(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
const TfLiteTensor* input_to_forget_weights,
const TfLiteTensor* input_to_cell_weights,
const TfLiteTensor* input_to_output_weights,
const TfLiteTensor* recurrent_to_input_weights,
const TfLiteTensor* recurrent_to_forget_weights,
const TfLiteTensor* recurrent_to_cell_weights,
const TfLiteTensor* recurrent_to_output_weights,
const TfLiteTensor* cell_to_input_weights,
const TfLiteTensor* cell_to_forget_weights,
const TfLiteTensor* cell_to_output_weights,
const TfLiteTensor* input_layer_norm_coefficients,
const TfLiteTensor* forget_layer_norm_coefficients,
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, TfLiteTensor* output_state,
TfLiteTensor* cell_state, TfLiteTensor* output,
const lstm_eval::IntegerLstmParameter* integer_lstm_param,
TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
TfLiteTensor* scratch6, TfLiteTensor* scratch7);
}
}
}
}
#endif