// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#import "ios/chrome/browser/segmentation_platform/model/ukm_data_manager_test_utils.h"
#import "base/run_loop.h"
#import "components/history/core/browser/history_service.h"
#import "components/keyed_service/core/service_access_type.h"
#import "components/segmentation_platform/embedder/model_provider_factory_impl.h"
#import "components/segmentation_platform/internal/database/ukm_database.h"
#import "components/segmentation_platform/internal/execution/mock_model_provider.h"
#import "components/segmentation_platform/internal/metadata/metadata_writer.h"
#import "components/segmentation_platform/internal/segmentation_platform_service_impl.h"
#import "components/segmentation_platform/internal/signals/ukm_observer.h"
#import "components/segmentation_platform/internal/ukm_data_manager.h"
#import "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
#import "ios/chrome/browser/history/model/history_service_factory.h"
#import "ios/chrome/browser/segmentation_platform/model/segmentation_platform_service_factory.h"
#import "ios/chrome/browser/segmentation_platform/model/ukm_database_client.h"
#import "services/metrics/public/cpp/ukm_builders.h"
#import "testing/gtest/include/gtest/gtest.h"
#import "url/gurl.h"
namespace segmentation_platform {
namespace {
using ::segmentation_platform::proto::SegmentId;
using ::testing::Return;
using ::ukm::builders::PageLoad;
// Returns a sample UKM entry.
ukm::mojom::UkmEntryPtr GetSamplePageLoadEntry(ukm::SourceId source_id) {
ukm::mojom::UkmEntryPtr entry = ukm::mojom::UkmEntry::New();
entry->source_id = source_id;
entry->event_hash = PageLoad::kEntryNameHash;
entry->metrics[PageLoad::kCpuTimeNameHash] = 10;
entry->metrics[PageLoad::kIsNewBookmarkNameHash] = 20;
entry->metrics[PageLoad::kIsNTPCustomLinkNameHash] = 30;
return entry;
}
// Runs the given query and returns the result as float value. See
// RunReadOnlyQueries() for more info.
std::optional<float> RunQueryAndGetResult(UkmDatabase* database,
UkmDatabase::CustomSqlQuery&& query) {
std::optional<float> output;
UkmDatabase::QueryList queries;
queries.emplace(0, std::move(query));
base::RunLoop wait_for_query;
database->RunReadOnlyQueries(
std::move(queries),
base::BindOnce(
[](base::OnceClosure quit, std::optional<float>* output, bool success,
processing::IndexedTensors tensor) {
if (success) {
EXPECT_EQ(1u, tensor.size());
EXPECT_EQ(1u, tensor.at(0).size());
*output = tensor.at(0)[0].float_val;
}
std::move(quit).Run();
},
wait_for_query.QuitClosure(), &output));
wait_for_query.Run();
return output;
}
} // namespace
UkmDataManagerTestUtils::UkmDataManagerTestUtils(
ukm::TestUkmRecorder* ukm_recorder,
bool owned_db_client)
: ukm_recorder_(ukm_recorder) {
if (owned_db_client) {
owned_db_client_ = std::make_unique<UkmDatabaseClient>();
ukm_database_client_ = owned_db_client_.get();
} else {
ukm_database_client_ = &UkmDatabaseClientHolder::GetClientInstance(nullptr);
}
}
UkmDataManagerTestUtils::~UkmDataManagerTestUtils() {
#if !BUILDFLAG(IS_ANDROID)
// The client should be torn down after profile is destroyed. On Android
// browser tests the profile is never destroyed, so do not tear down the
// client.
ukm_database_client_->TearDownForTesting();
#endif
ukm_database_client_ = nullptr;
}
void UkmDataManagerTestUtils::PreProfileInit(
const std::map<SegmentId, proto::SegmentationModelMetadata>&
default_overrides) {
// Set test recorder before UkmObserver is created.
ukm_database_client_->set_ukm_recorder_for_testing(ukm_recorder_);
for (const auto& segment : default_overrides) {
auto provider = std::make_unique<MockDefaultModelProvider>(segment.first,
segment.second);
default_overrides_[segment.first] = provider.get();
// Default model must be overridden before the platform is created:
TestDefaultModelOverride::GetInstance().SetModelForTesting(
segment.first, std::move(provider));
}
if (owned_db_client_) {
owned_db_client_->PreProfileInit(/*in_memory_database=*/true);
}
}
void UkmDataManagerTestUtils::SetupForProfile(ChromeBrowserState* profile) {
UkmDatabaseClientHolder::SetUkmClientForTesting(profile,
ukm_database_client_.get());
CHECK_EQ(ukm_database_client_.get(),
&UkmDatabaseClientHolder::GetClientInstance(profile));
history_service_ = ios::HistoryServiceFactory::GetForBrowserState(
profile, ServiceAccessType::EXPLICIT_ACCESS);
// Create the platform to kick off initialization.
segmentation_platform::SegmentationPlatformServiceFactory::GetForBrowserState(
profile);
}
void UkmDataManagerTestUtils::WillDestroyProfile(ChromeBrowserState* profile) {
UkmDatabaseClientHolder::SetUkmClientForTesting(profile, nullptr);
}
void UkmDataManagerTestUtils::WaitForUkmObserverRegistration() {
UkmObserver* observer = ukm_database_client_->ukm_observer_for_testing();
while (!observer->is_started_for_testing()) {
base::RunLoop().RunUntilIdle();
}
}
proto::SegmentationModelMetadata
UkmDataManagerTestUtils::GetSamplePageLoadMetadata(const std::string& query) {
proto::SegmentationModelMetadata metadata;
MetadataWriter writer(&metadata);
writer.AddOutputConfigForBinaryClassifier(
/*threshold=*/0.5f,
/*positive_label=*/"Show",
/*negative_label=*/"NotShow");
metadata.set_time_unit(proto::TimeUnit::DAY);
metadata.set_bucket_duration(42u);
auto* feature = metadata.add_input_features();
auto* sql_feature = feature->mutable_sql_feature();
sql_feature->set_sql(query);
auto* ukm_event = sql_feature->mutable_signal_filter()->add_ukm_events();
ukm_event->set_event_hash(PageLoad::kEntryNameHash);
ukm_event->add_metric_hash_filter(PageLoad::kCpuTimeNameHash);
ukm_event->add_metric_hash_filter(PageLoad::kIsNewBookmarkNameHash);
return metadata;
}
void UkmDataManagerTestUtils::RecordPageLoadUkm(const GURL& url,
base::Time history_timestamp) {
UkmObserver* observer = ukm_database_client_->ukm_observer_for_testing();
// Ensure that the observer is started before recording metrics.
ASSERT_TRUE(observer->is_started_for_testing());
// Ensure that OTR profiles are not started in the test.
ASSERT_FALSE(observer->is_paused_for_testing());
ukm_recorder_->AddEntry(GetSamplePageLoadEntry(source_id_counter_));
ukm_recorder_->UpdateSourceURL(source_id_counter_, url);
source_id_counter_++;
// Without a history service the recorded URLs will not be written to
// database.
ASSERT_TRUE(history_service_);
history_service_->AddPage(url, history_timestamp,
history::VisitSource::SOURCE_BROWSED);
}
bool UkmDataManagerTestUtils::IsUrlInDatabase(const GURL& url) {
UkmDatabase::CustomSqlQuery query("SELECT 1 FROM urls WHERE url=?",
{processing::ProcessedValue(url.spec())});
std::optional<float> result = RunQueryAndGetResult(
ukm_database_client_->GetUkmDataManager()->GetUkmDatabase(),
std::move(query));
return !!result;
}
MockDefaultModelProvider* UkmDataManagerTestUtils::GetDefaultOverride(
proto::SegmentId segment_id) {
return default_overrides_[segment_id];
}
} // namespace segmentation_platform