#include "components/segmentation_platform/internal/selection/request_dispatcher.h"
#include <set>
#include <utility>
#include "base/containers/circular_deque.h"
#include "base/functional/callback_forward.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/memory/scoped_refptr.h"
#include "base/task/single_thread_task_runner.h"
#include "base/time/time.h"
#include "components/segmentation_platform/internal/database/config_holder.h"
#include "components/segmentation_platform/internal/post_processor/post_processor.h"
#include "components/segmentation_platform/internal/selection/request_handler.h"
#include "components/segmentation_platform/internal/selection/segment_result_provider.h"
#include "components/segmentation_platform/internal/stats.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/prediction_options.h"
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
#include "components/segmentation_platform/public/result.h"
namespace segmentation_platform {
namespace {
const int kModelInitializationTimeoutMs = …;
void PostProcess(const RawResult& raw_result, ClassificationResult& result) { … }
void PostProcess(const RawResult& raw_result, AnnotatedNumericResult& result) { … }
}
RequestDispatcher::RequestDispatcher(StorageService* storage_service)
: … { … }
RequestDispatcher::~RequestDispatcher() = default;
void RequestDispatcher::OnPlatformInitialized(
bool success,
ExecutionService* execution_service,
std::map<std::string, std::unique_ptr<SegmentResultProvider>>
result_providers) { … }
void RequestDispatcher::ExecuteAllPendingActions() { … }
void RequestDispatcher::ExecutePendingActionsForKey(
const std::string& segmentation_key) { … }
void RequestDispatcher::OnModelUpdated(proto::SegmentId segment_id) { … }
void RequestDispatcher::OnModelInitializationTimeout() { … }
template <typename ResultType>
void RequestDispatcher::CallbackWrapper(
const std::string& segmentation_key,
base::Time start_time,
base::OnceCallback<void(const ResultType&)> callback,
bool is_cached_result,
const RawResult& raw_result) { … }
void RequestDispatcher::GetModelResult(
const std::string& segmentation_key,
const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
WrappedCallback callback) { … }
void RequestDispatcher::ExecuteOnDemand(
const std::string& segmentation_key,
const Config* config,
const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
WrappedCallback callback) { … }
void RequestDispatcher::OnFinishedOnDemandExecution(
const std::string& segmentation_key,
const Config* config,
const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
WrappedCallback callback,
const RawResult& raw_result) { … }
void RequestDispatcher::HandleCachedExecution(
const std::string& segmentation_key,
const Config* config,
const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
WrappedCallback callback) { … }
void RequestDispatcher::GetClassificationResult(
const std::string& segmentation_key,
const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
ClassificationResultCallback callback) { … }
void RequestDispatcher::GetAnnotatedNumericResult(
const std::string& segmentation_key,
const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
AnnotatedNumericResultCallback callback) { … }
int RequestDispatcher::GetPendingActionCountForTesting() { … }
}