#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_TFLITE_MODEL_EXECUTOR_H_
#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_TFLITE_MODEL_EXECUTOR_H_
#include <optional>
#include "base/files/memory_mapped_file.h"
#include "base/functional/bind.h"
#include "base/functional/callback_forward.h"
#include "base/logging.h"
#include "base/metrics/histogram.h"
#include "base/metrics/histogram_functions.h"
#include "base/sequence_checker.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/thread_pool.h"
#include "base/time/time.h"
#include "base/timer/elapsed_timer.h"
#include "base/trace_event/trace_event.h"
#include "base/types/expected.h"
#include "components/optimization_guide/core/execution_status.h"
#include "components/optimization_guide/core/model_enums.h"
#include "components/optimization_guide/core/model_execution_timeout_watchdog.h"
#include "components/optimization_guide/core/model_executor.h"
#include "components/optimization_guide/core/model_util.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/machine_learning_tflite_buildflags.h"
#include "third_party/tflite/src/tensorflow/lite/c/common.h"
#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h"
namespace optimization_guide {
namespace {
class ScopedExecutionStatusResultRecorder { … };
}
template <class OutputType,
class InputType,
class ModelExecutionTaskType =
tflite::task::core::BaseTaskApi<OutputType, InputType>>
class TFLiteModelExecutor : public ModelExecutor<OutputType, InputType> {
public:
TFLiteModelExecutor()
: … { … }
~TFLiteModelExecutor() override { … }
void InitializeAndMoveToExecutionThread(
std::optional<base::TimeDelta> model_inference_timeout,
proto::OptimizationTarget optimization_target,
scoped_refptr<base::SequencedTaskRunner> execution_task_runner,
scoped_refptr<base::SequencedTaskRunner> reply_task_runner) override { … }
void UpdateModelFile(
base::optional_ref<const base::FilePath> file_path) override { … }
void SetShouldUnloadModelOnComplete(
bool should_unload_model_on_complete) override { … }
void SetShouldPreloadModel(bool should_preload_model) override { … }
void UnloadModel() override { … }
using ExecutionCallback =
base::OnceCallback<void(const std::optional<OutputType>&)>;
using BatchExecutionCallback =
base::OnceCallback<void(const std::vector<std::optional<OutputType>>&)>;
void SendForExecution(ExecutionCallback callback_on_complete,
base::TimeTicks start_time,
InputType input) override { … }
void SendForBatchExecution(
BatchExecutionCallback callback_on_complete,
base::TimeTicks start_time,
ModelExecutor<OutputType, InputType>::ConstRefInputVector inputs)
override { … }
std::vector<std::optional<OutputType>> SendForBatchExecutionSync(
ModelExecutor<OutputType, InputType>::ConstRefInputVector inputs)
override { … }
base::WeakPtr<TFLiteModelExecutor> GetWeakPtrForExecutionThread() { … }
TFLiteModelExecutor(const TFLiteModelExecutor&) = delete;
TFLiteModelExecutor& operator=(const TFLiteModelExecutor&) = delete;
protected:
using ModelExecutionTask =
tflite::task::core::BaseTaskApi<OutputType, InputType>;
virtual std::optional<OutputType> Execute(
ModelExecutionTaskType* execution_task,
ExecutionStatus* out_status,
InputType args) = 0;
virtual base::expected<std::unique_ptr<ModelExecutionTaskType>,
ExecutionStatus>
BuildModelExecutionTask(base::MemoryMappedFile* model_file) = 0;
private:
void LoadModelFile(
base::OnceCallback<void(ExecutionStatus)> model_loaded_callback) { … }
void OnModelFileLoadedInMemory(
base::OnceCallback<void(ExecutionStatus)> model_loaded_callback,
std::pair<ExecutionStatus, std::unique_ptr<base::MemoryMappedFile>>
status_and_model_fb) { … }
void LoadModelFileAndBatchExecute(
BatchExecutionCallback callback_on_complete,
ModelExecutor<OutputType, InputType>::ConstRefInputVector inputs) { … }
void BatchExecuteLoadedModel(
ModelExecutor<OutputType, InputType>::ConstRefInputVector inputs,
std::vector<std::optional<OutputType>>* outputs) { … }
void BatchExecuteLoadedModelAndRunCallback(
BatchExecutionCallback callback_on_complete,
ModelExecutor<OutputType, InputType>::ConstRefInputVector inputs,
ExecutionStatus execution_status) { … }
void OnExecutionComplete() { … }
base::OnceClosure MakeCancelClosure() { … }
proto::OptimizationTarget optimization_target_ =
proto::OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN;
bool should_unload_model_on_complete_ = true;
bool should_preload_model_ = false;
std::unique_ptr<ModelExecutionTimeoutWatchdog, base::OnTaskRunnerDeleter>
watchdog_;
scoped_refptr<base::SequencedTaskRunner> execution_task_runner_;
scoped_refptr<base::SequencedTaskRunner> reply_task_runner_;
scoped_refptr<base::SequencedTaskRunner> model_loading_task_runner_;
std::optional<base::TimeTicks> last_execution_time_
GUARDED_BY_CONTEXT(sequence_checker_);
std::optional<base::FilePath> model_file_path_
GUARDED_BY_CONTEXT(sequence_checker_);
std::unique_ptr<ModelExecutionTaskType> loaded_model_
GUARDED_BY_CONTEXT(sequence_checker_);
std::unique_ptr<base::MemoryMappedFile> model_fb_
GUARDED_BY_CONTEXT(sequence_checker_);
SEQUENCE_CHECKER(sequence_checker_);
base::WeakPtrFactory<TFLiteModelExecutor>
execution_sequence_weak_ptr_factory_{this};
};
}
#endif