chromium/chrome/browser/certificate_provider/test_certificate_provider_extension.cc

// Copyright 2019 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/certificate_provider/test_certificate_provider_extension.h"

#include <cstdint>
#include <optional>
#include <string_view>
#include <utility>
#include <vector>

#include "base/containers/span.h"
#include "base/files/file_path.h"
#include "base/files/file_util.h"
#include "base/json/json_reader.h"
#include "base/json/json_writer.h"
#include "base/logging.h"
#include "base/numerics/safe_conversions.h"
#include "base/path_service.h"
#include "base/strings/string_util.h"
#include "base/threading/thread_restrictions.h"
#include "base/values.h"
#include "chrome/common/chrome_paths.h"
#include "chrome/common/extensions/api/certificate_provider.h"
#include "content/public/browser/browser_context.h"
#include "crypto/rsa_private_key.h"
#include "extensions/browser/api/test/test_api.h"
#include "extensions/browser/event_router.h"
#include "extensions/common/api/test.h"
#include "net/cert/asn1_util.h"
#include "net/cert/x509_certificate.h"
#include "net/cert/x509_util.h"
#include "net/test/cert_test_util.h"
#include "net/test/test_data_directory.h"
#include "third_party/boringssl/src/include/openssl/rsa.h"
#include "third_party/boringssl/src/include/openssl/ssl.h"

namespace ash {

namespace {

constexpr char kExtensionId[] = "ecmhnokcdiianioonpgakiooenfnonid";
// Paths relative to |chrome::DIR_TEST_DATA|:
constexpr base::FilePath::CharType kExtensionPath[] =
    FILE_PATH_LITERAL("extensions/test_certificate_provider/extension/");
constexpr base::FilePath::CharType kExtensionPemPath[] =
    FILE_PATH_LITERAL("extensions/test_certificate_provider/extension.pem");

// List of algorithms that the extension claims to support for the returned
// certificates.
constexpr extensions::api::certificate_provider::Algorithm
    kSupportedAlgorithms[] = {
        extensions::api::certificate_provider::Algorithm::
            kRsassaPkcs1V1_5Sha256,
        extensions::api::certificate_provider::Algorithm::kRsassaPkcs1V1_5Sha1};

base::Value ConvertBytesToValue(base::span<const uint8_t> bytes) {
  base::Value::List value;
  for (auto byte : bytes)
    value.Append(byte);
  return base::Value(std::move(value));
}

std::vector<uint8_t> ExtractBytesFromValue(const base::Value& value) {
  std::vector<uint8_t> bytes;
  for (const base::Value& item_value : value.GetList())
    bytes.push_back(base::checked_cast<uint8_t>(item_value.GetInt()));
  return bytes;
}

base::span<const uint8_t> GetCertDer(const net::X509Certificate& certificate) {
  return base::as_bytes(base::make_span(
      net::x509_util::CryptoBufferAsStringPiece(certificate.cert_buffer())));
}

base::Value MakeClientCertificateInfoValue(
    const net::X509Certificate& certificate) {
  base::Value::Dict cert_info_value;
  base::Value::List certificate_chain;
  certificate_chain.Append(ConvertBytesToValue(GetCertDer(certificate)));
  cert_info_value.Set("certificateChain", std::move(certificate_chain));
  base::Value::List supported_algorithms_value;
  for (auto supported_algorithm : kSupportedAlgorithms) {
    supported_algorithms_value.Append(
        extensions::api::certificate_provider::ToString(supported_algorithm));
  }
  cert_info_value.Set("supportedAlgorithms",
                      std::move(supported_algorithms_value));
  return base::Value(std::move(cert_info_value));
}

std::string ConvertValueToJson(const base::Value& value) {
  std::string json;
  CHECK(base::JSONWriter::Write(value, &json));
  return json;
}

base::Value ParseJsonToValue(const std::string& json) {
  std::optional<base::Value> value = base::JSONReader::Read(json);
  CHECK(value);
  return std::move(*value);
}

bool RsaSignRawData(crypto::RSAPrivateKey* key,
                    uint16_t openssl_signature_algorithm,
                    const std::vector<uint8_t>& input,
                    std::vector<uint8_t>* signature) {
  const EVP_MD* const digest_algorithm =
      SSL_get_signature_algorithm_digest(openssl_signature_algorithm);
  bssl::ScopedEVP_MD_CTX ctx;
  EVP_PKEY_CTX* pkey_ctx = nullptr;
  if (!EVP_DigestSignInit(ctx.get(), &pkey_ctx, digest_algorithm,
                          /*ENGINE* e=*/nullptr, key->key()))
    return false;
  if (SSL_is_signature_algorithm_rsa_pss(openssl_signature_algorithm)) {
    // For RSA-PSS, configure the special padding and set the salt length to be
    // equal to the hash size.
    if (!EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, RSA_PKCS1_PSS_PADDING) ||
        !EVP_PKEY_CTX_set_rsa_pss_saltlen(pkey_ctx, /*salt_len=*/-1)) {
      return false;
    }
  }
  size_t sig_len = 0;
  // Determine the signature length for the buffer.
  if (!EVP_DigestSign(ctx.get(), /*out_sig=*/nullptr, &sig_len, input.data(),
                      input.size()))
    return false;
  signature->resize(sig_len);
  return EVP_DigestSign(ctx.get(), signature->data(), &sig_len, input.data(),
                        input.size()) != 0;
}

void SendReplyToJs(ExtensionTestMessageListener* message_listener,
                   const base::Value& response) {
  message_listener->Reply(ConvertValueToJson(response));
  message_listener->Reset();
}

std::unique_ptr<crypto::RSAPrivateKey> LoadPrivateKeyFromFile(
    const base::FilePath& path) {
  std::string key_pk8;
  {
    base::ScopedAllowBlockingForTesting allow_io;
    EXPECT_TRUE(base::ReadFileToString(path, &key_pk8));
  }
  return crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(
      base::as_bytes(base::make_span(key_pk8)));
}

}  // namespace

// static
extensions::ExtensionId TestCertificateProviderExtension::extension_id() {
  return kExtensionId;
}

// static
base::FilePath TestCertificateProviderExtension::GetExtensionSourcePath() {
  return base::PathService::CheckedGet(chrome::DIR_TEST_DATA)
      .Append(kExtensionPath);
}

// static
base::FilePath TestCertificateProviderExtension::GetExtensionPemPath() {
  return base::PathService::CheckedGet(chrome::DIR_TEST_DATA)
      .Append(kExtensionPemPath);
}

// static
scoped_refptr<net::X509Certificate>
TestCertificateProviderExtension::GetCertificate() {
  return net::ImportCertFromFile(net::GetTestCertsDirectory(), "client_1.pem");
}

// static
std::string TestCertificateProviderExtension::GetCertificateSpki() {
  const scoped_refptr<net::X509Certificate> certificate = GetCertificate();
  std::string_view spki_bytes;
  if (!net::asn1::ExtractSPKIFromDERCert(
          net::x509_util::CryptoBufferAsStringPiece(certificate->cert_buffer()),
          &spki_bytes)) {
    return {};
  }
  return std::string(spki_bytes);
}

TestCertificateProviderExtension::TestCertificateProviderExtension(
    content::BrowserContext* browser_context)
    : browser_context_(browser_context),
      certificate_(GetCertificate()),
      private_key_(LoadPrivateKeyFromFile(net::GetTestCertsDirectory().Append(
          FILE_PATH_LITERAL("client_1.pk8")))),
      message_listener_(ReplyBehavior::kWillReply) {
  DCHECK(browser_context_);
  CHECK(certificate_);
  CHECK(private_key_);
  // Ignore messages targeted to other extensions or browser contexts.
  message_listener_.set_extension_id(kExtensionId);
  message_listener_.set_browser_context(browser_context);
  message_listener_.SetOnRepeatedlySatisfied(
      base::BindRepeating(&TestCertificateProviderExtension::HandleMessage,
                          base::Unretained(this)));
}

TestCertificateProviderExtension::~TestCertificateProviderExtension() = default;

void TestCertificateProviderExtension::TriggerSetCertificates() {
  base::Value::Dict message_data;
  message_data.Set("name", "setCertificates");
  base::Value::List cert_info_values;
  if (should_provide_certificates_)
    cert_info_values.Append(MakeClientCertificateInfoValue(*certificate_));
  message_data.Set("certificateInfoList", std::move(cert_info_values));

  base::Value::List message;
  message.Append(std::move(message_data));
  auto event = std::make_unique<extensions::Event>(
      extensions::events::FOR_TEST,
      extensions::api::test::OnMessage::kEventName, std::move(message),
      browser_context_);
  extensions::EventRouter::Get(browser_context_)
      ->DispatchEventToExtension(extension_id(), std::move(event));
}

void TestCertificateProviderExtension::HandleMessage(
    const std::string& message) {
  // Handle the request and reply to it (possibly, asynchronously).
  base::Value message_value = ParseJsonToValue(message);
  CHECK(message_value.is_list());
  base::Value::List& message_list = message_value.GetList();
  CHECK(message_list.size());
  CHECK(message_list[0].is_string());
  const std::string& request_type = message_list[0].GetString();
  ReplyToJsCallback send_reply_to_js_callback =
      base::BindOnce(&SendReplyToJs, &message_listener_);
  if (request_type == "getCertificates") {
    CHECK_EQ(message_list.size(), 1U);
    HandleCertificatesRequest(std::move(send_reply_to_js_callback));
  } else if (request_type == "onSignatureRequested") {
    CHECK_EQ(message_list.size(), 4U);
    HandleSignatureRequest(
        /*sign_request=*/message_list[1],
        /*pin_status=*/message_list[2],
        /*pin=*/message_list[3], std::move(send_reply_to_js_callback));
  } else {
    LOG(FATAL) << "Unexpected JS message type: " << request_type;
  }
}

void TestCertificateProviderExtension::HandleCertificatesRequest(
    ReplyToJsCallback callback) {
  ++certificate_request_count_;
  base::Value::List cert_info_values;
  if (should_provide_certificates_)
    cert_info_values.Append(MakeClientCertificateInfoValue(*certificate_));
  std::move(callback).Run(base::Value(std::move(cert_info_values)));
}

void TestCertificateProviderExtension::HandleSignatureRequest(
    const base::Value& sign_request,
    const base::Value& pin_status,
    const base::Value& pin,
    ReplyToJsCallback callback) {
  CHECK_EQ(*sign_request.GetDict().Find("certificate"),
           ConvertBytesToValue(GetCertDer(*certificate_)));
  const std::string pin_status_string = pin_status.GetString();
  const std::string pin_string = pin.GetString();

  const int sign_request_id =
      sign_request.GetDict().FindInt("signRequestId").value();
  const std::vector<uint8_t> input =
      ExtractBytesFromValue(*sign_request.GetDict().Find("input"));

  const extensions::api::certificate_provider::Algorithm algorithm =
      extensions::api::certificate_provider::ParseAlgorithm(
          *sign_request.GetDict().FindString("algorithm"));
  int openssl_signature_algorithm = 0;
  if (algorithm == extensions::api::certificate_provider::Algorithm::
                       kRsassaPkcs1V1_5Sha256) {
    openssl_signature_algorithm = SSL_SIGN_RSA_PKCS1_SHA256;
  } else if (algorithm == extensions::api::certificate_provider::Algorithm::
                              kRsassaPkcs1V1_5Sha1) {
    openssl_signature_algorithm = SSL_SIGN_RSA_PKCS1_SHA1;
  } else {
    LOG(FATAL) << "Unexpected signature request algorithm: "
               << extensions::api::certificate_provider::ToString(algorithm);
  }

  if (should_fail_sign_digest_requests_) {
    // Simulate a failure.
    std::move(callback).Run(/*response=*/base::Value());
    return;
  }

  base::Value::Dict response;
  if (required_pin_.has_value()) {
    if (pin_status_string == "not_requested") {
      // The PIN is required but not specified yet, so request it via the JS
      // side before generating the signature.
      base::Value::Dict pin_request_parameters;
      pin_request_parameters.Set("signRequestId", sign_request_id);
      if (remaining_pin_attempts_ == 0) {
        pin_request_parameters.Set("errorType", "MAX_ATTEMPTS_EXCEEDED");
      }
      response.Set("requestPin", std::move(pin_request_parameters));
      std::move(callback).Run(base::Value(std::move(response)));
      return;
    }
    if (remaining_pin_attempts_ == 0) {
      // The error about the lockout is already displayed, so fail immediately.
      std::move(callback).Run(/*response=*/base::Value());
      return;
    }
    if (pin_status_string == "canceled" ||
        base::StartsWith(pin_status_string,
                         "failed:", base::CompareCase::SENSITIVE)) {
      // The PIN request failed.
      LOG(WARNING) << "PIN request failed: " << pin_status_string;
      // Respond with a failure.
      std::move(callback).Run(/*response=*/base::Value());
      return;
    }
    DCHECK_EQ(pin_status_string, "ok");
    if (pin_string != *required_pin_) {
      // The entered PIN is wrong, so decrement the remaining attempt count, and
      // update the PIN dialog with displaying an error.
      if (remaining_pin_attempts_ > 0)
        --remaining_pin_attempts_;
      base::Value::Dict pin_request_parameters;
      pin_request_parameters.Set("signRequestId", sign_request_id);
      pin_request_parameters.Set("errorType", remaining_pin_attempts_ == 0
                                                  ? "MAX_ATTEMPTS_EXCEEDED"
                                                  : "INVALID_PIN");
      if (remaining_pin_attempts_ > 0) {
        pin_request_parameters.Set("attemptsLeft", remaining_pin_attempts_);
      }
      response.Set("requestPin", std::move(pin_request_parameters));
      std::move(callback).Run(base::Value(std::move(response)));
      return;
    }
    // The entered PIN is correct. Stop the PIN request and proceed to
    // generating the signature.
    base::Value::Dict stop_pin_request_parameters;
    stop_pin_request_parameters.Set("signRequestId", sign_request_id);
    response.Set("stopPinRequest", std::move(stop_pin_request_parameters));
  }
  // Generate and return a valid signature.
  std::vector<uint8_t> signature;
  CHECK(RsaSignRawData(private_key_.get(), openssl_signature_algorithm, input,
                       &signature));
  response.Set("signature", ConvertBytesToValue(signature));
  std::move(callback).Run(base::Value(std::move(response)));
}

}  // namespace ash