chromium/chrome/browser/ash/policy/enrollment/auto_enrollment_client_impl.cc

// Copyright 2018 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/ash/policy/enrollment/auto_enrollment_client_impl.h"

#include <stdint.h>

#include <memory>
#include <optional>
#include <string>

#include "ash/constants/ash_switches.h"
#include "base/check.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/logging.h"
#include "base/memory/ptr_util.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/scoped_refptr.h"
#include "base/metrics/histogram_functions.h"
#include "base/time/time.h"
#include "base/uuid.h"
#include "base/values.h"
#include "chrome/browser/ash/policy/enrollment/auto_enrollment_state.h"
#include "chrome/browser/ash/policy/enrollment/auto_enrollment_state_message_processor.h"
#include "chrome/browser/ash/policy/enrollment/enrollment_token_provider.h"
#include "chrome/browser/ash/policy/enrollment/psm/rlwe_dmserver_client.h"
#include "chrome/browser/ash/policy/server_backed_state/server_backed_device_state.h"
#include "chrome/common/pref_names.h"
#include "components/policy/core/common/cloud/device_management_service.h"
#include "components/policy/core/common/cloud/dm_auth.h"
#include "components/policy/core/common/cloud/dmserver_job_configurations.h"
#include "components/policy/core/common/cloud/enterprise_metrics.h"
#include "components/policy/proto/device_management_backend.pb.h"
#include "components/prefs/pref_registry_simple.h"
#include "components/prefs/pref_service.h"
#include "components/prefs/scoped_user_pref_update.h"
#include "crypto/sha2.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "third_party/abseil-cpp/absl/types/variant.h"

namespace policy {

namespace {

namespace em = ::enterprise_management;
namespace psm_rlwe = ::private_membership::rlwe;
using EnrollmentCheckType =
    em::DeviceAutoEnrollmentRequest::EnrollmentCheckType;

// Returns the power of the next power-of-2 starting at |value|.
int NextPowerOf2(int64_t value) {
  for (int i = 0; i <= AutoEnrollmentClient::kMaximumPower; ++i) {
    if ((INT64_C(1) << i) >= value)
      return i;
  }
  // No other value can be represented in an int64_t.
  return AutoEnrollmentClient::kMaximumPower + 1;
}

// Provides device identifier for Forced Re-Enrollment (FRE), where the
// server-backed state key is used. It will set the identifier for the
// DeviceAutoEnrollmentRequest.
class DeviceIdentifierProviderFRE {
 public:
  explicit DeviceIdentifierProviderFRE(
      const std::string& server_backed_state_key) {
    CHECK(!server_backed_state_key.empty());
    server_backed_state_key_hash_ =
        crypto::SHA256HashString(server_backed_state_key);
  }

  DeviceIdentifierProviderFRE(const DeviceIdentifierProviderFRE&) = delete;
  DeviceIdentifierProviderFRE& operator=(const DeviceIdentifierProviderFRE&) =
      delete;

  ~DeviceIdentifierProviderFRE() = default;

  // Should return the `EnrollmentCheckType` to be used in the
  // DeviceAutoEnrollmentRequest. This specifies the identifier set used on
  // the server.
  em::DeviceAutoEnrollmentRequest::EnrollmentCheckType GetEnrollmentCheckType()
      const {
    return em::DeviceAutoEnrollmentRequest::ENROLLMENT_CHECK_TYPE_FRE;
  }

  // Should return the hash of this device's identifier. The
  // DeviceAutoEnrollmentRequest exchange will check if this hash is in the
  // server-side identifier set specified by `GetEnrollmentCheckType()`
  const std::string& GetIdHash() const { return server_backed_state_key_hash_; }

 private:
  // SHA-256 digest of the stable identifier.
  std::string server_backed_state_key_hash_;
};

}  // namespace

enum class AutoEnrollmentClientImpl::ServerStateAvailabilitySuccess {
  // Indicates that request has been successful and server state availability is
  // known.
  kSuccess,
  // Special case for server state availability result via auto enrollment
  // request.
  // Indicates that request shall be immediately retried.
  kRetry,
};

// Base class to handle server state availability requests.
class AutoEnrollmentClientImpl::ServerStateAvailabilityRequester {
 public:
  using CompletionCallback =
      base::OnceCallback<void(ServerStateAvailabilityResult)>;

  virtual ~ServerStateAvailabilityRequester() = default;

  // Initiates request and reports back with `callback` once request is
  // finished.
  virtual void Start(CompletionCallback callback) = 0;

  // Returns:
  // * nullopt if server state is not obtained yet,
  // * false if server state has been obtained and the answer is: it is not
  // available.
  // * true if server state has been obtained and the answer is: it is
  // available.
  virtual std::optional<bool> GetServerStateIfObtained() const = 0;
};

// Responsible for resolving server state availability status via auto
// enrollment requests for force re-enrollment.
class AutoEnrollmentClientImpl::FREServerStateAvailabilityRequester
    : public ServerStateAvailabilityRequester {
 public:
  static void RegisterPrefs(PrefRegistrySimple* registry) {
    registry->RegisterBooleanPref(prefs::kShouldAutoEnroll, false);
    registry->RegisterIntegerPref(prefs::kAutoEnrollmentPowerLimit, -1);
  }

  FREServerStateAvailabilityRequester(
      DeviceManagementService* device_management_service,
      scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
      PrefService* local_state,
      const std::string& device_id,
      const std::string& uma_suffix,
      int current_power,
      int power_limit,
      const std::string& server_backed_state_key)
      : device_management_service_(device_management_service),
        url_loader_factory_(url_loader_factory),
        local_state_(local_state),
        device_id_(device_id),
        uma_suffix_(uma_suffix),
        current_power_(current_power),
        power_limit_(power_limit),
        device_identifier_provider_fre_(server_backed_state_key) {
    DCHECK_LE(current_power_, power_limit_);
  }

  FREServerStateAvailabilityRequester(
      const FREServerStateAvailabilityRequester&) = delete;
  FREServerStateAvailabilityRequester& operator=(
      const FREServerStateAvailabilityRequester&) = delete;

  void Start(CompletionCallback callback) override {
    StartImpl(std::move(callback));
  }

  std::optional<bool> GetServerStateIfObtained() const override {
    const PrefService::Preference* has_server_state_pref =
        local_state_->FindPreference(prefs::kShouldAutoEnroll);
    const PrefService::Preference* previous_limit_pref =
        local_state_->FindPreference(prefs::kAutoEnrollmentPowerLimit);

    if (!has_server_state_pref || has_server_state_pref->IsDefaultValue() ||
        !previous_limit_pref || previous_limit_pref->IsDefaultValue()) {
      return std::nullopt;
    }

    DCHECK(has_server_state_pref->GetValue()->is_bool());
    DCHECK(previous_limit_pref->GetValue()->is_int());

    if (power_limit_ > previous_limit_pref->GetValue()->GetInt()) {
      return std::nullopt;
    }

    return has_server_state_pref->GetValue()->GetBool();
  }

 private:
  void StartImpl(CompletionCallback callback) {
    DCHECK(!request_job_);
    DCHECK(callback);
    DCHECK(!completion_callback_);

    completion_callback_ = std::move(callback);

    // Start the Hash dance timer during the first attempt.
    if (hash_dance_time_start_.is_null())
      hash_dance_time_start_ = base::TimeTicks::Now();

    std::string id_hash = device_identifier_provider_fre_.GetIdHash();
    // Currently AutoEnrollmentClientImpl supports working with hashes that are
    // at least 8 bytes long. If this is reduced, the computation of the
    // remainder must also be adapted to handle the case of a shorter hash
    // gracefully.
    DCHECK_GE(id_hash.size(), 8u);

    uint64_t remainder = 0;
    const size_t last_byte_index = id_hash.size() - 1;
    for (int i = 0; 8 * i < current_power_; ++i) {
      uint64_t byte = id_hash[last_byte_index - i] & 0xff;
      remainder = remainder | (byte << (8 * i));
    }
    remainder = remainder & ((UINT64_C(1) << current_power_) - 1);

    // Record the time when the bucket download request is started. Note that
    // the time may be set multiple times. This is fine, only the last request
    // is the one where the hash bucket is actually downloaded.
    time_start_bucket_download_ = base::TimeTicks::Now();

    // TODO(crbug.com/40805389): Logging as "WARNING" to make sure it's
    // preserved in the logs.
    LOG(WARNING) << "Request bucket #" << remainder;

    std::unique_ptr<DMServerJobConfiguration> config =
        std::make_unique<DMServerJobConfiguration>(
            device_management_service_,
            DeviceManagementService::JobConfiguration::TYPE_AUTO_ENROLLMENT,
            device_id_,
            /*critical=*/false, DMAuth::NoAuth(),
            /*oauth_token=*/std::nullopt, url_loader_factory_,
            base::BindOnce(
                &FREServerStateAvailabilityRequester::HandleRequestCompletion,
                base::Unretained(this)));

    em::DeviceAutoEnrollmentRequest* request =
        config->request()->mutable_auto_enrollment_request();
    request->set_remainder(remainder);
    request->set_modulus(INT64_C(1) << current_power_);
    request->set_enrollment_check_type(
        device_identifier_provider_fre_.GetEnrollmentCheckType());

    request_job_ = device_management_service_->CreateJob(std::move(config));
  }

  void HandleRequestCompletion(DMServerJobResult result) {
    DCHECK(request_job_);
    DCHECK(completion_callback_);

    request_job_.reset();

    base::UmaHistogramSparse(kUMAHashDanceRequestStatus + uma_suffix_,
                             result.dm_status);

    if (result.dm_status != DM_STATUS_SUCCESS) {
      LOG(ERROR) << "Auto enrollment error: " << result.dm_status;

      const auto error =
          AutoEnrollmentDMServerError::FromDMServerJobResult(result);
      if (error.network_error.has_value()) {
        base::UmaHistogramSparse(kUMAHashDanceNetworkErrorCode + uma_suffix_,
                                 -error.network_error.value());
      }

      return RunCallback(base::unexpected(error));
    }

    ServerStateAvailabilityResult availability_result =
        ServerStateAvailabilitySuccess::kSuccess;
    const em::DeviceAutoEnrollmentResponse& enrollment_response =
        result.response.auto_enrollment_response();
    if (!result.response.has_auto_enrollment_response()) {
      LOG(ERROR) << "Server failed to provide auto-enrollment response.";
      availability_result =
          base::unexpected(AutoEnrollmentStateAvailabilityResponseError{});
    } else if (enrollment_response.has_expected_modulus()) {
      // Server is asking us to retry with a different modulus.
      modulus_updates_received_++;

      int64_t modulus = enrollment_response.expected_modulus();
      int power = NextPowerOf2(modulus);
      if ((INT64_C(1) << power) != modulus) {
        LOG(ERROR) << "Auto enrollment: the server didn't ask for a power-of-2 "
                   << "modulus. Using the closest power-of-2 instead "
                   << "(" << modulus << " vs 2^" << power << ")";
        availability_result =
            base::unexpected(AutoEnrollmentStateAvailabilityResponseError{});
      }
      if (modulus_updates_received_ >= 2) {
        LOG(ERROR) << "Auto enrollment error: already retried with an updated "
                   << "modulus but the server asked for a new one again: "
                   << power;
        availability_result =
            base::unexpected(AutoEnrollmentStateAvailabilityResponseError{});
      } else if (power > power_limit_) {
        LOG(ERROR) << "Auto enrollment error: the server asked for a larger "
                   << "modulus than the client accepts (" << power << " vs "
                   << power_limit_ << ").";
        availability_result =
            base::unexpected(AutoEnrollmentStateAvailabilityResponseError{});
      } else {
        // Retry at most once with the modulus that the server requested.
        if (power <= current_power_) {
          LOG(WARNING) << "Auto enrollment: the server asked to use a modulus ("
                       << power << ") that isn't larger than the first used ("
                       << current_power_ << "). Retrying anyway.";
        }
        // Remember this value, so that eventual retries start with the correct
        // modulus.
        current_power_ = power;
        DCHECK(!GetServerStateIfObtained());
        RunCallback(ServerStateAvailabilitySuccess::kRetry);
        return;
      }
    } else {
      // Server should have sent down a list of hashes to try.
      const bool has_server_state =
          IsIdHashInProtobuf(enrollment_response.hashes());
      // Cache the current decision in local_state, so that it is reused in case
      // the device reboots before enrolling.
      local_state_->SetBoolean(prefs::kShouldAutoEnroll, has_server_state);
      local_state_->SetInteger(prefs::kAutoEnrollmentPowerLimit, power_limit_);
      local_state_->CommitPendingWrite();

      // TODO(crbug.com/40805389): Logging as "WARNING" to make sure it's
      // preserved in the logs.
      LOG(WARNING) << "Received has_state=" << has_server_state;

      availability_result = ServerStateAvailabilitySuccess::kSuccess;
      RecordHashDanceSuccessTimeHistogram();
    }

    const bool succeeded_with_result =
        availability_result == ServerStateAvailabilitySuccess::kSuccess &&
        GetServerStateIfObtained();
    const bool failed_without_result =
        availability_result != ServerStateAvailabilitySuccess::kSuccess &&
        !GetServerStateIfObtained();
    DCHECK(succeeded_with_result || failed_without_result);

    // Bucket download done, update UMA.
    UpdateBucketDownloadTimingHistograms();
    RunCallback(availability_result);
  }

  void RunCallback(ServerStateAvailabilityResult availability_result) {
    DCHECK(completion_callback_);
    std::move(completion_callback_).Run(availability_result);
  }

  bool IsIdHashInProtobuf(
      const google::protobuf::RepeatedPtrField<std::string>& hashes) const {
    const std::string id_hash = device_identifier_provider_fre_.GetIdHash();
    for (int i = 0; i < hashes.size(); ++i) {
      if (hashes.Get(i) == id_hash)
        return true;
    }
    return false;
  }

  void UpdateBucketDownloadTimingHistograms() const {
    // These values determine bucketing of the histogram, they should not be
    // changed.
    // The minimum time can't be 0, must be at least 1.
    static const base::TimeDelta kMin = base::Milliseconds(1);
    static const base::TimeDelta kMax = base::Minutes(5);
    static const int kBuckets = 50;

    base::TimeTicks now = base::TimeTicks::Now();
    if (!hash_dance_time_start_.is_null()) {
      base::TimeDelta delta = now - hash_dance_time_start_;
      base::UmaHistogramCustomTimes(kUMAHashDanceProtocolTime + uma_suffix_,
                                    delta, kMin, kMax, kBuckets);
    }
    if (!time_start_bucket_download_.is_null()) {
      base::TimeDelta delta = now - time_start_bucket_download_;
      base::UmaHistogramCustomTimes(
          kUMAHashDanceBucketDownloadTime + uma_suffix_, delta, kMin, kMax,
          kBuckets);
    }
  }

  void RecordHashDanceSuccessTimeHistogram() const {
    // These values determine bucketing of the histogram, they should not be
    // changed.
    static const base::TimeDelta kMin = base::Milliseconds(1);
    static const base::TimeDelta kMax = base::Seconds(25);
    static const int kBuckets = 50;

    base::TimeTicks now = base::TimeTicks::Now();
    if (!hash_dance_time_start_.is_null()) {
      base::TimeDelta delta = now - hash_dance_time_start_;
      base::UmaHistogramCustomTimes(kUMAHashDanceSuccessTime + uma_suffix_,
                                    delta, kMin, kMax, kBuckets);
    }
  }

  raw_ptr<DeviceManagementService, DanglingUntriaged>
      device_management_service_;
  scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
  raw_ptr<PrefService> local_state_;
  const std::string device_id_;
  const std::string uma_suffix_;

  // Power-of-2 modulus to try next.
  int current_power_;

  // Power of the maximum power-of-2 modulus that this client will accept from
  // a retry response from the server.
  const int power_limit_;

  // Number of requests for a different modulus received from the server.
  // Used to determine if the server keeps asking for different moduli.
  int modulus_updates_received_ = 0;

  // Times used to determine the duration of the protocol, and the extra time
  // needed to complete after the signin was complete.
  // If `hash_dance_time_start_` is not null, the protocol is still running.
  base::TimeTicks hash_dance_time_start_;

  // The time when the bucket download part of the protocol started.
  base::TimeTicks time_start_bucket_download_;

  std::unique_ptr<DeviceManagementService::Job> request_job_;

  const DeviceIdentifierProviderFRE device_identifier_provider_fre_;

  CompletionCallback completion_callback_;
};

// Responsible for resolving server state availability status via private
// membership check requests for initial enrollment.
class AutoEnrollmentClientImpl::InitialServerStateAvailabilityRequester
    : public ServerStateAvailabilityRequester {
 public:
  static void RegisterPrefs(PrefRegistrySimple* registry) {
    registry->RegisterBooleanPref(prefs::kShouldRetrieveDeviceState, false);
    registry->RegisterIntegerPref(prefs::kEnrollmentPsmResult, -1);
    registry->RegisterTimePref(prefs::kEnrollmentPsmDeterminationTime,
                               base::Time());
  }

  explicit InitialServerStateAvailabilityRequester(
      std::unique_ptr<psm::RlweDmserverClient> psm_rlwe_dmserver_client,
      PrefService* local_state)
      : psm_rlwe_dmserver_client_(std::move(psm_rlwe_dmserver_client)),
        local_state_(local_state) {}

  InitialServerStateAvailabilityRequester(
      const InitialServerStateAvailabilityRequester&) = delete;
  InitialServerStateAvailabilityRequester& operator=(
      const InitialServerStateAvailabilityRequester&) = delete;

  void Start(CompletionCallback callback) override {
    StartImpl(std::move(callback));
  }

  std::optional<bool> GetServerStateIfObtained() const override {
    const PrefService::Preference* has_psm_server_state_pref =
        local_state_->FindPreference(prefs::kShouldRetrieveDeviceState);

    if (!has_psm_server_state_pref ||
        has_psm_server_state_pref->IsDefaultValue()) {
      return std::nullopt;
    }

    DCHECK(has_psm_server_state_pref->GetValue()->is_bool());

    return has_psm_server_state_pref->GetValue()->GetBool();
  }

 private:
  void StartImpl(CompletionCallback callback) {
    DCHECK(callback);
    DCHECK(!completion_callback_);
    DCHECK(!psm_rlwe_dmserver_client_->IsCheckMembershipInProgress());

    PrepareLocalState();

    completion_callback_ = std::move(callback);

    psm_rlwe_dmserver_client_->CheckMembership(base::BindOnce(
        &InitialServerStateAvailabilityRequester::HandlePsmCompletion,
        base::Unretained(this)));
  }

  void HandlePsmCompletion(
      psm::RlweDmserverClient::ResultHolder psm_result_holder) {
    UpdateLocalState(psm_result_holder);

    switch (psm_result_holder.psm_result) {
      case psm::RlweResult::kSuccessfulDetermination:
        DCHECK(GetServerStateIfObtained());
        RunCallback(ServerStateAvailabilitySuccess::kSuccess);
        break;

      case psm::RlweResult::kConnectionError:
      case psm::RlweResult::kServerError:
        DCHECK(psm_result_holder.dm_server_error.has_value());
        RunCallback(
            base::unexpected(psm_result_holder.dm_server_error.value()));
        break;

      case psm::RlweResult::kEmptyOprfResponseError:
      case psm::RlweResult::kEmptyQueryResponseError:
        RunCallback(
            base::unexpected(AutoEnrollmentStateAvailabilityResponseError{}));
        break;

      case psm::RlweResult::kCreateRlweClientLibraryError:
      case psm::RlweResult::kCreateOprfRequestLibraryError:
      case psm::RlweResult::kCreateQueryRequestLibraryError:
      case psm::RlweResult::kProcessingQueryResponseLibraryError:
        DCHECK(!GetServerStateIfObtained());
        RunCallback(base::unexpected(AutoEnrollmentPsmError{}));
        break;
    }
  }

  void RunCallback(ServerStateAvailabilityResult availability_result) {
    DCHECK(completion_callback_);
    std::move(completion_callback_).Run(availability_result);
  }

  void PrepareLocalState() {
    // Set the initial PSM execution result as unknown until it finishes
    // successfully or due to an error.
    // Also, clear the PSM determination timestamp.
    local_state_->SetInteger(prefs::kEnrollmentPsmResult,
                             em::DeviceRegisterRequest::PSM_RESULT_UNKNOWN);
    local_state_->ClearPref(prefs::kEnrollmentPsmDeterminationTime);
  }

  void UpdateLocalState(
      const psm::RlweDmserverClient::ResultHolder& psm_result_holder) {
    if (psm_result_holder.IsError()) {
      local_state_->SetInteger(prefs::kEnrollmentPsmResult,
                               em::DeviceRegisterRequest::PSM_RESULT_ERROR);
      return;
    }

    local_state_->SetBoolean(prefs::kShouldRetrieveDeviceState,
                             psm_result_holder.membership_result.value());
    local_state_->SetTime(
        prefs::kEnrollmentPsmDeterminationTime,
        psm_result_holder.membership_determination_time.value());
    local_state_->SetInteger(
        prefs::kEnrollmentPsmResult,
        psm_result_holder.membership_result.value()
            ? em::DeviceRegisterRequest::PSM_RESULT_SUCCESSFUL_WITH_STATE
            : em::DeviceRegisterRequest::PSM_RESULT_SUCCESSFUL_WITHOUT_STATE);
  }

  // Obtains the device state using PSM protocol. Handles all communications
  // related to PSM protocol with DMServer.
  std::unique_ptr<psm::RlweDmserverClient> psm_rlwe_dmserver_client_;

  raw_ptr<PrefService> local_state_;

  CompletionCallback completion_callback_;
};

// Stubbed out ServerStateAvailabilityRequester that always succeeds and
// indicates that server state should be retrieved.
class AutoEnrollmentClientImpl::TokenBasedEnrollmentStateAvailabilityRequester
    : public ServerStateAvailabilityRequester {
 public:
  explicit TokenBasedEnrollmentStateAvailabilityRequester(
      std::optional<std::string> enrollment_token,
      PrefService* local_state)
      : enrollment_token_(std::move(enrollment_token)),
        local_state_(local_state) {
    local_state_->SetInteger(
        prefs::kEnrollmentPsmResult,
        em::DeviceRegisterRequest::PSM_SKIPPED_FOR_FLEX_AUTO_ENROLLMENT);
    local_state_->SetBoolean(prefs::kShouldRetrieveDeviceState, true);
  }
  TokenBasedEnrollmentStateAvailabilityRequester(
      const TokenBasedEnrollmentStateAvailabilityRequester&) = delete;
  TokenBasedEnrollmentStateAvailabilityRequester& operator=(
      const TokenBasedEnrollmentStateAvailabilityRequester&) = delete;

  void Start(CompletionCallback callback) override {
    std::move(callback).Run(ServerStateAvailabilitySuccess::kSuccess);
  }

  std::optional<bool> GetServerStateIfObtained() const override {
    // This should always return true (as this class _should_ only be
    // instantiated after determining that an enrollment token is present).
    // Check the optional again anyways though for defensive programming
    // purposes.
    DCHECK(enrollment_token_.has_value());
    return enrollment_token_.has_value();
  }

 private:
  const std::optional<std::string> enrollment_token_;
  raw_ptr<PrefService> local_state_;
};

// Responsible fro resolving server state status for both force re-enrollment
// and initial enrollment.
class AutoEnrollmentClientImpl::ServerStateRetriever {
 public:
  using CompletionCallback =
      base::OnceCallback<void(ServerStateRetrievalResult)>;

  ServerStateRetriever(
      DeviceManagementService* device_management_service,
      scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
      PrefService* local_state,
      const std::string& device_id,
      const std::string& uma_suffix,
      std::unique_ptr<AutoEnrollmentStateMessageProcessor>
          state_download_message_processor)
      : device_management_service_(device_management_service),
        url_loader_factory_(url_loader_factory),
        local_state_(local_state),
        device_id_(device_id),
        uma_suffix_(uma_suffix),
        state_download_message_processor_(
            std::move(state_download_message_processor)) {}

  ServerStateRetriever(const ServerStateRetriever&) = delete;
  ServerStateRetriever& operator=(const ServerStateRetriever&) = delete;

  void Start(CompletionCallback callback) { StartImpl(std::move(callback)); }

  std::optional<AutoEnrollmentState> GetAutoEnrollmentStateIfObtained() const {
    if (!device_state_available_) {
      return std::nullopt;
    }

    const DeviceStateMode device_state_mode = GetDeviceStateMode();
    switch (device_state_mode) {
      case RESTORE_MODE_NONE:
        return AutoEnrollmentResult::kNoEnrollment;
      case RESTORE_MODE_DISABLED:
        return AutoEnrollmentResult::kDisabled;
      case RESTORE_MODE_REENROLLMENT_REQUESTED:
        return AutoEnrollmentResult::kSuggestedEnrollment;
      case RESTORE_MODE_REENROLLMENT_ENFORCED:
      case INITIAL_MODE_ENROLLMENT_ENFORCED:
      case RESTORE_MODE_REENROLLMENT_ZERO_TOUCH:
      case INITIAL_MODE_ENROLLMENT_ZERO_TOUCH:
      case INITIAL_MODE_ENROLLMENT_TOKEN_ENROLLMENT:
        return AutoEnrollmentResult::kEnrollment;
    }
  }

 private:
  void StartImpl(CompletionCallback callback) {
    DCHECK(!request_job_);
    DCHECK(callback);
    DCHECK(!completion_callback_);
    DCHECK(!device_state_available_);

    completion_callback_ = std::move(callback);

    std::unique_ptr<DMServerJobConfiguration> config =
        std::make_unique<DMServerJobConfiguration>(
            device_management_service_,
            state_download_message_processor_->GetJobType(), device_id_,
            /*critical=*/false, DMAuth::NoAuth(),
            /*oauth_token=*/std::nullopt, url_loader_factory_,
            base::BindRepeating(&ServerStateRetriever::HandleRequestCompletion,
                                base::Unretained(this)));

    state_download_message_processor_->FillRequest(config->request());
    request_job_ = device_management_service_->CreateJob(std::move(config));
  }

  void HandleRequestCompletion(DMServerJobResult result) {
    DCHECK(request_job_);
    DCHECK(completion_callback_);

    request_job_.reset();

    base::UmaHistogramSparse(kUMAHashDanceRequestStatus + uma_suffix_,
                             result.dm_status);
    if (result.dm_status != DM_STATUS_SUCCESS) {
      LOG(ERROR) << "Auto enrollment error: " << result.dm_status;

      const auto error =
          AutoEnrollmentDMServerError::FromDMServerJobResult(result);
      if (error.network_error.has_value()) {
        base::UmaHistogramSparse(kUMAHashDanceNetworkErrorCode + uma_suffix_,
                                 -error.network_error.value());
      }
      return RunCallback(base::unexpected(error));
    }

    std::optional<AutoEnrollmentStateMessageProcessor::ParsedResponse>
        parsed_response_result =
            state_download_message_processor_->ParseResponse(result.response);
    if (!parsed_response_result) {
      return RunCallback(
          base::unexpected(AutoEnrollmentStateRetrievalResponseError{}));
    }

    AutoEnrollmentStateMessageProcessor::ParsedResponse& parsed_response =
        *parsed_response_result;

    base::Value::Dict state;
    if (parsed_response.management_domain.has_value())
      state.Set(kDeviceStateManagementDomain,
                *parsed_response.management_domain);

    if (!parsed_response.restore_mode.empty())
      state.Set(kDeviceStateMode, parsed_response.restore_mode);

    if (parsed_response.disabled_message.has_value())
      state.Set(kDeviceStateDisabledMessage, *parsed_response.disabled_message);

    if (parsed_response.is_license_packaged_with_device.has_value())
      state.Set(kDeviceStatePackagedLicense,
                *parsed_response.is_license_packaged_with_device);

    if (parsed_response.license_type.has_value())
      state.Set(kDeviceStateLicenseType, *parsed_response.license_type);

    if (parsed_response.assigned_upgrade_type.has_value()) {
      state.Set(kDeviceStateAssignedUpgradeType,
                *parsed_response.assigned_upgrade_type);
    }

    // Store the enrollment state obtained from the server to local state.
    // Depending on the value, this can be used later to trigger enrollment or
    // to disable the device.
    local_state_->SetDict(prefs::kServerBackedDeviceState, std::move(state));

    device_state_available_ = true;
    RunCallback(base::ok());
  }

  void RunCallback(ServerStateRetrievalResult result) {
    DCHECK(completion_callback_);
    std::move(completion_callback_).Run(result);
  }

  raw_ptr<DeviceManagementService, DanglingUntriaged>
      device_management_service_;
  scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
  raw_ptr<PrefService> local_state_;
  const std::string device_id_;
  const std::string uma_suffix_;

  // Whether the download of server-kept device state completed successfully.
  bool device_state_available_ = false;

  std::unique_ptr<DeviceManagementService::Job> request_job_;

  // Fills and parses state retrieval request / response.
  std::unique_ptr<AutoEnrollmentStateMessageProcessor>
      state_download_message_processor_;

  CompletionCallback completion_callback_;
};

AutoEnrollmentClientImpl::FactoryImpl::FactoryImpl() = default;
AutoEnrollmentClientImpl::FactoryImpl::~FactoryImpl() = default;

std::unique_ptr<AutoEnrollmentClient>
AutoEnrollmentClientImpl::FactoryImpl::CreateForFRE(
    const ProgressCallback& progress_callback,
    DeviceManagementService* device_management_service,
    PrefService* local_state,
    scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
    const std::string& server_backed_state_key,
    int power_initial,
    int power_limit) {
  const std::string device_id =
      base::Uuid::GenerateRandomV4().AsLowercaseString();
  return base::WrapUnique(new AutoEnrollmentClientImpl(
      progress_callback,
      std::make_unique<FREServerStateAvailabilityRequester>(
          device_management_service, url_loader_factory, local_state, device_id,
          kUMASuffixFRE, power_initial, power_limit, server_backed_state_key),
      std::make_unique<ServerStateRetriever>(
          device_management_service, url_loader_factory, local_state, device_id,
          kUMASuffixFRE,
          AutoEnrollmentStateMessageProcessor::CreateForFRE(
              server_backed_state_key))));
}

std::unique_ptr<AutoEnrollmentClient>
AutoEnrollmentClientImpl::FactoryImpl::CreateForInitialEnrollment(
    const ProgressCallback& progress_callback,
    DeviceManagementService* device_management_service,
    PrefService* local_state,
    scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
    const std::string& device_serial_number,
    const std::string& device_brand_code,
    std::unique_ptr<psm::RlweDmserverClient> psm_rlwe_dmserver_client,
    ash::OobeConfiguration* oobe_config) {
  std::unique_ptr<ServerStateAvailabilityRequester>
      server_state_availability_requester;
  const std::optional<std::string> enrollment_token =
      GetEnrollmentToken(oobe_config);
  if (enrollment_token.has_value()) {
    server_state_availability_requester =
        std::make_unique<TokenBasedEnrollmentStateAvailabilityRequester>(
            enrollment_token, local_state);
  } else {
    server_state_availability_requester =
        std::make_unique<InitialServerStateAvailabilityRequester>(
            std::move(psm_rlwe_dmserver_client), local_state);
  }
  return base::WrapUnique(new AutoEnrollmentClientImpl(
      progress_callback, std::move(server_state_availability_requester),
      std::make_unique<ServerStateRetriever>(
          device_management_service, url_loader_factory, local_state,
          /*device_id=*/base::Uuid::GenerateRandomV4().AsLowercaseString(),
          kUMASuffixInitialEnrollment,
          AutoEnrollmentStateMessageProcessor::CreateForInitialEnrollment(
              device_serial_number, device_brand_code, enrollment_token))));
}

// static
void AutoEnrollmentClientImpl::RegisterPrefs(PrefRegistrySimple* registry) {
  FREServerStateAvailabilityRequester::RegisterPrefs(registry);
  InitialServerStateAvailabilityRequester::RegisterPrefs(registry);
}

AutoEnrollmentClientImpl::AutoEnrollmentClientImpl(
    ProgressCallback callback,
    std::unique_ptr<ServerStateAvailabilityRequester>
        server_state_availability_requester,
    std::unique_ptr<ServerStateRetriever> server_state_retriever)
    : progress_callback_(std::move(callback)),
      server_state_availability_requester_(
          std::move(server_state_availability_requester)),
      server_state_retriever_(std::move(server_state_retriever)) {
  DCHECK(progress_callback_);
}

AutoEnrollmentClientImpl::~AutoEnrollmentClientImpl() = default;

void AutoEnrollmentClientImpl::Start() {
  DCHECK_EQ(state_, State::kIdle);
  DCHECK(!server_state_retriever_->GetAutoEnrollmentStateIfObtained());

  RequestServerStateAvailability();
}

void AutoEnrollmentClientImpl::Retry() {
  switch (state_) {
    case State::kIdle:
      Start();
      break;

    // Request in progress, nothing to do.
    case State::kRequestingServerStateAvailability:
    case State::kRequestingStateRetrieval:
      break;

    case State::kRequestServerStateAvailabilityError:
      RequestServerStateAvailability();
      break;

    case State::kRequestStateRetrievalError:
      RequestStateRetrieval();
      break;

    // All possible requests are done and the final device state has been
    // reported. Nothing to to do.
    case State::kFinished:
      break;
    case State::kRequestServerStateAvailabilitySuccess:
      NOTREACHED_IN_MIGRATION()
          << "kRequestServerStateAvailabilitySuccess supposed to "
             "immediately resolve to kRequestingStateRetrieval.";
      break;
  }
}

void AutoEnrollmentClientImpl::RequestServerStateAvailability() {
  DCHECK(state_ == State::kIdle ||
         state_ == State::kRequestServerStateAvailabilityError);
  state_ = State::kRequestingServerStateAvailability;

  if (server_state_availability_requester_->GetServerStateIfObtained()) {
    OnServerStateAvailabilityCompleted(
        ServerStateAvailabilitySuccess::kSuccess);
    return;
  }

  server_state_availability_requester_->Start(base::BindOnce(
      &AutoEnrollmentClientImpl::OnServerStateAvailabilityCompleted,
      base::Unretained(this)));
}

void AutoEnrollmentClientImpl::OnServerStateAvailabilityCompleted(
    ServerStateAvailabilityResult result) {
  DCHECK(state_ == State::kRequestingServerStateAvailability);

  if (!result.has_value()) {
    if (absl::holds_alternative<AutoEnrollmentPsmError>(result.error())) {
      // At the moment, `AutoEnrollmentClientImpl` will not distinguish
      // between any of the PSM errors (except for connection error, and
      // server error) and will report final progress with given server state
      // even if it's not available.
      DCHECK(!server_state_availability_requester_->GetServerStateIfObtained());
      state_ = State::kFinished;
      return ReportFinished();
    }

    state_ = State::kRequestServerStateAvailabilityError;
    return ReportProgress(base::unexpected(result.error()));
  }

  switch (result.value()) {
    case ServerStateAvailabilitySuccess::kSuccess:
      DCHECK(server_state_availability_requester_->GetServerStateIfObtained());
      if (server_state_availability_requester_->GetServerStateIfObtained()
              .value()) {
        state_ = State::kRequestServerStateAvailabilitySuccess;
        return RequestStateRetrieval();
      } else {
        state_ = State::kFinished;
        return ReportFinished();
      }
    case ServerStateAvailabilitySuccess::kRetry:
      state_ = State::kRequestServerStateAvailabilityError;
      return Retry();
  }
}

void AutoEnrollmentClientImpl::RequestStateRetrieval() {
  DCHECK(state_ == State::kRequestServerStateAvailabilitySuccess ||
         state_ == State::kRequestStateRetrievalError);
  DCHECK(server_state_availability_requester_->GetServerStateIfObtained());
  DCHECK(
      server_state_availability_requester_->GetServerStateIfObtained().value());
  DCHECK(!server_state_retriever_->GetAutoEnrollmentStateIfObtained());
  state_ = State::kRequestingStateRetrieval;

  server_state_retriever_->Start(
      base::BindOnce(&AutoEnrollmentClientImpl::OnStateRetrievalCompleted,
                     base::Unretained(this)));
}

void AutoEnrollmentClientImpl::OnStateRetrievalCompleted(
    ServerStateRetrievalResult result) {
  DCHECK(state_ == State::kRequestingStateRetrieval);

  if (!result.has_value()) {
    state_ = State::kRequestStateRetrievalError;
    return ReportProgress(base::unexpected(result.error()));
  }

  DCHECK(server_state_retriever_->GetAutoEnrollmentStateIfObtained());
  state_ = State::kFinished;
  ReportFinished();
}

void AutoEnrollmentClientImpl::ReportProgress(AutoEnrollmentState state) const {
  DCHECK(progress_callback_);
  progress_callback_.Run(state);
}

void AutoEnrollmentClientImpl::ReportFinished() const {
  DCHECK_EQ(state_, State::kFinished);

  const auto auto_enrollment_state_result =
      server_state_retriever_->GetAutoEnrollmentStateIfObtained();
  if (auto_enrollment_state_result) {
    ReportProgress(auto_enrollment_state_result.value());
  } else {
    ReportProgress(AutoEnrollmentResult::kNoEnrollment);
  }
}

}  // namespace policy