#ifdef UNSAFE_BUFFERS_BUILD
#pragma allow_unsafe_buffers
#endif
#include "components/omnibox/browser/on_device_tail_model_executor.h"
#include <cmath>
#include <cstdint>
#include <sstream>
#include <string_view>
#include "base/base64.h"
#include "base/containers/contains.h"
#include "base/files/file_util.h"
#include "base/hash/hash.h"
#include "base/logging.h"
#include "base/strings/strcat.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/string_split.h"
#include "base/strings/string_util.h"
#include "components/omnibox/browser/omnibox_field_trial.h"
#include "components/optimization_guide/core/model_util.h"
#include "components/optimization_guide/core/tflite_op_resolver.h"
#include "third_party/tflite/src/tensorflow/lite/c/c_api_types.h"
#include "third_party/tflite/src/tensorflow/lite/kernels/register.h"
#include "third_party/tflite/src/tensorflow/lite/model_builder.h"
namespace {
static constexpr char kPreviousQueryEncoder[] = …;
static constexpr char kRnnStep[] = …;
static constexpr char kPrevQueryTokenIdsNodeName[] = …;
static constexpr char kPrevQueryEncodingOutputNodeName[] = …;
static constexpr char kRnnStepInputIdsNodeName[] = …;
static constexpr char kRnnStepPrevQueryEncodingInputNodeName[] = …;
static constexpr std::string_view kRnnStepCStateInputNamePrefix = …;
static constexpr std::string_view kRnnStepMStateInputNamePrefix = …;
static constexpr std::string_view kRnnStepCStateOutputNamePrefix = …;
static constexpr std::string_view kRnnStepMStateOutputNamePrefix = …;
static constexpr char kRnnStepOutputProbsNodeName[] = …;
static constexpr size_t kDefaultMaxNumSteps = …;
static constexpr float kDefaultProbabilityThreshold = …;
static constexpr size_t kPreQueryEncodingCacheSize = …;
static constexpr size_t kRnnStepOutputCacheSize = …;
static constexpr size_t kFileSizeLimit = …;
static constexpr char kVocabFileNameKeyword[] = …;
static constexpr char kBadwordHashesFileNameKeyword[] = …;
static constexpr char kBadSubstringDenyListFileNameKeyword[] = …;
std::ostream& operator<<(std::ostream& os,
const OnDeviceTailTokenizer::TokenIds& ids) { … }
std::string LoadFileContent(const base::FilePath file_path) { … }
}
OnDeviceTailModelExecutor::ModelInput::ModelInput() = default;
OnDeviceTailModelExecutor::ModelInput::ModelInput(std::string prefix,
std::string previous_query,
size_t max_num_suggestions)
: … { … }
OnDeviceTailModelExecutor::ModelInput::~ModelInput() = default;
OnDeviceTailModelExecutor::RnnCellStates::RnnCellStates() = default;
OnDeviceTailModelExecutor::RnnCellStates::RnnCellStates(size_t num_layer,
size_t state_size) { … }
OnDeviceTailModelExecutor::RnnCellStates::RnnCellStates(
const RnnCellStates& other) { … }
OnDeviceTailModelExecutor::RnnCellStates::~RnnCellStates() = default;
OnDeviceTailModelExecutor::RnnStepOutput::RnnStepOutput() = default;
OnDeviceTailModelExecutor::RnnStepOutput::RnnStepOutput(size_t num_layer,
size_t state_size,
size_t vocab_size)
: … { … }
OnDeviceTailModelExecutor::RnnStepOutput::RnnStepOutput(
const RnnStepOutput& other) { … }
OnDeviceTailModelExecutor::RnnStepOutput::~RnnStepOutput() = default;
OnDeviceTailModelExecutor::BeamNode::BeamNode() = default;
OnDeviceTailModelExecutor::BeamNode::BeamNode(int num_layer, int state_size)
: … { … }
OnDeviceTailModelExecutor::BeamNode::BeamNode(const BeamNode& other) { … }
OnDeviceTailModelExecutor::BeamNode::~BeamNode() = default;
OnDeviceTailModelExecutor::OnDeviceTailModelExecutor()
: … { … }
OnDeviceTailModelExecutor::~OnDeviceTailModelExecutor() = default;
bool OnDeviceTailModelExecutor::Init() { … }
bool OnDeviceTailModelExecutor::Init(
const base::FilePath& model_filepath,
const base::flat_set<base::FilePath>& additional_files,
const ModelMetadata& metadata) { … }
bool OnDeviceTailModelExecutor::InitModelInterpreter(
const base::FilePath& model_filepath) { … }
bool OnDeviceTailModelExecutor::EncodePreviousQuery(
const OnDeviceTailTokenizer::TokenIds& prev_query_token_ids,
std::vector<float>* prev_query_encoding) { … }
void OnDeviceTailModelExecutor::ResetCaches() { … }
void OnDeviceTailModelExecutor::LoadBadSubstringSet() { … }
void OnDeviceTailModelExecutor::LoadBadwordHashSet() { … }
bool OnDeviceTailModelExecutor::IsSuggestionBad(const std::string suggestion) { … }
void OnDeviceTailModelExecutor::Reset() { … }
bool OnDeviceTailModelExecutor::RunRnnStep(
const OnDeviceTailTokenizer::TokenIds& rnn_step_cache_key,
const OnDeviceTailTokenizer::TokenId& input_id,
const std::vector<float>& prev_query_encoding,
const RnnCellStates& previous_states,
RnnStepOutput* rnn_step_output) { … }
void OnDeviceTailModelExecutor::CreateNewBeams(
const RnnStepOutput& rnn_step_output,
const BeamNode& current_beam,
size_t max_num_suggestions,
float log_prob_threshold,
CandidateQueue* partial_candidates,
CandidateQueue* completed_candidates) { … }
void OnDeviceTailModelExecutor::InsertBeamNodeToCandidateQueue(
const TokenIdAndProb& token_id_and_prob,
const RnnCellStates& states,
const BeamNode& current_beam,
float log_prob_threshold,
size_t max_num_suggestions,
CandidateQueue* queue) { … }
bool OnDeviceTailModelExecutor::GetRootBeamNode(
const OnDeviceTailTokenizer::Tokenization& input_tokenization,
const OnDeviceTailTokenizer::TokenIds& prev_query_token_ids,
std::vector<float>* prev_query_encoding,
BeamNode* root_beam) { … }
float OnDeviceTailModelExecutor::GetLogProbability(float probability) { … }
std::vector<OnDeviceTailModelExecutor::Prediction>
OnDeviceTailModelExecutor::GenerateSuggestionsForPrefix(
const ModelInput& input) { … }