chromium/chrome/browser/ash/policy/invalidation/affiliated_invalidation_service_provider_impl.cc

// Copyright 2015 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/ash/policy/invalidation/affiliated_invalidation_service_provider_impl.h"

#include <memory>
#include <vector>

#include "base/functional/bind.h"
#include "base/memory/raw_ptr.h"
#include "chrome/browser/ash/profiles/profile_helper.h"
#include "chrome/browser/browser_process.h"
#include "chrome/browser/device_identity/device_identity_provider.h"
#include "chrome/browser/device_identity/device_oauth2_token_service_factory.h"
#include "chrome/browser/invalidation/profile_invalidation_provider_factory.h"
#include "chrome/browser/net/system_network_context_manager.h"
#include "chrome/browser/profiles/profile.h"
#include "chrome/browser/profiles/profile_manager.h"
#include "components/gcm_driver/instance_id/instance_id_driver.h"
#include "components/invalidation/impl/fcm_invalidation_service.h"
#include "components/invalidation/impl/fcm_network_handler.h"
#include "components/invalidation/impl/per_user_topic_subscription_manager.h"
#include "components/invalidation/invalidation_factory.h"
#include "components/invalidation/invalidation_listener.h"
#include "components/invalidation/profile_invalidation_provider.h"
#include "components/invalidation/public/identity_provider.h"
#include "components/invalidation/public/invalidation_handler.h"
#include "components/invalidation/public/invalidation_service.h"
#include "components/invalidation/public/invalidator_state.h"
#include "components/policy/core/common/cloud/cloud_policy_constants.h"
#include "components/user_manager/user.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"

namespace policy {

namespace {

invalidation::ProfileInvalidationProvider* GetInvalidationProvider(
    Profile* profile) {
  return invalidation::ProfileInvalidationProviderFactory::GetForProfile(
      profile);
}

}  // namespace

class AffiliatedInvalidationServiceProviderImpl::InvalidationServiceObserver
    : public invalidation::InvalidationHandler {
 public:
  explicit InvalidationServiceObserver(
      AffiliatedInvalidationServiceProviderImpl* parent,
      invalidation::InvalidationService* invalidation_service);

  InvalidationServiceObserver(const InvalidationServiceObserver&) = delete;
  InvalidationServiceObserver& operator=(const InvalidationServiceObserver&) =
      delete;

  ~InvalidationServiceObserver() override;

  invalidation::InvalidationService* GetInvalidationService();
  bool IsServiceConnected() const;

  // public invalidation::InvalidationHandler:
  void OnInvalidatorStateChange(invalidation::InvalidatorState state) override;
  void OnIncomingInvalidation(
      const invalidation::Invalidation& invalidation) override;
  std::string GetOwnerName() const override;

 private:
  raw_ptr<AffiliatedInvalidationServiceProviderImpl> parent_;
  const raw_ptr<invalidation::InvalidationService> invalidation_service_;
  bool is_service_connected_;
  bool is_observer_ready_;

  base::ScopedObservation<invalidation::InvalidationService,
                          invalidation::InvalidationHandler>
      invalidation_service_observation_{this};
};

AffiliatedInvalidationServiceProviderImpl::InvalidationServiceObserver::
    InvalidationServiceObserver(
        AffiliatedInvalidationServiceProviderImpl* parent,
        invalidation::InvalidationService* invalidation_service)
    : parent_(parent),
      invalidation_service_(invalidation_service),
      is_service_connected_(false),
      is_observer_ready_(false) {
  DCHECK(invalidation_service_);
  invalidation_service_observation_.Observe(invalidation_service_);
  is_service_connected_ = invalidation_service->GetInvalidatorState() ==
                          invalidation::InvalidatorState::kEnabled;
  is_observer_ready_ = true;
}

AffiliatedInvalidationServiceProviderImpl::InvalidationServiceObserver::
    ~InvalidationServiceObserver() {
  is_observer_ready_ = false;
}

invalidation::InvalidationService* AffiliatedInvalidationServiceProviderImpl::
    InvalidationServiceObserver::GetInvalidationService() {
  return invalidation_service_;
}

bool AffiliatedInvalidationServiceProviderImpl::InvalidationServiceObserver::
    IsServiceConnected() const {
  return is_service_connected_;
}

void AffiliatedInvalidationServiceProviderImpl::InvalidationServiceObserver::
    OnInvalidatorStateChange(invalidation::InvalidatorState state) {
  if (!is_observer_ready_) {
    return;
  }

  const bool new_is_service_connected =
      (state == invalidation::InvalidatorState::kEnabled);

  if (is_service_connected_ == new_is_service_connected) {
    return;
  }

  is_service_connected_ = new_is_service_connected;
  if (is_service_connected_) {
    parent_->OnInvalidationServiceConnected(invalidation_service_);
  } else {
    parent_->OnInvalidationServiceDisconnected(invalidation_service_);
  }
}

void AffiliatedInvalidationServiceProviderImpl::InvalidationServiceObserver::
    OnIncomingInvalidation(const invalidation::Invalidation& invalidation) {}

std::string AffiliatedInvalidationServiceProviderImpl::
    InvalidationServiceObserver::GetOwnerName() const {
  return "AffiliatedInvalidationService";
}

AffiliatedInvalidationServiceProviderImpl::
    AffiliatedInvalidationServiceProviderImpl()
    : current_invalidation_service_(nullptr),
      consumer_count_(0),
      is_shut_down_(false) {
  // The AffiliatedInvalidationServiceProviderImpl should be created before any
  // user Profiles.
  DCHECK(g_browser_process->profile_manager()->GetLoadedProfiles().empty());

  // Subscribe to notification about new user profiles becoming available.
  session_observation_.Observe(session_manager::SessionManager::Get());
}

AffiliatedInvalidationServiceProviderImpl::
    ~AffiliatedInvalidationServiceProviderImpl() {
  // Verify that the provider was shut down first.
  DCHECK(is_shut_down_);
}

void AffiliatedInvalidationServiceProviderImpl::OnUserProfileLoaded(
    const AccountId& account_id) {
  DCHECK(!is_shut_down_);
  Profile* profile =
      ash::ProfileHelper::Get()->GetProfileByAccountId(account_id);
  invalidation::ProfileInvalidationProvider* invalidation_provider =
      GetInvalidationProvider(profile);
  if (!invalidation_provider) {
    // If the Profile does not support invalidation (e.g. guest, incognito),
    // ignore it.
    return;
  }
  const user_manager::User* user =
      ash::ProfileHelper::Get()->GetUserByProfile(profile);
  if (!user || !user->IsAffiliated()) {
    // If the Profile belongs to a user who is not affiliated on the device,
    // ignore it.
    return;
  }

  // Create a state observer for the user's invalidation service.
  auto invalidation_service_or_listener =
      invalidation_provider->GetInvalidationServiceOrListener(
          policy::kPolicyFCMInvalidationSenderID,
          invalidation::InvalidationListener::kProjectNumberEnterprise);
  CHECK(std::holds_alternative<invalidation::InvalidationService*>(
      invalidation_service_or_listener))
      << "AffiliatedInvalidationServiceProviderImpl is created with "
         "InvalidationListener setup";
  auto* invalidation_service = std::get<invalidation::InvalidationService*>(
      invalidation_service_or_listener);
  profile_invalidation_service_observers_.push_back(
      std::make_unique<InvalidationServiceObserver>(this,
                                                    invalidation_service));
  if (profile_invalidation_service_observers_.back()->IsServiceConnected()) {
    // If the invalidation service is connected, check whether to switch to it.
    OnInvalidationServiceConnected(invalidation_service);
  }
}

void AffiliatedInvalidationServiceProviderImpl::RegisterConsumer(
    Consumer* consumer) {
  if (consumers_.HasObserver(consumer) || is_shut_down_) {
    return;
  }

  consumers_.AddObserver(consumer);
  ++consumer_count_;

  if (current_invalidation_service_) {
    consumer->OnInvalidationServiceSet(current_invalidation_service_);
  } else if (consumer_count_ == 1) {
    FindConnectedInvalidationService();
  }
}

void AffiliatedInvalidationServiceProviderImpl::UnregisterConsumer(
    Consumer* consumer) {
  if (!consumers_.HasObserver(consumer)) {
    return;
  }

  consumers_.RemoveObserver(consumer);
  --consumer_count_;

  if (current_invalidation_service_ && consumer_count_ == 0) {
    current_invalidation_service_ = nullptr;
    DestroyDeviceInvalidationService();
  }
}

void AffiliatedInvalidationServiceProviderImpl::Shutdown() {
  is_shut_down_ = true;

  session_observation_.Reset();
  profile_invalidation_service_observers_.clear();
  device_invalidation_service_observer_.reset();

  if (current_invalidation_service_) {
    current_invalidation_service_ = nullptr;
    // Explicitly notify consumers that the invalidation service they were using
    // is no longer available.
    SetCurrentInvalidationService(nullptr);
  }

  DestroyDeviceInvalidationService();
}

invalidation::InvalidationService*
AffiliatedInvalidationServiceProviderImpl::GetDeviceInvalidationServiceForTest()
    const {
  return device_invalidation_service_.get();
}

void AffiliatedInvalidationServiceProviderImpl::OnInvalidationServiceConnected(
    invalidation::InvalidationService* invalidation_service) {
  DCHECK(!is_shut_down_);

  if (consumer_count_ == 0) {
    // If there are no consumers, no invalidation service is required.
    return;
  }

  if (!device_invalidation_service_) {
    // The lack of a device-global invalidation service implies that another
    // connected invalidation service is being made available to consumers
    // already. There is no need to switch from that to the service which just
    // connected.
    return;
  }

  // Make the invalidation service that just connected available to consumers.
  current_invalidation_service_ = nullptr;
  SetCurrentInvalidationService(invalidation_service);

  if (current_invalidation_service_ && device_invalidation_service_ &&
      current_invalidation_service_ != device_invalidation_service_.get()) {
    // If a different invalidation service is being made available to consumers
    // now, destroy the device-global one.
    DestroyDeviceInvalidationService();
  }
}

void AffiliatedInvalidationServiceProviderImpl::
    OnInvalidationServiceDisconnected(
        invalidation::InvalidationService* invalidation_service) {
  DCHECK(!is_shut_down_);

  if (invalidation_service != current_invalidation_service_) {
    // If the invalidation service which disconnected was not being made
    // available to consumers, return.
    return;
  }

  // The invalidation service which disconnected was being made available to
  // consumers. Stop making it available.
  DCHECK(consumer_count_);
  current_invalidation_service_ = nullptr;

  // Try to make another invalidation service available to consumers.
  FindConnectedInvalidationService();

  // If no other connected invalidation service was found, explicitly notify
  // consumers that the invalidation service they were using is no longer
  // available.
  if (!current_invalidation_service_) {
    SetCurrentInvalidationService(nullptr);
  }
}

void AffiliatedInvalidationServiceProviderImpl::
    FindConnectedInvalidationService() {
  DCHECK(!current_invalidation_service_);
  DCHECK(consumer_count_);
  DCHECK(!is_shut_down_);

  for (const auto& observer : profile_invalidation_service_observers_) {
    if (observer->IsServiceConnected()) {
      // If a connected invalidation service belonging to an affiliated
      // logged-in user is found, make it available to consumers.
      DestroyDeviceInvalidationService();
      SetCurrentInvalidationService(observer->GetInvalidationService());
      return;
    }
  }

  if (!device_invalidation_service_) {
    // If no other connected invalidation service was found and no device-global
    // invalidation service exists, create one.
    device_invalidation_service_ = InitializeDeviceInvalidationService();
    device_invalidation_service_observer_ =
        std::make_unique<InvalidationServiceObserver>(
            this, device_invalidation_service_.get());
  }

  if (device_invalidation_service_observer_->IsServiceConnected()) {
    // If the device-global invalidation service is connected already, make it
    // available to consumers immediately. Otherwise, the invalidation service
    // will be made available to clients when it successfully connects.
    OnInvalidationServiceConnected(device_invalidation_service_.get());
  }
}

void AffiliatedInvalidationServiceProviderImpl::SetCurrentInvalidationService(
    invalidation::InvalidationService* invalidation_service) {
  DCHECK(!current_invalidation_service_);
  current_invalidation_service_ = invalidation_service;
  for (auto& observer : consumers_) {
    observer.OnInvalidationServiceSet(current_invalidation_service_);
  }
}

void AffiliatedInvalidationServiceProviderImpl::
    DestroyDeviceInvalidationService() {
  device_invalidation_service_observer_.reset();
  device_invalidation_service_.reset();
  device_identity_provider_.reset();
  device_instance_id_driver_.reset();
}

std::unique_ptr<invalidation::InvalidationService>
AffiliatedInvalidationServiceProviderImpl::
    InitializeDeviceInvalidationService() {
  scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory;
  if (g_browser_process->system_network_context_manager()) {
    // system_network_context_manager() can be null during unit tests.
    url_loader_factory = g_browser_process->system_network_context_manager()
                             ->GetSharedURLLoaderFactory();
    DCHECK(url_loader_factory);
  }

  device_identity_provider_ = std::make_unique<DeviceIdentityProvider>(
      DeviceOAuth2TokenServiceFactory::Get());

  device_instance_id_driver_ = std::make_unique<instance_id::InstanceIDDriver>(
      g_browser_process->gcm_driver());

  DCHECK(device_instance_id_driver_);
  auto invalidation_service_or_listener =
      invalidation::CreateInvalidationServiceOrListener(
          device_identity_provider_.get(), g_browser_process->gcm_driver(),
          device_instance_id_driver_.get(), url_loader_factory,
          g_browser_process->local_state(), kPolicyFCMInvalidationSenderID,
          /*project_number=*/"", /*log_prefix=*/"");
  CHECK(std::holds_alternative<
        std::unique_ptr<invalidation::InvalidationService>>(
      invalidation_service_or_listener))
      << "AffiliatedInvalidationServiceProviderImpl is created with "
         "InvalidationListener setup";
  return std::move(std::get<std::unique_ptr<invalidation::InvalidationService>>(
      invalidation_service_or_listener));
}

}  // namespace policy