chromium/services/on_device_model/ml/on_device_model_executor.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "services/on_device_model/ml/on_device_model_executor.h"

#include <algorithm>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "base/check.h"
#include "base/compiler_specific.h"
#include "base/containers/unique_ptr_adapters.h"
#include "base/logging.h"
#include "base/memory/raw_ref.h"
#include "base/memory/scoped_refptr.h"
#include "base/memory/weak_ptr.h"
#include "base/metrics/field_trial_params.h"
#include "base/metrics/histogram_functions.h"
#include "base/numerics/safe_conversions.h"
#include "base/task/thread_pool.h"
#include "base/timer/elapsed_timer.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "services/on_device_model/ml/chrome_ml.h"
#include "services/on_device_model/ml/session_accessor.h"
#include "services/on_device_model/public/mojom/on_device_model.mojom.h"
#include "services/on_device_model/public/mojom/on_device_model_service.mojom.h"

#if BUILDFLAG(IS_MAC)
#include "base/apple/foundation_util.h"
#endif

LoadModelResult;

namespace ml {
namespace {

constexpr uint32_t kReserveTokensForSafety =;

const base::FeatureParam<bool> kPreferTextureWeights{};

const base::FeatureParam<bool> kEnableHostMappedPointer{};

const base::FeatureParam<bool> kUseLowPower{};

const base::FeatureParam<bool> kAllowFp16{};

// Helper to bind object methods as weak task-posting callback functions.
template <typename R, typename C, typename... Args>
std::function<R(Args...)> CreateWeakCallbackFn(R (C::*method)(Args...),
                                               C* that) {}

// Helper to convert a OnceCallback to std::function.
template <typename R, typename... Args>
std::function<R(Args...)> ConvertCallbackToFn(
    base::OnceCallback<R(Args...)> callback) {}

int CalculateTokensPerSecond(int num_tokens, base::TimeDelta duration) {}

float GetTemperature(std::optional<float> temperature) {}

uint32_t GetTopK(std::optional<uint32_t> top_k) {}

std::optional<ModelBackendType> ModelBackendTypeFromMojom(
    on_device_model::mojom::ModelBackendType backend) {}

}  // namespace

// Handles sending and canceling responses.
class Responder final {};

// Handles calling the ContextClient on completion and canceling the context
// request.
class ContextHolder final {};

SessionImpl::SessionImpl(const ChromeML& chrome_ml,
                         ChromeMLModel model,
                         SessionAccessor::Ptr session,
                         SessionAccessor::Ptr empty_session,
                         uint32_t max_tokens,
                         std::optional<uint32_t> adaptation_id)
    :{}
SessionImpl::~SessionImpl() = default;

DISABLE_CFI_DLSYM
void SessionImpl::AddContext(
    on_device_model::mojom::InputOptionsPtr input,
    mojo::PendingRemote<on_device_model::mojom::ContextClient> client,
    base::OnceClosure on_complete) {}

DISABLE_CFI_DLSYM
void SessionImpl::Execute(
    on_device_model::mojom::InputOptionsPtr input,
    mojo::PendingRemote<on_device_model::mojom::StreamingResponder> response,
    base::OnceClosure on_complete) {}

DISABLE_CFI_DLSYM
void SessionImpl::SizeInTokens(const std::string& text,
                               base::OnceCallback<void(uint32_t)> callback) {}

DISABLE_CFI_DLSYM
void SessionImpl::Score(const std::string& text,
                        base::OnceCallback<void(float)> callback) {}

std::unique_ptr<SessionImpl> SessionImpl::Clone() {}

void SessionImpl::RemoveContext(ContextHolder* context) {}

DISABLE_CFI_DLSYM
void DestroyModel(const ChromeML* chrome_ml, ChromeMLModel model) {}

OnDeviceModelExecutor::OnDeviceModelExecutor(
    base::PassKey<OnDeviceModelExecutor>,
    const ChromeML& chrome_ml)
    :{}

OnDeviceModelExecutor::~OnDeviceModelExecutor() {}

// static
base::expected<std::unique_ptr<OnDeviceModelExecutor>, LoadModelResult>
OnDeviceModelExecutor::CreateWithResult(
    const ChromeML& chrome_ml,
    on_device_model::mojom::LoadModelParamsPtr params,
    base::OnceClosure on_complete) {}

std::unique_ptr<SessionImpl> OnDeviceModelExecutor::CreateSession(
    std::optional<uint32_t> adaptation_id) {}

void OnDeviceModelExecutor::DetectLanguage(
    const std::string& text,
    on_device_model::mojom::OnDeviceModel::DetectLanguageCallback callback) {}

DISABLE_CFI_DLSYM
void OnDeviceModelExecutor::ClassifyTextSafety(
    const std::string& text,
    on_device_model::mojom::OnDeviceModel::ClassifyTextSafetyCallback
        callback) {}

DISABLE_CFI_DLSYM
base::expected<uint32_t, LoadModelResult> OnDeviceModelExecutor::LoadAdaptation(
    on_device_model::mojom::LoadAdaptationParamsPtr params,
    base::OnceClosure on_complete) {}

DISABLE_CFI_DLSYM
LoadModelResult OnDeviceModelExecutor::Init(
    on_device_model::mojom::LoadModelParamsPtr params,
    base::OnceClosure on_complete) {}

// static
void OnDeviceModelExecutor::Schedule(uintptr_t context,
                                     std::function<void()>* fn) {}

}  // namespace ml