#include <cstddef>
#include <cstdint>
#include "tensorflow/lite/core/c/builtin_op_data.h"
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace unidirectional_sequence_rnn {
namespace {
struct OpData { … };
}
constexpr int kInputTensor = …;
constexpr int kWeightsTensor = …;
constexpr int kRecurrentWeightsTensor = …;
constexpr int kBiasTensor = …;
constexpr int kHiddenStateTensor = …;
constexpr int kOutputTensor = …;
void* Init(TfLiteContext* context, const char* buffer, size_t length) { … }
void Free(TfLiteContext* context, void* buffer) { … }
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { … }
TfLiteStatus EvalFloat(const TfLiteTensor* input,
const TfLiteTensor* input_weights,
const TfLiteTensor* recurrent_weights,
const TfLiteTensor* bias,
const TfLiteSequenceRNNParams* params,
TfLiteTensor* hidden_state, TfLiteTensor* output) { … }
TfLiteStatus EvalHybrid(
const TfLiteTensor* input, const TfLiteTensor* input_weights,
const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias,
const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch,
TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors,
TfLiteTensor* hidden_state, TfLiteTensor* output, TfLiteTensor* zero_points,
TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
bool* compute_row_sums) { … }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { … }
}
TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_RNN() { … }
}
}
}