#include "embedding_network.h"
#include "base.h"
#include "embedding_network_params.h"
#include "float16.h"
#include "simple_adder.h"
namespace chrome_lang_id {
namespace {
VectorWrapper;
void CheckNoQuantization(const EmbeddingNetworkParams::Matrix matrix) { … }
void FillMatrixParams(const EmbeddingNetworkParams::Matrix source_matrix,
EmbeddingNetwork::Matrix *mat) { … }
template <typename ScaleAdderClass>
void SparseReluProductPlusBias(bool apply_relu,
const EmbeddingNetwork::Matrix &weights,
const EmbeddingNetwork::VectorWrapper &b,
const EmbeddingNetwork::Vector &x,
EmbeddingNetwork::Vector *y) { … }
}
void EmbeddingNetwork::ConcatEmbeddings(
const std::vector<FeatureVector> &feature_vectors, Vector *concat) const { … }
template <typename ScaleAdderClass>
void EmbeddingNetwork::FinishComputeFinalScores(const Vector &concat,
Vector *scores) const { … }
void EmbeddingNetwork::ComputeFinalScores(
const std::vector<FeatureVector> &features, Vector *scores) const { … }
EmbeddingNetwork::EmbeddingNetwork(const EmbeddingNetworkParams *model)
: … { … }
}