#ifdef UNSAFE_BUFFERS_BUILD
#pragma allow_unsafe_buffers
#endif
#include "components/segmentation_platform/internal/database/test_segment_info_database.h"
#include <optional>
#include "base/containers/contains.h"
#include "components/segmentation_platform/internal/metadata/metadata_writer.h"
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
#include "components/segmentation_platform/public/proto/types.pb.h"
namespace segmentation_platform::test {
TestSegmentInfoDatabase::TestSegmentInfoDatabase()
: … { … }
TestSegmentInfoDatabase::~TestSegmentInfoDatabase() = default;
void TestSegmentInfoDatabase::Initialize(SuccessCallback callback) { … }
void TestSegmentInfoDatabase::GetSegmentInfoForSegments(
const base::flat_set<SegmentId>& segment_ids,
MultipleSegmentInfoCallback callback) { … }
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList>
TestSegmentInfoDatabase::GetSegmentInfoForBothModels(
const base::flat_set<SegmentId>& segment_ids) { … }
const SegmentInfo* TestSegmentInfoDatabase::GetCachedSegmentInfo(
SegmentId segment_id,
ModelSource model_source) { … }
void TestSegmentInfoDatabase::UpdateSegment(
SegmentId segment_id,
ModelSource model_source,
std::optional<proto::SegmentInfo> segment_info,
SuccessCallback callback) { … }
void TestSegmentInfoDatabase::SaveSegmentResult(
SegmentId segment_id,
ModelSource model_source,
std::optional<proto::PredictionResult> result,
SuccessCallback callback) { … }
void TestSegmentInfoDatabase::SaveTrainingData(SegmentId segment_id,
ModelSource model_source,
const proto::TrainingData& data,
SuccessCallback callback) { … }
void TestSegmentInfoDatabase::GetTrainingData(SegmentId segment_id,
ModelSource model_source,
TrainingRequestId request_id,
bool delete_from_db,
TrainingDataCallback callback) { … }
void TestSegmentInfoDatabase::AddUserActionFeature(
SegmentId segment_id,
const std::string& name,
uint64_t bucket_count,
uint64_t tensor_length,
proto::Aggregation aggregation,
ModelSource model_source) { … }
void TestSegmentInfoDatabase::AddHistogramValueFeature(
SegmentId segment_id,
const std::string& name,
uint64_t bucket_count,
uint64_t tensor_length,
proto::Aggregation aggregation,
ModelSource model_source) { … }
void TestSegmentInfoDatabase::AddHistogramEnumFeature(
SegmentId segment_id,
const std::string& name,
uint64_t bucket_count,
uint64_t tensor_length,
proto::Aggregation aggregation,
const std::vector<int32_t>& accepted_enum_ids,
ModelSource model_source) { … }
void TestSegmentInfoDatabase::AddSqlFeature(
SegmentId segment_id,
const MetadataWriter::SqlFeature& feature,
ModelSource model_source) { … }
void TestSegmentInfoDatabase::AddPredictionResult(SegmentId segment_id,
float score,
base::Time timestamp,
ModelSource model_source) { … }
void TestSegmentInfoDatabase::AddDiscreteMapping(
SegmentId segment_id,
const float mappings[][2],
int num_pairs,
const std::string& discrete_mapping_key,
ModelSource model_source) { … }
void TestSegmentInfoDatabase::SetBucketDuration(SegmentId segment_id,
uint64_t bucket_duration,
proto::TimeUnit time_unit,
ModelSource model_source) { … }
proto::SegmentInfo* TestSegmentInfoDatabase::FindOrCreateSegment(
SegmentId segment_id,
ModelSource model_source) { … }
}