chromium/components/language_detection/core/embedding_lookup.cc

// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifdef UNSAFE_BUFFERS_BUILD
// TODO(crbug.com/40285824): Remove this and convert code to safer constructs.
#pragma allow_unsafe_buffers
#endif

#include "components/language_detection/core/embedding_lookup.h"

#include "base/check_op.h"
#include "components/language_detection/core/quantization_utils.h"
#include "third_party/flatbuffers/src/include/flatbuffers/flexbuffers.h"
#include "third_party/tflite/src/tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "third_party/tflite/src/tensorflow/lite/kernels/kernel_util.h"

namespace language_detection {

namespace {

GetRoot;
Map;
constexpr int kInputMessage =;
constexpr int kEmbeddingTable =;
constexpr int kMinVal =;
constexpr int kMaxVal =;
constexpr int kOutputLabel =;
constexpr int kNumFloatBits =;

class EmbeddingLookupOpParams {};

int GetOutputEmbeddingSize(const int input_embedding_size,
                           const bool is_quantized,
                           const int num_precision_bits) {}

void* Init(TfLiteContext* context, const char* buffer, size_t length) {}

void Free(TfLiteContext* context, void* buffer) {}

TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {}

// This is the core method that generates the aggregated embedding from the
// given input and embedding table tensors.
//
// If `is_quantized` is set to false, the `embedding_table` is considered to
// be a regular floating-point tensor, with each row representing an
// embedding vector, and each element in the vector is an embedding dimension.
//
// If `is_quantized` is set to true, the `embedding_table` is considered to be
// a packed quantized tensor, with each row still representing an embedding
// vector. However, each element in the vector contains 'packed' n-bit quantized
// representation of m embedding dimensions.
//
// n = `num_precision_bits`,
// m = 32 / n.
void GetEmbedding(const TfLiteTensor* input,
                  const TfLiteTensor* embedding_table,
                  const float min_val,
                  const float max_val,
                  float* data,
                  const EmbeddingLookupOpParams* params) {}

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {}

}  // namespace

TfLiteRegistration* Register_EMBEDDING_LOOKUP() {}

}  // namespace language_detection