chromium/chromeos/ash/components/report/device_metrics/churn/cohort_impl.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromeos/ash/components/report/device_metrics/churn/cohort_impl.h"

#include "ash/constants/ash_features.h"
#include "chromeos/ash/components/report/device_metrics/churn/active_status.h"
#include "chromeos/ash/components/report/prefs/fresnel_pref_names.h"
#include "chromeos/ash/components/report/utils/device_metadata_utils.h"
#include "chromeos/ash/components/report/utils/network_utils.h"
#include "chromeos/ash/components/report/utils/psm_utils.h"
#include "chromeos/ash/components/report/utils/time_utils.h"
#include "chromeos/ash/components/report/utils/uma_utils.h"
#include "components/prefs/pref_service.h"
#include "services/network/public/cpp/resource_request.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "services/network/public/cpp/simple_url_loader.h"
#include "third_party/private_membership/src/private_membership_rlwe.pb.h"

namespace psm_rlwe = private_membership::rlwe;

namespace ash::report::device_metrics {

namespace {

// PSM use case enum for churn monthly cohort use case.
constexpr psm_rlwe::RlweUseCase kPsmUseCase =
    psm_rlwe::RlweUseCase::CROS_FRESNEL_CHURN_MONTHLY_COHORT;

}  // namespace

CohortImpl::CohortImpl(UseCaseParameters* params)
    : UseCase(params),
      active_status_(std::make_unique<ActiveStatus>(params->GetLocalState())) {}

CohortImpl::~CohortImpl() = default;

void CohortImpl::Run(base::OnceCallback<void()> callback) {
  callback_ = std::move(callback);

  if (!IsDevicePingRequired()) {
    utils::RecordIsDevicePingRequired(utils::PsmUseCase::kCohort, false);
    std::move(callback_).Run();
    return;
  }

  utils::RecordIsDevicePingRequired(utils::PsmUseCase::kCohort, true);

  // Perform check membership if the local state pref has default value.
  // This is done to avoid duplicate check in if the device pinged already.
  if (base::FeatureList::IsEnabled(
          features::kDeviceActiveClientChurnCohortCheckMembership) &&
      (GetLastPingTimestamp() == base::Time::UnixEpoch() ||
       GetLastPingTimestamp() == base::Time())) {
    CheckMembershipOprf();
  } else {
    CheckIn();
  }
}

base::WeakPtr<CohortImpl> CohortImpl::GetWeakPtr() {
  return weak_factory_.GetWeakPtr();
}

void CohortImpl::CheckMembershipOprf() {
  PsmClientManager* psm_client_manager = GetParams()->GetPsmClientManager();

  psm_client_manager->SetPsmRlweClient(kPsmUseCase, GetPsmIdentifiersToQuery());

  if (!psm_client_manager->GetPsmRlweClient()) {
    LOG(ERROR) << "Check membership failed since the PSM RLWE client could "
               << "not be initialized.";
    std::move(callback_).Run();
    return;
  }

  // Generate PSM Oprf request body.
  const auto status_or_oprf_request = psm_client_manager->CreateOprfRequest();
  if (!status_or_oprf_request.ok()) {
    LOG(ERROR) << "Failed to create OPRF request.";
    std::move(callback_).Run();
    return;
  }

  psm_rlwe::PrivateMembershipRlweOprfRequest oprf_request =
      status_or_oprf_request.value();

  // Wrap PSM Oprf request body by FresnelPsmRlweOprfRequest proto.
  // This proto is expected by the Fresnel service.
  report::FresnelPsmRlweOprfRequest fresnel_oprf_request;
  *fresnel_oprf_request.mutable_rlwe_oprf_request() = oprf_request;

  std::string request_body;
  fresnel_oprf_request.SerializeToString(&request_body);

  auto resource_request =
      utils::GenerateResourceRequest(utils::GetOprfRequestURL());

  url_loader_ = network::SimpleURLLoader::Create(
      std::move(resource_request), GetCheckMembershipTrafficTag());
  url_loader_->AttachStringForUpload(request_body, "application/x-protobuf");
  url_loader_->SetTimeoutDuration(utils::GetOprfRequestTimeout());
  url_loader_->DownloadToString(
      GetParams()->GetUrlLoaderFactory().get(),
      base::BindOnce(&CohortImpl::OnCheckMembershipOprfComplete,
                     weak_factory_.GetWeakPtr()),
      utils::GetMaxFresnelResponseSizeBytes());
}

void CohortImpl::OnCheckMembershipOprfComplete(
    std::unique_ptr<std::string> response_body) {
  // Use RAII to reset |url_loader_| after current function scope.
  auto url_loader = std::move(url_loader_);

  int net_code = url_loader->NetError();
  utils::RecordNetErrorCode(utils::PsmUseCase::kCohort,
                            utils::PsmRequest::kOprf, net_code);

  // Convert serialized response body to oprf response protobuf.
  FresnelPsmRlweOprfResponse psm_oprf_response;
  bool is_response_body_set = response_body.get() != nullptr;

  if (!is_response_body_set ||
      !psm_oprf_response.ParseFromString(*response_body)) {
    LOG(ERROR) << "Oprf response net code = " << net_code;
    LOG(ERROR) << "Response body was not set or could not be parsed into "
               << "FresnelPsmRlweOprfResponse proto. "
               << "Is response body set = " << is_response_body_set;
    std::move(callback_).Run();
    return;
  }

  if (!psm_oprf_response.has_rlwe_oprf_response()) {
    LOG(ERROR) << "Oprf response net code = " << net_code;
    LOG(ERROR) << "FresnelPsmRlweOprfResponse is missing the actual oprf "
                  "response from server.";
    std::move(callback_).Run();
    return;
  }

  psm_rlwe::PrivateMembershipRlweOprfResponse oprf_response =
      psm_oprf_response.rlwe_oprf_response();

  CheckMembershipQuery(oprf_response);
}

void CohortImpl::CheckMembershipQuery(
    const psm_rlwe::PrivateMembershipRlweOprfResponse& oprf_response) {
  PsmClientManager* psm_client_manager = GetParams()->GetPsmClientManager();

  // Generate PSM Query request body.
  const auto status_or_query_request =
      psm_client_manager->CreateQueryRequest(oprf_response);
  if (!status_or_query_request.ok()) {
    std::move(callback_).Run();
    return;
  }

  psm_rlwe::PrivateMembershipRlweQueryRequest query_request =
      status_or_query_request.value();

  // Wrap PSM Query request body by FresnelPsmRlweQueryRequest proto.
  // This proto is expected by the Fresnel service.
  report::FresnelPsmRlweQueryRequest fresnel_query_request;
  *fresnel_query_request.mutable_rlwe_query_request() = query_request;

  std::string request_body;
  fresnel_query_request.SerializeToString(&request_body);

  auto resource_request =
      utils::GenerateResourceRequest(utils::GetQueryRequestURL());

  url_loader_ = network ::SimpleURLLoader ::Create(
      std::move(resource_request), GetCheckMembershipTrafficTag());
  url_loader_->AttachStringForUpload(request_body, "application/x-protobuf");
  url_loader_->SetTimeoutDuration(utils::GetQueryRequestTimeout());
  url_loader_->DownloadToString(
      GetParams()->GetUrlLoaderFactory().get(),
      base::BindOnce(&CohortImpl::OnCheckMembershipQueryComplete,
                     weak_factory_.GetWeakPtr()),
      utils::GetMaxFresnelResponseSizeBytes());
}

void CohortImpl::OnCheckMembershipQueryComplete(
    std::unique_ptr<std::string> response_body) {
  // Use RAII to reset |url_loader_| after current function scope.
  auto url_loader = std::move(url_loader_);

  int net_code = url_loader->NetError();
  utils::RecordNetErrorCode(utils::PsmUseCase::kCohort,
                            utils::PsmRequest::kQuery, net_code);

  // Convert serialized response body to fresnel query response protobuf.
  FresnelPsmRlweQueryResponse psm_query_response;
  bool is_response_body_set = response_body.get() != nullptr;

  if (!is_response_body_set ||
      !psm_query_response.ParseFromString(*response_body)) {
    LOG(ERROR) << "Query response net code = " << net_code;
    LOG(ERROR) << "Response body was not set or could not be parsed into "
               << "FresnelPsmRlweQueryResponse proto. "
               << "Is response body set = " << is_response_body_set;
    std::move(callback_).Run();
    return;
  }

  if (!psm_query_response.has_rlwe_query_response()) {
    LOG(ERROR) << "Query response net code = " << net_code;
    LOG(ERROR) << "FresnelPsmRlweQueryResponse is missing the actual query "
                  "response from server.";
    std::move(callback_).Run();
    return;
  }

  psm_rlwe::PrivateMembershipRlweQueryResponse query_response =
      psm_query_response.rlwe_query_response();
  auto status_or_response =
      GetParams()->GetPsmClientManager()->ProcessQueryResponse(query_response);

  if (!status_or_response.ok()) {
    LOG(ERROR) << "Failed to process query response.";
    std::move(callback_).Run();
    return;
  }

  psm_rlwe::RlweMembershipResponses rlwe_membership_responses =
      status_or_response.value();

  // TODO: Update logic below here to handle cohort check membership...
  if (rlwe_membership_responses.membership_responses_size() == 0) {
    LOG(ERROR) << "Check Membership for Cohort should query for greater "
               << "than 0 memberships. Size = "
               << rlwe_membership_responses.membership_responses_size();
    std::move(callback_).Run();
    return;
  }

  LOG(ERROR) << "TODO: Implement logic to find last ping Cohort use case. ";
  private_membership::MembershipResponse membership_response =
      rlwe_membership_responses.membership_responses(0).membership_response();

  bool is_psm_id_member = membership_response.is_member();

  if (is_psm_id_member) {
    LOG(ERROR) << "Check In ping was already sent this month.";
    SetLastPingTimestamp(GetParams()->GetActiveTs());
    std::move(callback_).Run();
    return;
  }

  CheckIn();
}

void CohortImpl::CheckIn() {
  std::optional<FresnelImportDataRequest> import_request =
      GenerateImportRequestBody();
  if (!import_request.has_value()) {
    LOG(ERROR) << "Failed to create the import request body.";
    std::move(callback_).Run();
    return;
  }

  std::string request_body;
  import_request.value().SerializeToString(&request_body);

  auto resource_request =
      utils::GenerateResourceRequest(utils::GetImportRequestURL());

  url_loader_ = network::SimpleURLLoader::Create(std::move(resource_request),
                                                 GetCheckInTrafficTag());
  url_loader_->AttachStringForUpload(request_body, "application/x-protobuf");
  url_loader_->SetTimeoutDuration(utils::GetImportRequestTimeout());
  url_loader_->DownloadToString(GetParams()->GetUrlLoaderFactory().get(),
                                base::BindOnce(&CohortImpl::OnCheckInComplete,
                                               weak_factory_.GetWeakPtr()),
                                utils::GetMaxFresnelResponseSizeBytes());
}

void CohortImpl::OnCheckInComplete(std::unique_ptr<std::string> response_body) {
  // Use RAII to reset |url_loader_| after current function scope.
  auto url_loader = std::move(url_loader_);

  int net_code = url_loader->NetError();
  utils::RecordNetErrorCode(utils::PsmUseCase::kCohort,
                            utils::PsmRequest::kImport, net_code);

  if (net_code == net::OK) {
    UpdateLocalStateOnCheckInSuccess();
  } else {
    LOG(ERROR) << "Failed to check in successfully. Net code = " << net_code;
  }

  // Check-in completed - use case is done running.
  std::move(callback_).Run();
}

base::Time CohortImpl::GetLastPingTimestamp() {
  return GetParams()->GetLocalState()->GetTime(
      ash::report::prefs::kDeviceActiveChurnCohortMonthlyPingTimestamp);
}

void CohortImpl::SetLastPingTimestamp(base::Time ts) {
  GetParams()->GetLocalState()->SetTime(
      ash::report::prefs::kDeviceActiveChurnCohortMonthlyPingTimestamp, ts);
}

std::vector<psm_rlwe::RlwePlaintextId> CohortImpl::GetPsmIdentifiersToQuery() {
  // TODO: implement methods to generate PSM id.
  std::vector<psm_rlwe::RlwePlaintextId> query_psm_ids = {};
  return query_psm_ids;
}

std::optional<FresnelImportDataRequest>
CohortImpl::GenerateImportRequestBody() {
  FresnelImportDataRequest import_request;
  import_request.set_use_case(kPsmUseCase);

  // Certain metadata is passed by chrome, since it's not available in ash.
  version_info::Channel version_channel =
      GetParams()->GetChromeDeviceParams().chrome_channel;
  ash::report::MarketSegment market_segment =
      GetParams()->GetChromeDeviceParams().market_segment;

  DeviceMetadata* device_metadata = import_request.mutable_device_metadata();
  device_metadata->set_chrome_milestone(utils::GetChromeMilestone());
  device_metadata->set_hardware_id(utils::GetFullHardwareClass());
  device_metadata->set_chromeos_channel(
      utils::GetChromeChannel(version_channel));
  device_metadata->set_market_segment(market_segment);

  std::string window_id = utils::TimeToYYYYMMString(GetParams()->GetActiveTs());
  std::optional<psm_rlwe::RlwePlaintextId> psm_id =
      utils::GeneratePsmIdentifier(GetParams()->GetHighEntropySeed(),
                                   psm_rlwe::RlweUseCase_Name(kPsmUseCase),
                                   window_id);

  if (window_id.empty() || !psm_id.has_value()) {
    LOG(ERROR) << "Window ID or Psm ID is empty.";
    return std::nullopt;
  }

  FresnelImportData* import_data = import_request.add_import_data();
  import_data->set_window_identifier(window_id);
  import_data->set_plaintext_id(psm_id.value().sensitive_id());
  import_data->set_is_pt_window_identifier(true);

  ChurnCohortMetadata* cohort_metadata =
      import_data->mutable_churn_cohort_metadata();
  std::optional<ChurnCohortMetadata> new_cohort_metadata =
      active_status_->CalculateCohortMetadata(GetParams()->GetActiveTs());
  if (!new_cohort_metadata.has_value()) {
    LOG(ERROR) << "Failed to calculate new cohort metadata.";
    return std::nullopt;
  }

  base::Time active_ts = GetParams()->GetActiveTs();
  std::optional<base::Time> first_active_week_ts = utils::GetFirstActiveWeek();

  if (!first_active_week_ts.has_value() ||
      first_active_week_ts.value() == base::Time() ||
      first_active_week_ts.value() == base::Time::UnixEpoch()) {
    LOG(ERROR) << "Failed to retrieve first active week from VPD. "
                  "Setting first active week to UNKNOWN.";
    cohort_metadata->set_first_active_week("UNKNOWN");
  } else {
    int max_days_in_5_weeks = 7 * 5;
    bool within_date_range = utils::IsFirstActiveUnderNDaysAgo(
        active_ts, first_active_week_ts.value(), max_days_in_5_weeks);

    // Privacy approved 5 weeks of first active week history in cohort ping.
    // In order for analysts to avoid double counting on the server-side.
    if (within_date_range) {
      cohort_metadata->set_first_active_week(
          utils::ConvertTimeToISO8601String(first_active_week_ts.value()));
    }
  }

  *cohort_metadata = new_cohort_metadata.value();
  return import_request;
}

void CohortImpl::UpdateLocalStateOnCheckInSuccess() {
  // Check the new cohort active status value is valid.
  std::optional<int> new_active_val =
      active_status_->CalculateNewValue(GetParams()->GetActiveTs());
  if (!new_active_val.has_value()) {
    LOG(ERROR)
        << "Failed to update active status value after successful cohort "
           "check in.";
    return;
  }

  PrefService* local_state = GetParams()->GetLocalState();

  // Read the current relative observation period reported status.
  bool is_active_current_period_minus_0 = local_state->GetBoolean(
      prefs::kDeviceActiveLastKnownIsActiveCurrentPeriodMinus0);
  bool is_active_current_period_minus_1 = local_state->GetBoolean(
      prefs::kDeviceActiveLastKnownIsActiveCurrentPeriodMinus1);
  bool is_active_current_period_minus_2 = local_state->GetBoolean(
      prefs::kDeviceActiveLastKnownIsActiveCurrentPeriodMinus2);

  int active_val = active_status_->GetValue();

  // Shift observation periods relative to the last reported cohort month.
  // NOTE: The observation periods are relative to the last cohort ping month.
  // Shift occurs after a successful churn cohort ping in a new month to
  // account for the relative observation report statuses.
  for (int i = 0; i < 3; i++) {
    if (active_val == new_active_val.value()) {
      break;
    }

    is_active_current_period_minus_2 = is_active_current_period_minus_1;
    is_active_current_period_minus_1 = is_active_current_period_minus_0;
    is_active_current_period_minus_0 = false;

    active_val += 1;
  }

  // Update the local state keys after determining all new values.
  local_state->SetBoolean(
      prefs::kDeviceActiveLastKnownIsActiveCurrentPeriodMinus0,
      is_active_current_period_minus_0);
  local_state->SetBoolean(
      prefs::kDeviceActiveLastKnownIsActiveCurrentPeriodMinus1,
      is_active_current_period_minus_1);
  local_state->SetBoolean(
      prefs::kDeviceActiveLastKnownIsActiveCurrentPeriodMinus2,
      is_active_current_period_minus_2);
  active_status_->SetValue(new_active_val.value());
  SetLastPingTimestamp(GetParams()->GetActiveTs());
}

bool CohortImpl::IsDevicePingRequired() {
  base::Time last_ping_ts = GetLastPingTimestamp();
  base::Time cur_ping_ts = GetParams()->GetActiveTs();

  // Safety check to avoid against clock drift, or unexpected timestamps.
  // Check should make sure that we are not reporting window id's for
  // a month previous to one that we reported already.
  if (last_ping_ts >= cur_ping_ts) {
    return false;
  }

  return utils::TimeToYYYYMMString(last_ping_ts) !=
         utils::TimeToYYYYMMString(cur_ping_ts);
}

}  // namespace ash::report::device_metrics