chromium/components/omnibox/browser/on_device_tail_model_executor.cc

// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifdef UNSAFE_BUFFERS_BUILD
// TODO(crbug.com/40285824): Remove this and convert code to safer constructs.
#pragma allow_unsafe_buffers
#endif

#include "components/omnibox/browser/on_device_tail_model_executor.h"

#include <cmath>
#include <cstdint>
#include <sstream>
#include <string_view>

#include "base/base64.h"
#include "base/containers/contains.h"
#include "base/files/file_util.h"
#include "base/hash/hash.h"
#include "base/logging.h"
#include "base/strings/strcat.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/string_split.h"
#include "base/strings/string_util.h"
#include "components/omnibox/browser/omnibox_field_trial.h"
#include "components/optimization_guide/core/model_util.h"
#include "components/optimization_guide/core/tflite_op_resolver.h"
#include "third_party/tflite/src/tensorflow/lite/c/c_api_types.h"
#include "third_party/tflite/src/tensorflow/lite/kernels/register.h"
#include "third_party/tflite/src/tensorflow/lite/model_builder.h"

namespace {
// The names of the subgraphs.
static constexpr char kPreviousQueryEncoder[] =;
static constexpr char kRnnStep[] =;

// The names of input & output node.
static constexpr char kPrevQueryTokenIdsNodeName[] =;
static constexpr char kPrevQueryEncodingOutputNodeName[] =;

static constexpr char kRnnStepInputIdsNodeName[] =;
static constexpr char kRnnStepPrevQueryEncodingInputNodeName[] =;

static constexpr std::string_view kRnnStepCStateInputNamePrefix =;
static constexpr std::string_view kRnnStepMStateInputNamePrefix =;

static constexpr std::string_view kRnnStepCStateOutputNamePrefix =;
static constexpr std::string_view kRnnStepMStateOutputNamePrefix =;

static constexpr char kRnnStepOutputProbsNodeName[] =;

// Some default values of params needed to run the model.
static constexpr size_t kDefaultMaxNumSteps =;
static constexpr float kDefaultProbabilityThreshold =;

// The sizes of the caches.
static constexpr size_t kPreQueryEncodingCacheSize =;
static constexpr size_t kRnnStepOutputCacheSize =;

// Maximum file size that will be loaded in bytes.
static constexpr size_t kFileSizeLimit =;

// Keywords to identify additional files needed by the executor.
static constexpr char kVocabFileNameKeyword[] =;
static constexpr char kBadwordHashesFileNameKeyword[] =;
static constexpr char kBadSubstringDenyListFileNameKeyword[] =;

std::ostream& operator<<(std::ostream& os,
                         const OnDeviceTailTokenizer::TokenIds& ids) {}

std::string LoadFileContent(const base::FilePath file_path) {}

}  // namespace

OnDeviceTailModelExecutor::ModelInput::ModelInput() = default;

OnDeviceTailModelExecutor::ModelInput::ModelInput(std::string prefix,
                                                  std::string previous_query,
                                                  size_t max_num_suggestions)
    :{}

OnDeviceTailModelExecutor::ModelInput::~ModelInput() = default;

OnDeviceTailModelExecutor::RnnCellStates::RnnCellStates() = default;

OnDeviceTailModelExecutor::RnnCellStates::RnnCellStates(size_t num_layer,
                                                        size_t state_size) {}

OnDeviceTailModelExecutor::RnnCellStates::RnnCellStates(
    const RnnCellStates& other) {}

OnDeviceTailModelExecutor::RnnCellStates::~RnnCellStates() = default;

OnDeviceTailModelExecutor::RnnStepOutput::RnnStepOutput() = default;

OnDeviceTailModelExecutor::RnnStepOutput::RnnStepOutput(size_t num_layer,
                                                        size_t state_size,
                                                        size_t vocab_size)
    :{}

OnDeviceTailModelExecutor::RnnStepOutput::RnnStepOutput(
    const RnnStepOutput& other) {}

OnDeviceTailModelExecutor::RnnStepOutput::~RnnStepOutput() = default;

OnDeviceTailModelExecutor::BeamNode::BeamNode() = default;

OnDeviceTailModelExecutor::BeamNode::BeamNode(int num_layer, int state_size)
    :{}

OnDeviceTailModelExecutor::BeamNode::BeamNode(const BeamNode& other) {}

OnDeviceTailModelExecutor::BeamNode::~BeamNode() = default;

OnDeviceTailModelExecutor::OnDeviceTailModelExecutor()
    :{}

OnDeviceTailModelExecutor::~OnDeviceTailModelExecutor() = default;

bool OnDeviceTailModelExecutor::Init() {}

bool OnDeviceTailModelExecutor::Init(
    const base::FilePath& model_filepath,
    const base::flat_set<base::FilePath>& additional_files,
    const ModelMetadata& metadata) {}

bool OnDeviceTailModelExecutor::InitModelInterpreter(
    const base::FilePath& model_filepath) {}

bool OnDeviceTailModelExecutor::EncodePreviousQuery(
    const OnDeviceTailTokenizer::TokenIds& prev_query_token_ids,
    std::vector<float>* prev_query_encoding) {}

void OnDeviceTailModelExecutor::ResetCaches() {}

void OnDeviceTailModelExecutor::LoadBadSubstringSet() {}

void OnDeviceTailModelExecutor::LoadBadwordHashSet() {}

bool OnDeviceTailModelExecutor::IsSuggestionBad(const std::string suggestion) {}

void OnDeviceTailModelExecutor::Reset() {}

bool OnDeviceTailModelExecutor::RunRnnStep(
    const OnDeviceTailTokenizer::TokenIds& rnn_step_cache_key,
    const OnDeviceTailTokenizer::TokenId& input_id,
    const std::vector<float>& prev_query_encoding,
    const RnnCellStates& previous_states,
    RnnStepOutput* rnn_step_output) {}

void OnDeviceTailModelExecutor::CreateNewBeams(
    const RnnStepOutput& rnn_step_output,
    const BeamNode& current_beam,
    size_t max_num_suggestions,
    float log_prob_threshold,
    CandidateQueue* partial_candidates,
    CandidateQueue* completed_candidates) {}

void OnDeviceTailModelExecutor::InsertBeamNodeToCandidateQueue(
    const TokenIdAndProb& token_id_and_prob,
    const RnnCellStates& states,
    const BeamNode& current_beam,
    float log_prob_threshold,
    size_t max_num_suggestions,
    CandidateQueue* queue) {}

bool OnDeviceTailModelExecutor::GetRootBeamNode(
    const OnDeviceTailTokenizer::Tokenization& input_tokenization,
    const OnDeviceTailTokenizer::TokenIds& prev_query_token_ids,
    std::vector<float>* prev_query_encoding,
    BeamNode* root_beam) {}

// static
float OnDeviceTailModelExecutor::GetLogProbability(float probability) {}

std::vector<OnDeviceTailModelExecutor::Prediction>
OnDeviceTailModelExecutor::GenerateSuggestionsForPrefix(
    const ModelInput& input) {}