#include <algorithm>
#include <cstddef>
#include <cstdint>
#include "tensorflow/lite/core/c/builtin_op_data.h"
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace bidirectional_sequence_rnn {
namespace {
struct OpData { … };
}
constexpr int kInputTensor = …;
constexpr int kFwWeightsTensor = …;
constexpr int kFwRecurrentWeightsTensor = …;
constexpr int kFwBiasTensor = …;
constexpr int kFwHiddenStateTensor = …;
constexpr int kBwWeightsTensor = …;
constexpr int kBwRecurrentWeightsTensor = …;
constexpr int kBwBiasTensor = …;
constexpr int kBwHiddenStateTensor = …;
constexpr int kAuxInputTensor = …;
constexpr int kFwAuxWeightsTensor = …;
constexpr int kBwAuxWeightsTensor = …;
constexpr int kFwOutputTensor = …;
constexpr int kBwOutputTensor = …;
enum TemporaryTensor { … };
void* Init(TfLiteContext* context, const char* buffer, size_t length) { … }
void Free(TfLiteContext* context, void* buffer) { … }
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { … }
TfLiteStatus EvalFloat(const TfLiteTensor* input, const TfLiteTensor* bw_input,
const TfLiteTensor* fw_input_weights,
const TfLiteTensor* fw_recurrent_weights,
const TfLiteTensor* fw_bias,
const TfLiteTensor* bw_input_weights,
const TfLiteTensor* bw_recurrent_weights,
const TfLiteTensor* bw_bias,
const TfLiteTensor* aux_input,
const TfLiteTensor* fw_aux_input_weights,
const TfLiteTensor* bw_aux_input_weights,
const TfLiteBidirectionalSequenceRNNParams* params,
TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) { … }
TfLiteStatus EvalHybrid(
const TfLiteTensor* input, const TfLiteTensor* bw_input,
const TfLiteTensor* fw_input_weights,
const TfLiteTensor* fw_recurrent_weights, const TfLiteTensor* fw_bias,
const TfLiteTensor* bw_input_weights,
const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
const TfLiteTensor* aux_input, const TfLiteTensor* aux_fw_input_weights,
const TfLiteTensor* aux_bw_input_weights,
const TfLiteBidirectionalSequenceRNNParams* params,
TfLiteTensor* scaling_factors, TfLiteTensor* input_quantized,
TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized,
TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state,
TfLiteTensor* bw_output, TfLiteTensor* zero_points,
TfLiteTensor* accum_scratch, TfLiteTensor* fw_row_sums,
TfLiteTensor* bw_row_sums, bool* fw_compute_row_sums,
bool* bw_compute_row_sums) { … }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { … }
}
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN() { … }
}
}
}