#include "components/optimization_guide/core/model_execution/session_impl.h"
#include <optional>
#include "base/containers/contains.h"
#include "base/functional/callback.h"
#include "base/metrics/histogram_functions.h"
#include "base/strings/stringprintf.h"
#include "base/timer/elapsed_timer.h"
#include "base/token.h"
#include "base/uuid.h"
#include "components/optimization_guide/core/model_execution/feature_keys.h"
#include "components/optimization_guide/core/model_execution/model_execution_util.h"
#include "components/optimization_guide/core/model_execution/on_device_model_access_controller.h"
#include "components/optimization_guide/core/model_execution/on_device_model_feature_adapter.h"
#include "components/optimization_guide/core/model_execution/on_device_model_service_controller.h"
#include "components/optimization_guide/core/model_execution/redactor.h"
#include "components/optimization_guide/core/model_execution/repetition_checker.h"
#include "components/optimization_guide/core/model_execution/substitution.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/core/optimization_guide_logger.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/optimization_guide/proto/model_quality_metadata.pb.h"
#include "components/optimization_guide/proto/string_value.pb.h"
#include "components/optimization_guide/proto/text_safety_model_metadata.pb.h"
namespace optimization_guide {
namespace {
RepeatedPtrField;
ModelExecutionError;
void LogResponseHasRepeats(ModelBasedCapabilityKey feature, bool has_repeats) { … }
std::string GenerateExecutionId() { … }
void InvokeStreamingCallbackWithRemoteResult(
OptimizationGuideModelExecutionResultStreamingCallback callback,
OptimizationGuideModelExecutionResult result,
std::unique_ptr<ModelQualityLogEntry> log_entry) { … }
proto::InternalOnDeviceModelExecutionInfo MakeTextSafetyExecutionLog(
const std::string& text,
const on_device_model::mojom::SafetyInfoPtr& safety_info,
bool is_unsafe) { … }
SamplingParams ResolveSamplingParams(
const std::optional<SessionConfigParams>& config_params,
const std::optional<SessionImpl::OnDeviceOptions>& on_device_opts) { … }
}
class SessionImpl::ContextProcessor
: public on_device_model::mojom::ContextClient { … };
SessionImpl::OnDeviceModelClient::~OnDeviceModelClient() = default;
SessionImpl::OnDeviceOptions::OnDeviceOptions() = default;
SessionImpl::OnDeviceOptions::OnDeviceOptions(OnDeviceOptions&&) = default;
SessionImpl::OnDeviceOptions::~OnDeviceOptions() = default;
bool SessionImpl::OnDeviceOptions::ShouldUse() const { … }
SessionImpl::SessionImpl(
ModelBasedCapabilityKey feature,
std::optional<OnDeviceOptions> on_device_opts,
ExecuteRemoteFn execute_remote_fn,
base::WeakPtr<OptimizationGuideLogger> optimization_guide_logger,
base::WeakPtr<ModelQualityLogsUploaderService>
model_quality_uploader_service,
const std::optional<SessionConfigParams>& config_params)
: … { … }
SessionImpl::~SessionImpl() { … }
const TokenLimits& SessionImpl::GetTokenLimits() const { … }
void SessionImpl::AddContext(
const google::protobuf::MessageLite& request_metadata) { … }
SessionImpl::AddContextResult SessionImpl::AddContextImpl(
const google::protobuf::MessageLite& request_metadata) { … }
void SessionImpl::Score(const std::string& text,
OptimizationGuideModelScoreCallback callback) { … }
void SessionImpl::ExecuteModel(
const google::protobuf::MessageLite& request_metadata,
optimization_guide::OptimizationGuideModelExecutionResultStreamingCallback
callback) { … }
void SessionImpl::RunNextRequestSafetyCheckOrBeginExecution(
on_device_model::mojom::InputOptionsPtr options,
int request_check_idx) { … }
void SessionImpl::OnRequestDetectLanguageResult(
on_device_model::mojom::InputOptionsPtr options,
int request_check_idx,
std::string check_input_text,
on_device_model::mojom::LanguageDetectionResultPtr result) { … }
void SessionImpl::OnRequestSafetyResult(
on_device_model::mojom::InputOptionsPtr options,
int request_check_idx,
std::string check_input_text,
on_device_model::mojom::SafetyInfoPtr safety_info) { … }
void SessionImpl::BeginRequestExecution(
on_device_model::mojom::InputOptionsPtr options) { … }
void SessionImpl::OnResponse(on_device_model::mojom::ResponseChunkPtr chunk) { … }
void SessionImpl::OnComplete(
on_device_model::mojom::ResponseSummaryPtr summary) { … }
void SessionImpl::RunRawOutputSafetyCheck() { … }
void SessionImpl::OnRawOutputSafetyResult(
std::string safety_check_text,
size_t raw_output_size,
on_device_model::mojom::SafetyInfoPtr safety_info) { … }
void SessionImpl::MaybeSendCompleteResponse() { … }
on_device_model::mojom::Session& SessionImpl::GetOrCreateSession() { … }
void SessionImpl::OnDisconnect() { … }
void SessionImpl::CancelPendingResponse(ExecuteModelResult result,
ModelExecutionError error) { … }
void SessionImpl::SendResponse(ResponseType response_type) { … }
void SessionImpl::OnParsedResponse(
bool is_complete,
base::expected<proto::Any, ResponseParsingError> output) { … }
void SessionImpl::SendPartialResponseCallback(
const proto::Any& success_response_metadata) { … }
void SessionImpl::SendSuccessCompletionCallback(
const proto::Any& success_response_metadata) { … }
bool SessionImpl::ShouldUseOnDeviceModel() const { … }
void SessionImpl::OnSessionTimedOut() { … }
void SessionImpl::DestroyOnDeviceStateAndFallbackToRemote(
ExecuteModelResult result) { … }
void SessionImpl::DestroyOnDeviceState() { … }
std::unique_ptr<google::protobuf::MessageLite> SessionImpl::MergeContext(
const google::protobuf::MessageLite& request) { … }
void SessionImpl::RunTextSafetyRemoteFallbackAndCompletionCallback(
proto::Any success_response_metadata) { … }
void SessionImpl::OnTextSafetyRemoteResponse(
proto::InternalOnDeviceModelExecutionInfo remote_ts_model_execution_info,
proto::Any success_response_metadata,
OptimizationGuideModelExecutionResult result,
std::unique_ptr<ModelQualityLogEntry> remote_log_entry) { … }
SessionImpl::OnDeviceState::OnDeviceState(OnDeviceOptions&& options,
SessionImpl* session)
: … { … }
SessionImpl::OnDeviceState::~OnDeviceState() = default;
proto::OnDeviceModelServiceResponse*
SessionImpl::OnDeviceState::MutableLoggedResponse() { … }
void SessionImpl::OnDeviceState::AddModelExecutionLog(
const proto::InternalOnDeviceModelExecutionInfo& log) { … }
void SessionImpl::OnDeviceState::ResetRequestState() { … }
SessionImpl::OnDeviceState::SafeRawOutput::SafeRawOutput() = default;
SessionImpl::OnDeviceState::SafeRawOutput::~SafeRawOutput() = default;
SessionImpl::ExecuteModelHistogramLogger::~ExecuteModelHistogramLogger() { … }
void SessionImpl::GetSizeInTokens(
const std::string& text,
OptimizationGuideModelSizeInTokenCallback callback) { … }
const SamplingParams SessionImpl::GetSamplingParams() const { … }
}