chromium/chrome/browser/certificate_provider/certificate_provider_service.cc

// Copyright 2015 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/certificate_provider/certificate_provider_service.h"

#include <stddef.h>

#include <string>
#include <utility>
#include <vector>

#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/functional/callback_helpers.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/memory/ptr_util.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/task_runner.h"
#include "build/chromeos_buildflags.h"
#include "chrome/browser/certificate_provider/certificate_provider.h"
#include "extensions/common/extension_id.h"
#include "net/base/net_errors.h"

#if BUILDFLAG(IS_CHROMEOS_LACROS)
#include "chromeos/crosapi/mojom/cert_database.mojom.h"
#include "chromeos/lacros/lacros_service.h"
#endif  // IS_CHROMEOS_LACROS

namespace chromeos {

namespace {

void PostSignResult(net::SSLPrivateKey::SignCallback callback,
                    net::Error error,
                    const std::vector<uint8_t>& signature) {
  base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
      FROM_HERE, base::BindOnce(std::move(callback), error, signature));
}

void PostIdentities(
    base::OnceCallback<void(net::ClientCertIdentityList)> callback,
    net::ClientCertIdentityList certs) {
  base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
      FROM_HERE, base::BindOnce(std::move(callback), std::move(certs)));
}

}  // namespace

class CertificateProviderService::CertificateProviderImpl
    : public CertificateProvider {
 public:
  // This provider must be used on the same thread as the
  // CertificateProviderService.
  explicit CertificateProviderImpl(
      const base::WeakPtr<CertificateProviderService>& service);

  CertificateProviderImpl(const CertificateProviderImpl&) = delete;
  CertificateProviderImpl& operator=(const CertificateProviderImpl&) = delete;

  ~CertificateProviderImpl() override;

  void GetCertificates(
      base::OnceCallback<void(net::ClientCertIdentityList)> callback) override;

 private:
  const base::WeakPtr<CertificateProviderService> service_;
  SEQUENCE_CHECKER(sequence_checker_);
};

// Implements an SSLPrivateKey backed by the signing function exposed by an
// extension through the certificateProvider API.
// Objects of this class must be used the CertificateProviderService's sequence.
class CertificateProviderService::SSLPrivateKey : public net::SSLPrivateKey {
 public:
  SSLPrivateKey(const std::string& extension_id,
                const CertificateInfo& cert_info,
                const base::WeakPtr<CertificateProviderService>& service);

  SSLPrivateKey(const SSLPrivateKey&) = delete;
  SSLPrivateKey& operator=(const SSLPrivateKey&) = delete;

  // net::SSLPrivateKey:
  std::string GetProviderName() override;
  std::vector<uint16_t> GetAlgorithmPreferences() override;
  void Sign(uint16_t algorithm,
            base::span<const uint8_t> input,
            SignCallback callback) override;

 private:
  ~SSLPrivateKey() override;

  const extensions::ExtensionId extension_id_;
  const CertificateInfo cert_info_;
  const base::WeakPtr<CertificateProviderService> service_;
  SEQUENCE_CHECKER(sequence_checker_);
};

class CertificateProviderService::ClientCertIdentity
    : public net::ClientCertIdentity {
 public:
  ClientCertIdentity(scoped_refptr<net::X509Certificate> cert,
                     base::WeakPtr<CertificateProviderService> service)
      : net::ClientCertIdentity(std::move(cert)), service_(service) {}

  ClientCertIdentity(const ClientCertIdentity&) = delete;
  ClientCertIdentity& operator=(const ClientCertIdentity&) = delete;

  ~ClientCertIdentity() override;

  void AcquirePrivateKey(
      base::OnceCallback<void(scoped_refptr<net::SSLPrivateKey>)>
          private_key_callback) override;

 private:
  SEQUENCE_CHECKER(sequence_checker_);
  const base::WeakPtr<CertificateProviderService> service_;
};

CertificateProviderService::ClientCertIdentity::~ClientCertIdentity() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}

void CertificateProviderService::ClientCertIdentity::AcquirePrivateKey(
    base::OnceCallback<void(scoped_refptr<net::SSLPrivateKey>)>
        private_key_callback) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  if (!service_) {
    std::move(private_key_callback).Run(nullptr);
    return;
  }

  bool is_currently_provided = false;
  CertificateInfo info;
  std::string extension_id;
  // TODO(mattm): can the ClientCertIdentity store a handle directly to the
  // extension instead of having to go through service_->certificate_map_ ?
  service_->certificate_map_.LookUpCertificate(
      *certificate(), &is_currently_provided, &info, &extension_id);
  if (!is_currently_provided) {
    std::move(private_key_callback).Run(nullptr);
    return;
  }

  std::move(private_key_callback)
      .Run(base::MakeRefCounted<SSLPrivateKey>(extension_id, info, service_));
}

CertificateProviderService::CertificateProviderImpl::CertificateProviderImpl(
    const base::WeakPtr<CertificateProviderService>& service)
    : service_(service) {}

CertificateProviderService::CertificateProviderImpl::
    ~CertificateProviderImpl() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}

void CertificateProviderService::CertificateProviderImpl::GetCertificates(
    base::OnceCallback<void(net::ClientCertIdentityList)> callback) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  // Avoid running the callback reentrantly.
  callback = base::BindOnce(&PostIdentities, std::move(callback));

  if (!service_) {
    std::move(callback).Run(net::ClientCertIdentityList());
    return;
  }
  service_->GetCertificatesFromExtensions(std::move(callback));
}

CertificateProviderService::SSLPrivateKey::SSLPrivateKey(
    const std::string& extension_id,
    const CertificateInfo& cert_info,
    const base::WeakPtr<CertificateProviderService>& service)
    : extension_id_(extension_id), cert_info_(cert_info), service_(service) {}

std::string CertificateProviderService::SSLPrivateKey::GetProviderName() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  return "extension \"" + extension_id_ + "\"";
}

std::vector<uint16_t>
CertificateProviderService::SSLPrivateKey::GetAlgorithmPreferences() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  return cert_info_.supported_algorithms;
}

void CertificateProviderService::SSLPrivateKey::Sign(
    uint16_t algorithm,
    base::span<const uint8_t> input,
    SignCallback callback) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  // The Sign() method should not call its callback reentrantly, so wrap it in
  // PostSignResult().
  callback = base::BindOnce(&PostSignResult, std::move(callback));

  if (!service_) {
    std::move(callback).Run(net::ERR_FAILED, /*signature=*/{});
    return;
  }

  service_->RequestSignatureFromExtension(
      extension_id_, cert_info_.certificate, algorithm, input,
      /*authenticating_user_account_id=*/{}, std::move(callback));
}

CertificateProviderService::SSLPrivateKey::~SSLPrivateKey() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}

CertificateProviderService::CertificateProviderService() {}

CertificateProviderService::~CertificateProviderService() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}

void CertificateProviderService::SetDelegate(
    std::unique_ptr<Delegate> delegate) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  DCHECK(!delegate_);
  DCHECK(delegate);

  delegate_ = std::move(delegate);
}

void CertificateProviderService::AddObserver(Observer* observer) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  observers_.AddObserver(observer);
}

void CertificateProviderService::RemoveObserver(Observer* observer) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  observers_.RemoveObserver(observer);
}

void CertificateProviderService::SetCertificatesProvidedByExtension(
    const std::string& extension_id,
    const CertificateInfoList& certificate_infos) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
#if BUILDFLAG(IS_CHROMEOS_LACROS)
  // Synchronize with Ash-Chrome
  chromeos::LacrosService* service = chromeos::LacrosService::Get();
  if (service && service->IsAvailable<crosapi::mojom::CertDatabase>() &&
      service->GetInterfaceVersion<crosapi::mojom::CertDatabase>() >=
          static_cast<int>(crosapi::mojom::CertDatabase::MethodMinVersions::
                               kSetCertsProvidedByExtensionMinVersion)) {
    service->GetRemote<crosapi::mojom::CertDatabase>()
        ->SetCertsProvidedByExtension(extension_id, certificate_infos);
  }
#endif  // BUILDFLAG(IS_CHROMEOS_LACROS)
  certificate_map_.UpdateCertificatesForExtension(extension_id,
                                                  certificate_infos);
  for (auto& observer : observers_)
    observer.OnCertificatesUpdated(extension_id, certificate_infos);
}

bool CertificateProviderService::SetExtensionCertificateReplyReceived(
    const std::string& extension_id,
    int cert_request_id) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  bool completed = false;
  if (!certificate_requests_.SetExtensionReplyReceived(
          extension_id, cert_request_id, &completed)) {
    DLOG(WARNING) << "Unexpected reply of extension " << extension_id
                  << " to request " << cert_request_id;
    return false;
  }
  if (completed) {
    base::OnceCallback<void(net::ClientCertIdentityList)> callback;
    certificate_requests_.RemoveRequest(cert_request_id, &callback);
    CollectCertificatesAndRun(std::move(callback));
  }
  return true;
}

bool CertificateProviderService::ReplyToSignRequest(
    const std::string& extension_id,
    int sign_request_id,
    const std::vector<uint8_t>& signature) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  // TODO(crbug.com/40671053): Remove logging after stabilizing the feature.
  LOG(WARNING) << "Extension " << extension_id
               << " replied to signature request " << sign_request_id
               << ", size " << signature.size();

  scoped_refptr<net::X509Certificate> certificate;
  net::SSLPrivateKey::SignCallback callback;
  if (!sign_requests_.RemoveRequest(extension_id, sign_request_id, &certificate,
                                    &callback)) {
    return false;
  }
  pin_dialog_manager_.RemoveSignRequest(extension_id, sign_request_id);

  const net::Error error_code = signature.empty() ? net::ERR_FAILED : net::OK;
  std::move(callback).Run(error_code, signature);

  if (!signature.empty()) {
    for (auto& observer : observers_)
      observer.OnSignCompleted(certificate, extension_id);
  }
  return true;
}

bool CertificateProviderService::LookUpCertificate(
    const net::X509Certificate& cert,
    bool* has_extension,
    std::string* extension_id) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  CertificateInfo unused_info;
  return certificate_map_.LookUpCertificate(cert, has_extension, &unused_info,
                                            extension_id);
}

std::unique_ptr<CertificateProvider>
CertificateProviderService::CreateCertificateProvider() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  return std::make_unique<CertificateProviderImpl>(weak_factory_.GetWeakPtr());
}

void CertificateProviderService::OnExtensionUnregistered(
    const std::string& extension_id) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  certificate_map_.RemoveExtension(extension_id);
}

void CertificateProviderService::OnExtensionUnloaded(
    const std::string& extension_id) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  certificate_map_.RemoveExtension(extension_id);

  for (const int completed_cert_request_id :
       certificate_requests_.DropExtension(extension_id)) {
    base::OnceCallback<void(net::ClientCertIdentityList)> callback;
    certificate_requests_.RemoveRequest(completed_cert_request_id, &callback);
    CollectCertificatesAndRun(std::move(callback));
  }

  for (auto& callback : sign_requests_.RemoveAllRequests(extension_id))
    std::move(callback).Run(net::ERR_FAILED, std::vector<uint8_t>());

  pin_dialog_manager_.ExtensionUnloaded(extension_id);
}

void CertificateProviderService::RequestSignatureBySpki(
    const std::string& subject_public_key_info,
    uint16_t algorithm,
    base::span<const uint8_t> input,
    const std::optional<AccountId>& authenticating_user_account_id,
    net::SSLPrivateKey::SignCallback callback) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  bool is_currently_provided = false;
  CertificateInfo info;
  std::string extension_id;
  certificate_map_.LookUpCertificateBySpki(
      subject_public_key_info, &is_currently_provided, &info, &extension_id);
  if (!is_currently_provided) {
    LOG(ERROR) << "no certificate with the specified spki was found";
    std::move(callback).Run(net::ERR_FAILED, std::vector<uint8_t>());
    return;
  }

  RequestSignatureFromExtension(extension_id, info.certificate, algorithm,
                                input, authenticating_user_account_id,
                                std::move(callback));
}

bool CertificateProviderService::LookUpSpki(
    const std::string& subject_public_key_info,
    std::vector<uint16_t>* supported_algorithms,
    std::string* extension_id) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  bool is_currently_provided = false;
  CertificateInfo info;
  certificate_map_.LookUpCertificateBySpki(
      subject_public_key_info, &is_currently_provided, &info, extension_id);
  if (!is_currently_provided) {
    LOG(ERROR) << "no certificate with the specified spki was found";
    return false;
  }
  *supported_algorithms = info.supported_algorithms;
  return true;
}

void CertificateProviderService::AbortSignatureRequestsForAuthenticatingUser(
    const AccountId& authenticating_user_account_id) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  using ExtensionNameRequestIdPair =
      certificate_provider::SignRequests::ExtensionNameRequestIdPair;

  const std::vector<ExtensionNameRequestIdPair> sign_requests_to_abort =
      sign_requests_.FindRequestsForAuthenticatingUser(
          authenticating_user_account_id);

  for (const ExtensionNameRequestIdPair& sign_request :
       sign_requests_to_abort) {
    const std::string& extension_id = sign_request.first;
    const int sign_request_id = sign_request.second;

    // TODO(crbug.com/40671053): Remove logging after stabilizing the feature.
    LOG(WARNING) << "Aborting user login signature request from extension "
                 << extension_id << " id " << sign_request_id;

    pin_dialog_manager_.RemoveSignRequest(extension_id, sign_request_id);

    scoped_refptr<net::X509Certificate> certificate;
    net::SSLPrivateKey::SignCallback sign_callback;
    if (sign_requests_.RemoveRequest(extension_id, sign_request_id,
                                     &certificate, &sign_callback)) {
      std::move(sign_callback).Run(net::ERR_FAILED, std::vector<uint8_t>());
    }
  }
}

void CertificateProviderService::GetCertificatesFromExtensions(
    base::OnceCallback<void(net::ClientCertIdentityList)> callback) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  const std::vector<std::string> provider_extensions(
      delegate_->CertificateProviderExtensions());

  if (provider_extensions.empty()) {
    DVLOG(2) << "No provider extensions left.";
    // Note that there could still be unfinished requests to extensions that
    // were previously registered.
    CollectCertificatesAndRun(std::move(callback));
    return;
  }

  const int cert_request_id = certificate_requests_.AddRequest(
      provider_extensions, std::move(callback),
      base::BindOnce(&CertificateProviderService::TerminateCertificateRequest,
                     base::Unretained(this)));

  DVLOG(2) << "Start certificate request " << cert_request_id;
  delegate_->BroadcastCertificateRequest(cert_request_id);
}

void CertificateProviderService::CollectCertificatesAndRun(
    base::OnceCallback<void(net::ClientCertIdentityList)> callback) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  net::ClientCertIdentityList client_cert_identity_list;
  std::vector<scoped_refptr<net::X509Certificate>> certificates =
      certificate_map_.GetCertificates();
  for (const scoped_refptr<net::X509Certificate>& certificate : certificates) {
    client_cert_identity_list.push_back(std::make_unique<ClientCertIdentity>(
        certificate, weak_factory_.GetWeakPtr()));
  }

  std::move(callback).Run(std::move(client_cert_identity_list));
}

void CertificateProviderService::TerminateCertificateRequest(
    int cert_request_id) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  base::OnceCallback<void(net::ClientCertIdentityList)> callback;
  if (!certificate_requests_.RemoveRequest(cert_request_id, &callback)) {
    DLOG(WARNING) << "Request id " << cert_request_id << " unknown.";
    return;
  }

  DVLOG(1) << "Time out certificate request " << cert_request_id;
  CollectCertificatesAndRun(std::move(callback));
}

void CertificateProviderService::RequestSignatureFromExtension(
    const std::string& extension_id,
    const scoped_refptr<net::X509Certificate>& certificate,
    uint16_t algorithm,
    base::span<const uint8_t> input,
    const std::optional<AccountId>& authenticating_user_account_id,
    net::SSLPrivateKey::SignCallback callback) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  const int sign_request_id = sign_requests_.AddRequest(
      extension_id, certificate, authenticating_user_account_id,
      std::move(callback));

  // TODO(crbug.com/40671053): Remove logging after stabilizing the feature.
  LOG(WARNING) << "Starting signature request to extension " << extension_id
               << " id " << sign_request_id;

  pin_dialog_manager_.AddSignRequestId(extension_id, sign_request_id,
                                       authenticating_user_account_id);
  if (!delegate_->DispatchSignRequestToExtension(
          extension_id, sign_request_id, algorithm, certificate, input)) {
    // TODO(crbug.com/40671053): Remove logging after stabilizing the feature.
    LOG(WARNING) << "Failed to dispatch signature request to extension "
                 << extension_id << " id " << sign_request_id;
    scoped_refptr<net::X509Certificate> local_certificate;
    sign_requests_.RemoveRequest(extension_id, sign_request_id,
                                 &local_certificate, &callback);
    pin_dialog_manager_.RemoveSignRequest(extension_id, sign_request_id);
    std::move(callback).Run(net::ERR_FAILED, std::vector<uint8_t>());
  }
}

}  // namespace chromeos