chromium/chrome/browser/ash/cert_provisioning/cert_provisioning_invalidator.cc

// Copyright 2020 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/ash/cert_provisioning/cert_provisioning_invalidator.h"

#include <utility>

#include "base/functional/overloaded.h"
#include "base/strings/string_util.h"
#include "base/strings/stringprintf.h"
#include "chrome/browser/ash/cert_provisioning/cert_provisioning_common.h"
#include "chrome/browser/invalidation/profile_invalidation_provider_factory.h"
#include "components/invalidation/invalidation_factory.h"
#include "components/invalidation/invalidation_listener.h"
#include "components/invalidation/profile_invalidation_provider.h"
#include "components/invalidation/public/invalidation.h"
#include "components/invalidation/public/invalidation_service.h"
#include "components/invalidation/public/invalidation_util.h"
#include "components/invalidation/public/invalidator_state.h"
#include "components/policy/core/common/cloud/cloud_policy_constants.h"

namespace ash::cert_provisioning {

namespace {

// Topics that start with this prefix are considered to be "public" FCM topics.
// This allows us to migrate to "private" FCM topics (which would get a
// different prefix) server-side without client-side changes.
constexpr char kFcmCertProvisioningPublicTopicPrefix[] = "cert-";

// Shall be expanded to cert.[scope].[topic]
constexpr char kOwnerNameFormat[] = "cert.%s.%s";

const char* CertScopeToString(CertScope scope) {
  switch (scope) {
    case CertScope::kUser:
      return "user";
    case CertScope::kDevice:
      return "device";
  }

  NOTREACHED_IN_MIGRATION()
      << "Unknown cert scope: " << static_cast<int>(scope);
  return "";
}

}  // namespace

//=============== CertProvisioningInvalidationHandler ==========================

namespace internal {

// static
std::unique_ptr<CertProvisioningInvalidationHandler>
CertProvisioningInvalidationHandler::BuildAndRegister(
    CertScope scope,
    std::variant<invalidation::InvalidationService*,
                 invalidation::InvalidationListener*>
        invalidation_service_or_listener,
    const invalidation::Topic& topic,
    const std::string& listener_type,
    OnInvalidationEventCallback on_invalidation_event_callback) {
  auto invalidator = std::make_unique<CertProvisioningInvalidationHandler>(
      scope, std::move(invalidation_service_or_listener), topic, listener_type,
      std::move(on_invalidation_event_callback));

  if (!invalidator->Register()) {
    return nullptr;
  }

  return invalidator;
}

CertProvisioningInvalidationHandler::CertProvisioningInvalidationHandler(
    CertScope scope,
    std::variant<invalidation::InvalidationService*,
                 invalidation::InvalidationListener*>
        invalidation_service_or_listener,
    const invalidation::Topic& topic,
    const std::string& listener_type,
    OnInvalidationEventCallback on_invalidation_event_callback)
    : scope_(scope),
      invalidation_service_or_listener_(
          invalidation::PointerVariantToRawPointer(
              invalidation_service_or_listener)),
      topic_(topic),
      listener_type_(listener_type),
      on_invalidation_event_callback_(
          std::move(on_invalidation_event_callback)) {
  CHECK(!std::holds_alternative<raw_ptr<invalidation::InvalidationService>>(
            invalidation_service_or_listener_) ||
        std::get<raw_ptr<invalidation::InvalidationService>>(
            invalidation_service_or_listener_))
      << "InvalidationService is used but is null";
  CHECK(!std::holds_alternative<raw_ptr<invalidation::InvalidationListener>>(
            invalidation_service_or_listener_) ||
        std::get<raw_ptr<invalidation::InvalidationListener>>(
            invalidation_service_or_listener_))
      << "InvalidationListener is used but is null";
  CHECK(!on_invalidation_event_callback_.is_null());
}

CertProvisioningInvalidationHandler::~CertProvisioningInvalidationHandler() {
  // Unregister is not called here so that subscription can be preserved if
  // browser restarts. If subscription is not needed a user must call Unregister
  // explicitly.
}

bool CertProvisioningInvalidationHandler::Register() {
  if (IsRegistered()) {
    return true;
  }

  return std::visit(
      base::Overloaded{[this](invalidation::InvalidationService* service) {
                         return RegisterWithInvalidationService(service);
                       },
                       [this](invalidation::InvalidationListener* listener) {
                         invalidation_listener_observation_.Observe(listener);
                         return true;
                       }},
      invalidation_service_or_listener_);
}

bool CertProvisioningInvalidationHandler::RegisterWithInvalidationService(
    invalidation::InvalidationService* service) {
  OnInvalidatorStateChange(service->GetInvalidatorState());
  invalidation_service_observation_.Observe(service);

  if (!service->UpdateInterestedTopics(this,
                                       /*topics=*/{topic_})) {
    LOG(WARNING) << "Failed to register with topic " << topic_;
    return false;
  }

  return true;
}

void CertProvisioningInvalidationHandler::Unregister() {
  if (!IsRegistered()) {
    return;
  }

  if (std::holds_alternative<raw_ptr<invalidation::InvalidationService>>(
          invalidation_service_or_listener_)) {
    const auto service = std::get<raw_ptr<invalidation::InvalidationService>>(
        invalidation_service_or_listener_);
    // Assuming that updating invalidator's topics with an empty set can never
    // fail. Failure is only possible when setting a non-empty topic that is
    // already associated with some other invalidator.
    const bool topics_reset =
        service->UpdateInterestedTopics(this, invalidation::TopicSet());
    CHECK(topics_reset);
  }
  invalidation_service_observation_.Reset();
  invalidation_listener_observation_.Reset();
}

void CertProvisioningInvalidationHandler::OnInvalidatorStateChange(
    invalidation::InvalidatorState state) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}

void CertProvisioningInvalidationHandler::OnSuccessfullySubscribed(
    const invalidation::Topic& topic) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  CHECK_EQ(topic, topic_)
      << "Successfully subscribed notification for wrong topic";

  on_invalidation_event_callback_.Run(
      InvalidationEvent::kSuccessfullySubscribed);
}

void CertProvisioningInvalidationHandler::OnIncomingInvalidation(
    const invalidation::Invalidation& invalidation) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  if (!AreInvalidationsEnabled()) {
    LOG(WARNING) << "Unexpected invalidation received.";
  }

  CHECK(invalidation.topic() == topic_)
      << "Incoming invalidation does not contain invalidation"
         " for certificate topic";

  on_invalidation_event_callback_.Run(InvalidationEvent::kInvalidationReceived);
}

std::string CertProvisioningInvalidationHandler::GetOwnerName() const {
  return base::StringPrintf(kOwnerNameFormat, CertScopeToString(scope_),
                            topic_.c_str());
}

bool CertProvisioningInvalidationHandler::IsPublicTopic(
    const invalidation::Topic& topic) const {
  return base::StartsWith(topic, kFcmCertProvisioningPublicTopicPrefix,
                          base::CompareCase::SENSITIVE);
}

void CertProvisioningInvalidationHandler::OnExpectationChanged(
    invalidation::InvalidationsExpected expected) {
  are_invalidations_expected_ = expected;
}

void CertProvisioningInvalidationHandler::OnInvalidationReceived(
    const invalidation::DirectInvalidation& invalidation) {
  on_invalidation_event_callback_.Run(InvalidationEvent::kInvalidationReceived);
}

std::string CertProvisioningInvalidationHandler::GetType() const {
  return listener_type_;
}

bool CertProvisioningInvalidationHandler::IsRegistered() const {
  return std::visit(
      base::Overloaded{
          [this](invalidation::InvalidationService* service) {
            return invalidation_service_observation_.IsObservingSource(service);
          },
          [this](invalidation::InvalidationListener* listener) {
            return invalidation_listener_observation_.IsObservingSource(
                listener);
          }},
      invalidation_service_or_listener_);
}

bool CertProvisioningInvalidationHandler::AreInvalidationsEnabled() const {
  if (!IsRegistered()) {
    return false;
  }

  return std::visit(
      base::Overloaded{[](invalidation::InvalidationService* service) {
                         return service->GetInvalidatorState() ==
                                invalidation::InvalidatorState::kEnabled;
                       },
                       [this](invalidation::InvalidationListener* listener) {
                         return are_invalidations_expected_ ==
                                invalidation::InvalidationsExpected::kYes;
                       }},
      invalidation_service_or_listener_);
}

}  // namespace internal

//=============== CertProvisioningUserInvalidator ==============================

CertProvisioningInvalidator::CertProvisioningInvalidator() = default;
CertProvisioningInvalidator::~CertProvisioningInvalidator() = default;

void CertProvisioningInvalidator::Unregister() {
  if (invalidation_handler_) {
    invalidation_handler_->Unregister();
    invalidation_handler_.reset();
  }
}

//=============== CertProvisioningUserInvalidatorFactory =======================

CertProvisioningUserInvalidatorFactory::CertProvisioningUserInvalidatorFactory(
    Profile* profile)
    : profile_(profile) {
  CHECK(profile_);
}

std::unique_ptr<CertProvisioningInvalidator>
CertProvisioningUserInvalidatorFactory::Create() {
  return std::make_unique<CertProvisioningUserInvalidator>(profile_);
}

//=============== CertProvisioningUserInvalidator ==============================

CertProvisioningUserInvalidator::CertProvisioningUserInvalidator(
    Profile* profile)
    : profile_(profile) {
  CHECK(profile_);
}

void CertProvisioningUserInvalidator::Register(
    const invalidation::Topic& topic,
    const std::string& listener_type,
    OnInvalidationEventCallback on_invalidation_event_callback) {
  invalidation::ProfileInvalidationProvider* invalidation_provider =
      invalidation::ProfileInvalidationProviderFactory::GetForProfile(profile_);
  CHECK(invalidation_provider);

  invalidation_handler_ =
      internal::CertProvisioningInvalidationHandler::BuildAndRegister(
          CertScope::kUser,
          invalidation_provider->GetInvalidationServiceOrListener(
              policy::kPolicyFCMInvalidationSenderID,
              invalidation::InvalidationListener::kProjectNumberEnterprise),
          topic, listener_type, std::move(on_invalidation_event_callback));

  if (!invalidation_handler_) {
    LOG(ERROR) << "Failed to register for invalidation topic";
  }
}

//=============== CertProvisioningDeviceInvalidatorFactory =====================

CertProvisioningDeviceInvalidatorFactory::
    CertProvisioningDeviceInvalidatorFactory(
        std::variant<policy::AffiliatedInvalidationServiceProvider*,
                     invalidation::InvalidationListener*>
            invalidation_service_provider_or_listener)
    : invalidation_service_provider_or_listener_(
          invalidation::PointerVariantToRawPointer(
              invalidation_service_provider_or_listener)) {
  CHECK(!std::holds_alternative<
            raw_ptr<policy::AffiliatedInvalidationServiceProvider>>(
            invalidation_service_provider_or_listener_) ||
        std::get<raw_ptr<policy::AffiliatedInvalidationServiceProvider>>(
            invalidation_service_provider_or_listener_))
      << "AffiliatedInvalidationServiceProvider is used but is null";
  CHECK(!std::holds_alternative<raw_ptr<invalidation::InvalidationListener>>(
            invalidation_service_provider_or_listener_) ||
        std::get<raw_ptr<invalidation::InvalidationListener>>(
            invalidation_service_provider_or_listener_))
      << "InvalidationListener is used but is null";
}

CertProvisioningDeviceInvalidatorFactory::
    CertProvisioningDeviceInvalidatorFactory() = default;
CertProvisioningDeviceInvalidatorFactory::
    ~CertProvisioningDeviceInvalidatorFactory() = default;

std::unique_ptr<CertProvisioningInvalidator>
CertProvisioningDeviceInvalidatorFactory::Create() {
  return std::make_unique<CertProvisioningDeviceInvalidator>(
      invalidation::RawPointerVariantToPointer(
          invalidation_service_provider_or_listener_));
}

//=============== CertProvisioningDeviceInvalidator ============================

CertProvisioningDeviceInvalidator::CertProvisioningDeviceInvalidator(
    std::variant<policy::AffiliatedInvalidationServiceProvider*,
                 invalidation::InvalidationListener*>
        invalidation_service_provider_or_listener)
    : invalidation_service_provider_or_listener_(
          invalidation::PointerVariantToRawPointer(
              invalidation_service_provider_or_listener)) {
  CHECK(!std::holds_alternative<
            raw_ptr<policy::AffiliatedInvalidationServiceProvider>>(
            invalidation_service_provider_or_listener_) ||
        std::get<raw_ptr<policy::AffiliatedInvalidationServiceProvider>>(
            invalidation_service_provider_or_listener_))
      << "AffiliatedInvalidationServiceProvider is used but is null";
  CHECK(!std::holds_alternative<raw_ptr<invalidation::InvalidationListener>>(
            invalidation_service_provider_or_listener_) ||
        std::get<raw_ptr<invalidation::InvalidationListener>>(
            invalidation_service_provider_or_listener_))
      << "InvalidationListener is used but is null";
}

CertProvisioningDeviceInvalidator::~CertProvisioningDeviceInvalidator() {
  // As mentioned in the class-level comment, this intentionally doesn't call
  // Unregister so that a subscription can be preserved across process restarts.
  //
  // Note that it is OK to call this even if this instance has not called
  // RegisterConsumer yet.
  std::visit(
      base::Overloaded{
          [this](
              policy::AffiliatedInvalidationServiceProvider* service_provider) {
            service_provider->UnregisterConsumer(this);
          },
          [](invalidation::InvalidationListener* listener) {
            // Do nothing.
          }},
      invalidation_service_provider_or_listener_);
}

void CertProvisioningDeviceInvalidator::Register(
    const invalidation::Topic& topic,
    const std::string& listener_type,
    OnInvalidationEventCallback on_invalidation_event_callback) {
  topic_ = topic;
  listener_type_ = listener_type;
  CHECK(!topic_.empty());
  on_invalidation_event_callback_ = std::move(on_invalidation_event_callback);
  std::visit(
      base::Overloaded{
          [this](
              policy::AffiliatedInvalidationServiceProvider* service_provider) {
            service_provider->RegisterConsumer(this);
          },
          [this](invalidation::InvalidationListener* listener) {
            invalidation_handler_ =
                internal::CertProvisioningInvalidationHandler::BuildAndRegister(
                    CertScope::kDevice, listener, topic_, listener_type_,
                    on_invalidation_event_callback_);
          }},
      invalidation_service_provider_or_listener_);
}

void CertProvisioningDeviceInvalidator::Unregister() {
  std::visit(
      base::Overloaded{
          [this](
              policy::AffiliatedInvalidationServiceProvider* service_provider) {
            service_provider->UnregisterConsumer(this);
          },
          [](invalidation::InvalidationListener* listener) {
            // Do nothing.
          }},
      invalidation_service_provider_or_listener_);
  CertProvisioningInvalidator::Unregister();
  topic_.clear();
}

void CertProvisioningDeviceInvalidator::OnInvalidationServiceSet(
    invalidation::InvalidationService* invalidation_service) {
  // This can only be called after Register() has been called, so the `topic_`
  // must be non-empty.
  CHECK(!topic_.empty());

  // Reset any previously active `invalidation_handler` as it could be referring
  // to the previous `invalidation_service`.
  invalidation_handler_.reset();

  if (!invalidation_service) {
    return;
  }

  invalidation_handler_ =
      internal::CertProvisioningInvalidationHandler::BuildAndRegister(
          CertScope::kDevice, invalidation_service, topic_, listener_type_,
          on_invalidation_event_callback_);
  if (!invalidation_handler_) {
    LOG(ERROR) << "Failed to register for invalidation topic";
  }
}

}  // namespace ash::cert_provisioning