chromium/components/optimization_guide/core/tflite_model_executor.h

// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_TFLITE_MODEL_EXECUTOR_H_
#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_TFLITE_MODEL_EXECUTOR_H_

#include <optional>

#include "base/files/memory_mapped_file.h"
#include "base/functional/bind.h"
#include "base/functional/callback_forward.h"
#include "base/logging.h"
#include "base/metrics/histogram.h"
#include "base/metrics/histogram_functions.h"
#include "base/sequence_checker.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/thread_pool.h"
#include "base/time/time.h"
#include "base/timer/elapsed_timer.h"
#include "base/trace_event/trace_event.h"
#include "base/types/expected.h"
#include "components/optimization_guide/core/execution_status.h"
#include "components/optimization_guide/core/model_enums.h"
#include "components/optimization_guide/core/model_execution_timeout_watchdog.h"
#include "components/optimization_guide/core/model_executor.h"
#include "components/optimization_guide/core/model_util.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/machine_learning_tflite_buildflags.h"
#include "third_party/tflite/src/tensorflow/lite/c/common.h"
#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h"

namespace optimization_guide {

namespace {

// Util class for recording the result of the model execution. The result is
// recorded when it goes out of scope and its destructor is called.
class ScopedExecutionStatusResultRecorder {};

}  // namespace

// An ModelExecutor that executes tflite models with arbitrary
// input and output types. Note that callers will need to give an implementation
// of this class to a |ModelHandler|, whereas the
// handle is the actual class that calling code would own and call into.
//
// By default, the model file will be (re)loaded for every execution and then
// unloaded from memory after every execution (e.g.: "OnComplete"). This helps
// to keep memory usage of the browser process down, but does delay model
// execution by the time it takes to load the model (about 50ms in practice).
// See |SetShouldUnloadModelOnComplete| to override this behavior.
//
// Note that when built with the MediaPipe backend (non-default), task
// cancellation is not supported.
template <class OutputType,
          class InputType,
          // TODO(b/283522287): Remove this once all usage of TFLite Task
          // Support are replaced by MediaPipe.
          class ModelExecutionTaskType =
              tflite::task::core::BaseTaskApi<OutputType, InputType>>
class TFLiteModelExecutor : public ModelExecutor<OutputType, InputType> {
 public:
  TFLiteModelExecutor()
      :{}

  ~TFLiteModelExecutor() override {}

  // Should be called on the same sequence as the ctor, but once called |this|
  // must only be used from the |execution_task_runner| thread/sequence.
  void InitializeAndMoveToExecutionThread(
      std::optional<base::TimeDelta> model_inference_timeout,
      proto::OptimizationTarget optimization_target,
      scoped_refptr<base::SequencedTaskRunner> execution_task_runner,
      scoped_refptr<base::SequencedTaskRunner> reply_task_runner) override {}

  // Called when a model file is available to load. Immediately loads model into
  // memory when `should_preload_model_` is set.
  void UpdateModelFile(
      base::optional_ref<const base::FilePath> file_path) override {}

  // Calling this method allows the default model loading/unloading behavior to
  // be overridden. Setting this to false will cause the model to remain loaded
  // afterwards a model execution (e.g.: "OnComplete"), until |UnloadModel| is
  // called. False is the default behavior (see class comment).
  //
  // Note that keeping the model in memory for a long duration may be detected
  // as a memory leak in Chrome, and will always increase the private or shared
  // memory used by the browser by the size of the model file and the
  // constructed TFLite graph.
  void SetShouldUnloadModelOnComplete(
      bool should_unload_model_on_complete) override {}

  // Calling this method allows the default model preloading behavior to
  // be overridden. Setting this to true will cause the model to be loaded as
  // soon as its file path is available. Callers may also need to call
  // `SetShouldUnloadModelOnComplete(true)` to keep the model in memory for the
  // lifetime of the entire browsing session.
  //
  // Note that keeping the model in memory for a long duration may be detected
  // as a memory leak in Chrome, and will always increase the private or shared
  // memory used by the browser by the size of the model file and the
  // constructed TFLite graph.
  void SetShouldPreloadModel(bool should_preload_model) override {}

  // Clears the loaded model from memory if it is loaded. Safe to call when the
  // model is already unloaded, and becomes a no-op.
  void UnloadModel() override {}

  using ExecutionCallback =
      base::OnceCallback<void(const std::optional<OutputType>&)>;
  using BatchExecutionCallback =
      base::OnceCallback<void(const std::vector<std::optional<OutputType>>&)>;

  // When complete, |callback_on_complete| will be run via |reply_task_runner_|
  // with the outputs of the model.
  void SendForExecution(ExecutionCallback callback_on_complete,
                        base::TimeTicks start_time,
                        InputType input) override {}

  // Starts the batch execution of the model. When complete,
  // |callback_on_complete| will be run via |reply_task_runner_| with the
  // outputs of the model.
  void SendForBatchExecution(
      BatchExecutionCallback callback_on_complete,
      base::TimeTicks start_time,
      ModelExecutor<OutputType, InputType>::ConstRefInputVector inputs)
      override {}

  // Starts the synchronous execution of the model. Returns model outputs.
  // Model needs to be loaded. Synchronous calls do not load or unload model.
  std::vector<std::optional<OutputType>> SendForBatchExecutionSync(
      ModelExecutor<OutputType, InputType>::ConstRefInputVector inputs)
      override {}

  // IMPORTANT: These WeakPointers must only be dereferenced on the
  // |execution_task_runner| thread.
  base::WeakPtr<TFLiteModelExecutor> GetWeakPtrForExecutionThread() {}

  TFLiteModelExecutor(const TFLiteModelExecutor&) = delete;
  TFLiteModelExecutor& operator=(const TFLiteModelExecutor&) = delete;

 protected:
  using ModelExecutionTask =
      tflite::task::core::BaseTaskApi<OutputType, InputType>;

  // Executes the model using |execution_task| on |args|, returning the model
  // output and setting |out_status| with the status of the execution attempt.
  virtual std::optional<OutputType> Execute(
      ModelExecutionTaskType* execution_task,
      ExecutionStatus* out_status,
      InputType args) = 0;

  // Builds a model execution task using |model_file|. On error, the returned
  // `ExecutionStatus` will never be `ExecutionStatus::kSuccess`.
  virtual base::expected<std::unique_ptr<ModelExecutionTaskType>,
                         ExecutionStatus>
  BuildModelExecutionTask(base::MemoryMappedFile* model_file) = 0;

 private:
  // Loads the model file in the background thread, and calls a callback on
  // model file loaded in memory on the model execution thread.
  void LoadModelFile(
      base::OnceCallback<void(ExecutionStatus)> model_loaded_callback) {}

  // Called on model file loaded in memory. Builds the model execution task from
  // the memory-mapped file, and calls `model_loaded_callback`.
  void OnModelFileLoadedInMemory(
      base::OnceCallback<void(ExecutionStatus)> model_loaded_callback,
      std::pair<ExecutionStatus, std::unique_ptr<base::MemoryMappedFile>>
          status_and_model_fb) {}

  // Loads the model file if not loaded yet on the background thread, and batch
  // executes it on the model execution thread.
  void LoadModelFileAndBatchExecute(
      BatchExecutionCallback callback_on_complete,
      ModelExecutor<OutputType, InputType>::ConstRefInputVector inputs) {}

  // Batch executes the loaded model for inputs.
  void BatchExecuteLoadedModel(
      ModelExecutor<OutputType, InputType>::ConstRefInputVector inputs,
      std::vector<std::optional<OutputType>>* outputs) {}

  // Batch executes the loaded model and runs callback on the reply thread.
  // Unloads the model if needed.
  void BatchExecuteLoadedModelAndRunCallback(
      BatchExecutionCallback callback_on_complete,
      ModelExecutor<OutputType, InputType>::ConstRefInputVector inputs,
      ExecutionStatus execution_status) {}

  void OnExecutionComplete() {}

  base::OnceClosure MakeCancelClosure() {}

  proto::OptimizationTarget optimization_target_ =
      proto::OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN;

  bool should_unload_model_on_complete_ = true;

  bool should_preload_model_ = false;

  std::unique_ptr<ModelExecutionTimeoutWatchdog, base::OnTaskRunnerDeleter>
      watchdog_;

  // Main thread for model execution. For synchronous model execution, this
  // needs to be the same caller thread.
  scoped_refptr<base::SequencedTaskRunner> execution_task_runner_;

  // Arbitrary thread for running reply tasks.
  scoped_refptr<base::SequencedTaskRunner> reply_task_runner_;

  // Background thread for model loading file I/O.
  scoped_refptr<base::SequencedTaskRunner> model_loading_task_runner_;

  // The time that the model was last executed. Logged in metrics for the second
  // and following runs.
  std::optional<base::TimeTicks> last_execution_time_
      GUARDED_BY_CONTEXT(sequence_checker_);

  // The model file path to be loaded. May be nullopt if no model has been
  // downloaded yet.
  std::optional<base::FilePath> model_file_path_
      GUARDED_BY_CONTEXT(sequence_checker_);

  // Note on lifetimes: |loaded_model_| and |model_fb_| both share the same
  // lifetime, being set in |LoadModelFile()| and being destroyed in
  // |UnloadModel()|.

  std::unique_ptr<ModelExecutionTaskType> loaded_model_
      GUARDED_BY_CONTEXT(sequence_checker_);

  // This will only be non-null when |model_file_path_| is set, and while the
  // model is loaded which is managed by a feature flag.
  std::unique_ptr<base::MemoryMappedFile> model_fb_
      GUARDED_BY_CONTEXT(sequence_checker_);

  SEQUENCE_CHECKER(sequence_checker_);

  base::WeakPtrFactory<TFLiteModelExecutor>
      execution_sequence_weak_ptr_factory_{this};
};

}  // namespace optimization_guide

#endif  // COMPONENTS_OPTIMIZATION_GUIDE_CORE_TFLITE_MODEL_EXECUTOR_H_