#include "components/segmentation_platform/internal/data_collection/training_data_collector_impl.h"
#include <cstdint>
#include "base/containers/contains.h"
#include "base/functional/callback_helpers.h"
#include "base/logging.h"
#include "base/metrics/field_trial_params.h"
#include "base/metrics/metrics_hashes.h"
#include "base/metrics/user_metrics.h"
#include "base/notreached.h"
#include "base/rand_util.h"
#include "base/task/single_thread_task_runner.h"
#include "base/time/clock.h"
#include "base/time/time.h"
#include "components/segmentation_platform/internal/config_parser.h"
#include "components/segmentation_platform/internal/constants.h"
#include "components/segmentation_platform/internal/data_collection/training_data_cache.h"
#include "components/segmentation_platform/internal/database/cached_result_provider.h"
#include "components/segmentation_platform/internal/database/config_holder.h"
#include "components/segmentation_platform/internal/database/signal_storage_config.h"
#include "components/segmentation_platform/internal/execution/processing/feature_list_query_processor.h"
#include "components/segmentation_platform/internal/metadata/metadata_utils.h"
#include "components/segmentation_platform/internal/platform_options.h"
#include "components/segmentation_platform/internal/segmentation_ukm_helper.h"
#include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/local_state_helper.h"
#include "components/segmentation_platform/public/proto/model_metadata.pb.h"
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
#include "components/segmentation_platform/public/trigger.h"
#include "services/metrics/public/cpp/ukm_source_id.h"
namespace segmentation_platform {
namespace {
FeatureListQueryProcessor;
static int kMinimumReportingIntervalInHours = …;
base::Time GetNextReportTime(base::Time last_report_time) { … }
std::map<SegmentId, const proto::SegmentInfo*> GetPreferredSegmentInfo(
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> segment_list) { … }
bool IsPeriodic(const proto::SegmentInfo& info) { … }
bool NeedsExactPredictionTime(const proto::SegmentInfo& segment_info) { … }
constexpr base::FeatureParam<int> TimeDelaySamplingRate{ … };
}
struct TrainingDataCollectorImpl::TrainingTimings { … };
TrainingDataCollectorImpl::TrainingDataCollectorImpl(
const PlatformOptions& platform_options,
processing::FeatureListQueryProcessor* processor,
HistogramSignalHandler* histogram_signal_handler,
UserActionSignalHandler* user_action_signal_handler,
StorageService* storage_service,
PrefService* profile_prefs,
base::Clock* clock,
CachedResultProvider* cached_result_provider)
: … { … }
TrainingDataCollectorImpl::~TrainingDataCollectorImpl() { … }
void TrainingDataCollectorImpl::OnModelMetadataUpdated() { … }
void TrainingDataCollectorImpl::OnServiceInitialized() { … }
void TrainingDataCollectorImpl::OnGetSegmentsInfoList(
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> segments) { … }
void TrainingDataCollectorImpl::OnHistogramSignalUpdated(
const std::string& histogram_name,
base::HistogramBase::Sample sample) { … }
void TrainingDataCollectorImpl::OnUserAction(const std::string& user_action,
base::TimeTicks action_time) { … }
void TrainingDataCollectorImpl::SetSamplingRateForTesting(
uint64_t sampling_rate) { … }
void TrainingDataCollectorImpl::OnUmaUpdatedReportForSegmentInfo(
const std::optional<ImmediateCollectionParam>& param,
const proto::SegmentInfo* segment) { … }
bool TrainingDataCollectorImpl::CanReportTrainingData(
const proto::SegmentInfo& segment_info,
bool include_output) { … }
void TrainingDataCollectorImpl::OnGetTrainingTensors(
const std::optional<ImmediateCollectionParam>& param,
const proto::SegmentInfo& segment_info,
bool has_error,
const ModelProvider::Request& input_tensors,
const ModelProvider::Response& output_tensors) { … }
void TrainingDataCollectorImpl::ReportCollectedContinuousTrainingData() { … }
void TrainingDataCollectorImpl::CollectTrainingData(
SegmentId segment_id,
TrainingRequestId request_id,
ukm::SourceId ukm_source_id,
const TrainingLabels& param,
SuccessCallback callback) { … }
TrainingRequestId TrainingDataCollectorImpl::OnDecisionTime(
proto::SegmentId segment_id,
scoped_refptr<InputContext> input_context,
DecisionType type,
std::optional<ModelProvider::Request> inputs,
bool decision_result_update_trigger) { … }
void TrainingDataCollectorImpl::OnGetSegmentInfoAtDecisionTime(
proto::SegmentId segment_id,
TrainingRequestId request_id,
DecisionType type,
scoped_refptr<InputContext> input_context,
const proto::SegmentInfo& segment_info,
std::optional<ModelProvider::Request> inputs) { … }
void TrainingDataCollectorImpl::OnGetTrainingTensorsAtDecisionTime(
TrainingRequestId request_id,
const TrainingTimings& training_request,
const proto::SegmentInfo& segment_info,
bool has_error,
const ModelProvider::Request& input_tensors,
const ModelProvider::Response& output_tensors) { … }
void TrainingDataCollectorImpl::PostObservationTask(
TrainingRequestId request_id,
const proto::SegmentInfo& segment_info,
const base::TimeDelta& delay,
stats::TrainingDataCollectionEvent event) { … }
void TrainingDataCollectorImpl::OnObservationTrigger(
const std::optional<ImmediateCollectionParam>& param,
TrainingRequestId request_id,
const proto::SegmentInfo& segment_info,
SuccessCallback callback) { … }
void TrainingDataCollectorImpl::OnGetStoredTrainingData(
const std::optional<ImmediateCollectionParam>& param,
const proto::SegmentInfo& segment_info,
SuccessCallback callback,
std::optional<proto::TrainingData> input) { … }
void TrainingDataCollectorImpl::OnGetOutputsOnObservationTrigger(
const std::optional<ImmediateCollectionParam>& param,
const proto::SegmentInfo& segment_info,
const ModelProvider::Request& cached_input_tensors,
bool has_error,
const ModelProvider::Request& input_tensors,
const ModelProvider::Response& output_tensors) { … }
TrainingDataCollectorImpl::TrainingTimings
TrainingDataCollectorImpl::ComputeDecisionTiming(
const proto::SegmentInfo& info) const { … }
base::Time TrainingDataCollectorImpl::ComputeObservationTiming(
const proto::SegmentInfo& info,
base::Time prediction_time) const { … }
bool TrainingDataCollectorImpl::FillTrainingData(
TrainingRequestId request_id,
const TrainingTimings& training_request,
const ModelProvider::Request& input_tensors,
const proto::SegmentInfo& segment_info,
proto::TrainingData& training_data) { … }
}