chromium/chrome/browser/ash/net/client_cert_store_kcer.cc

// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/ash/net/client_cert_store_kcer.h"

#include <algorithm>
#include <iterator>
#include <utility>

#include "ash/constants/ash_features.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/functional/callback_helpers.h"
#include "base/location.h"
#include "base/strings/stringprintf.h"
#include "base/task/thread_pool.h"
#include "base/threading/scoped_blocking_call.h"
#include "chrome/browser/ash/net/client_cert_filter.h"
#include "chrome/browser/certificate_provider/certificate_provider.h"
#include "chromeos/components/kcer/client_cert_identity_kcer.h"
#include "chromeos/components/kcer/kcer.h"
#include "net/ssl/client_cert_store_nss.h"
#include "net/ssl/ssl_cert_request_info.h"
#include "net/ssl/ssl_private_key.h"

namespace ash {
namespace {
net::ClientCertIdentityList FilterCertsOnWorkerThread(
    scoped_refptr<const net::SSLCertRequestInfo> request,
    net::ClientCertIdentityList client_certs) {
  net::ClientCertStoreNSS::FilterCertsOnWorkerThread(&client_certs, *request);
  return client_certs;
}
}  // namespace

//==============================================================================

ClientCertStoreKcer::ClientCertStoreKcer(
    std::unique_ptr<chromeos::CertificateProvider> cert_provider,
    base::WeakPtr<kcer::Kcer> kcer)
    : cert_provider_(std::move(cert_provider)), kcer_(std::move(kcer)) {}

ClientCertStoreKcer::~ClientCertStoreKcer() {}

void ClientCertStoreKcer::GetClientCerts(
    scoped_refptr<const net::SSLCertRequestInfo> cert_request_info,
    ClientCertListCallback callback) {
  auto get_kcer_certs = base::BindOnce(
      &ClientCertStoreKcer::GetKcerCerts, weak_factory_.GetWeakPtr(),
      std::move(cert_request_info), std::move(callback));

  if (cert_provider_) {
    cert_provider_->GetCertificates(std::move(get_kcer_certs));
  } else {
    std::move(get_kcer_certs).Run(net::ClientCertIdentityList());
  }
}

void ClientCertStoreKcer::GetKcerCerts(
    scoped_refptr<const net::SSLCertRequestInfo> request,
    ClientCertListCallback callback,
    net::ClientCertIdentityList additional_certs) {
  if (!kcer_) {
    return GotAllCerts(std::move(request), std::move(callback),
                       std::move(additional_certs));
  }

  // Fetch all tokens that are available in the current context.
  kcer_->GetAvailableTokens(base::BindOnce(
      &ClientCertStoreKcer::GotKcerTokens, weak_factory_.GetWeakPtr(),
      std::move(request), std::move(callback), std::move(additional_certs)));
}

void ClientCertStoreKcer::GotKcerTokens(
    scoped_refptr<const net::SSLCertRequestInfo> request,
    ClientCertListCallback callback,
    net::ClientCertIdentityList additional_certs,
    base::flat_set<kcer::Token> tokens) {
  if (!kcer_) {
    return GotAllCerts(std::move(request), std::move(callback),
                       std::move(additional_certs));
  }

  kcer_->ListCerts(
      std::move(tokens),
      base::BindOnce(&ClientCertStoreKcer::GotKcerCerts,
                     weak_factory_.GetWeakPtr(), std::move(request),
                     std::move(callback), std::move(additional_certs)));
}

void ClientCertStoreKcer::GotKcerCerts(
    scoped_refptr<const net::SSLCertRequestInfo> request,
    ClientCertListCallback callback,
    net::ClientCertIdentityList additional_certs,
    std::vector<scoped_refptr<const kcer::Cert>> kcer_certs,
    base::flat_map<kcer::Token, kcer::Error> kcer_errors) {
  if (!kcer_) {
    return GotAllCerts(std::move(request), std::move(callback),
                       std::move(additional_certs));
  }

  for (auto& [k, v] : kcer_errors) {
    // Log errors, if any, and continue as usual with the remaining certs.
    LOG(ERROR) << base::StringPrintf(
        "Failed to get certs from token:%d, error:%d", static_cast<uint32_t>(k),
        static_cast<uint32_t>(v));
  }

  additional_certs.reserve(additional_certs.size() + kcer_certs.size());
  for (scoped_refptr<const kcer::Cert>& cert : kcer_certs) {
    if (!cert || !cert->GetX509Cert()) {
      // Probably shouldn't happen, but double check just in case.
      continue;
    }

    additional_certs.push_back(
        std::make_unique<kcer::ClientCertIdentityKcer>(kcer_, std::move(cert)));
  }

  return GotAllCerts(std::move(request), std::move(callback),
                     std::move(additional_certs));
}

void ClientCertStoreKcer::GotAllCerts(
    scoped_refptr<const net::SSLCertRequestInfo> request,
    ClientCertListCallback callback,
    net::ClientCertIdentityList certs) {
  base::ThreadPool::PostTaskAndReplyWithResult(
      FROM_HERE,
      {base::MayBlock(), base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN},
      base::BindOnce(&FilterCertsOnWorkerThread, std::move(request),
                     std::move(certs)),
      base::BindOnce(&ClientCertStoreKcer::ReturnClientCerts,
                     weak_factory_.GetWeakPtr(), std::move(callback)));
}

void ClientCertStoreKcer::ReturnClientCerts(
    ClientCertListCallback callback,
    net::ClientCertIdentityList identities) {
  std::move(callback).Run(std::move(identities));
}

}  // namespace ash