#include <stdint.h>
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace matrix_set_diag {
constexpr int kInputTensor = …;
constexpr int kDiagonalTensor = …;
constexpr int kOutputTensor = …;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { … }
template <typename T>
void FillDiagImpl(const T* in, const T* diag, T* out, const int batch_size,
const int row_size, const int col_size) { … }
template <typename T>
void FillDiag(const TfLiteTensor* input, const TfLiteTensor* diag,
TfLiteTensor* output, const int batch_size, const int row_size,
const int col_size) { … }
void FillDiagHelper(const TfLiteTensor* input, const TfLiteTensor* diag,
TfLiteTensor* output) { … }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { … }
}
TfLiteRegistration* Register_MATRIX_SET_DIAG() { … }
}
}
}