#include "services/on_device_model/ml/on_device_model_executor.h"
#include <algorithm>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "base/check.h"
#include "base/compiler_specific.h"
#include "base/containers/unique_ptr_adapters.h"
#include "base/logging.h"
#include "base/memory/raw_ref.h"
#include "base/memory/scoped_refptr.h"
#include "base/memory/weak_ptr.h"
#include "base/metrics/field_trial_params.h"
#include "base/metrics/histogram_functions.h"
#include "base/numerics/safe_conversions.h"
#include "base/task/thread_pool.h"
#include "base/timer/elapsed_timer.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "services/on_device_model/ml/chrome_ml.h"
#include "services/on_device_model/ml/session_accessor.h"
#include "services/on_device_model/public/mojom/on_device_model.mojom.h"
#include "services/on_device_model/public/mojom/on_device_model_service.mojom.h"
#if BUILDFLAG(IS_MAC)
#include "base/apple/foundation_util.h"
#endif
LoadModelResult;
namespace ml {
namespace {
constexpr uint32_t kReserveTokensForSafety = …;
const base::FeatureParam<bool> kPreferTextureWeights{ … };
const base::FeatureParam<bool> kEnableHostMappedPointer{ … };
const base::FeatureParam<bool> kUseLowPower{ … };
const base::FeatureParam<bool> kAllowFp16{ … };
template <typename R, typename C, typename... Args>
std::function<R(Args...)> CreateWeakCallbackFn(R (C::*method)(Args...),
C* that) { … }
template <typename R, typename... Args>
std::function<R(Args...)> ConvertCallbackToFn(
base::OnceCallback<R(Args...)> callback) { … }
int CalculateTokensPerSecond(int num_tokens, base::TimeDelta duration) { … }
float GetTemperature(std::optional<float> temperature) { … }
uint32_t GetTopK(std::optional<uint32_t> top_k) { … }
std::optional<ModelBackendType> ModelBackendTypeFromMojom(
on_device_model::mojom::ModelBackendType backend) { … }
}
class Responder final { … };
class ContextHolder final { … };
SessionImpl::SessionImpl(const ChromeML& chrome_ml,
ChromeMLModel model,
SessionAccessor::Ptr session,
SessionAccessor::Ptr empty_session,
uint32_t max_tokens,
std::optional<uint32_t> adaptation_id)
: … { … }
SessionImpl::~SessionImpl() = default;
DISABLE_CFI_DLSYM
void SessionImpl::AddContext(
on_device_model::mojom::InputOptionsPtr input,
mojo::PendingRemote<on_device_model::mojom::ContextClient> client,
base::OnceClosure on_complete) { … }
DISABLE_CFI_DLSYM
void SessionImpl::Execute(
on_device_model::mojom::InputOptionsPtr input,
mojo::PendingRemote<on_device_model::mojom::StreamingResponder> response,
base::OnceClosure on_complete) { … }
DISABLE_CFI_DLSYM
void SessionImpl::SizeInTokens(const std::string& text,
base::OnceCallback<void(uint32_t)> callback) { … }
DISABLE_CFI_DLSYM
void SessionImpl::Score(const std::string& text,
base::OnceCallback<void(float)> callback) { … }
std::unique_ptr<SessionImpl> SessionImpl::Clone() { … }
void SessionImpl::RemoveContext(ContextHolder* context) { … }
DISABLE_CFI_DLSYM
void DestroyModel(const ChromeML* chrome_ml, ChromeMLModel model) { … }
OnDeviceModelExecutor::OnDeviceModelExecutor(
base::PassKey<OnDeviceModelExecutor>,
const ChromeML& chrome_ml)
: … { … }
OnDeviceModelExecutor::~OnDeviceModelExecutor() { … }
base::expected<std::unique_ptr<OnDeviceModelExecutor>, LoadModelResult>
OnDeviceModelExecutor::CreateWithResult(
const ChromeML& chrome_ml,
on_device_model::mojom::LoadModelParamsPtr params,
base::OnceClosure on_complete) { … }
std::unique_ptr<SessionImpl> OnDeviceModelExecutor::CreateSession(
std::optional<uint32_t> adaptation_id) { … }
void OnDeviceModelExecutor::DetectLanguage(
const std::string& text,
on_device_model::mojom::OnDeviceModel::DetectLanguageCallback callback) { … }
DISABLE_CFI_DLSYM
void OnDeviceModelExecutor::ClassifyTextSafety(
const std::string& text,
on_device_model::mojom::OnDeviceModel::ClassifyTextSafetyCallback
callback) { … }
DISABLE_CFI_DLSYM
base::expected<uint32_t, LoadModelResult> OnDeviceModelExecutor::LoadAdaptation(
on_device_model::mojom::LoadAdaptationParamsPtr params,
base::OnceClosure on_complete) { … }
DISABLE_CFI_DLSYM
LoadModelResult OnDeviceModelExecutor::Init(
on_device_model::mojom::LoadModelParamsPtr params,
base::OnceClosure on_complete) { … }
void OnDeviceModelExecutor::Schedule(uintptr_t context,
std::function<void()>* fn) { … }
}