/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/io/async/SSLContext.h>
#include <folly/FileUtil.h>
#include <folly/io/async/test/SSLUtil.h>
#include <folly/portability/GTest.h>
#include <folly/portability/OpenSSL.h>
#include <folly/ssl/OpenSSLCertUtils.h>
#include <folly/ssl/OpenSSLKeyUtils.h>
#include <folly/ssl/OpenSSLPtrTypes.h>
#include <folly/testing/TestUtil.h>
#if !defined(FOLLY_CERTS_DIR)
#define FOLLY_CERTS_DIR "folly/io/async/test/certs"
#endif
using namespace std;
using folly::test::find_resource;
namespace folly {
class SSLContextTest : public testing::Test {
public:
SSLContext ctx;
void verifySSLCipherList(const vector<string>& ciphers);
void verifySSLCiphersuites(const vector<string>& ciphersuites);
};
void SSLContextTest::verifySSLCipherList(const vector<string>& ciphers) {
ssl::SSLUniquePtr ssl(ctx.createSSL());
EXPECT_EQ(ciphers, test::getNonTLS13CipherList(ssl.get()));
}
void SSLContextTest::verifySSLCiphersuites(const vector<string>& ciphersuites) {
ssl::SSLUniquePtr ssl(ctx.createSSL());
EXPECT_EQ(ciphersuites, test::getTLS13Ciphersuites(ssl.get()));
}
TEST_F(SSLContextTest, TestSetCipherString) {
ctx.ciphers("AES128-SHA:ECDHE-RSA-AES256-SHA384");
verifySSLCipherList({"AES128-SHA", "ECDHE-RSA-AES256-SHA384"});
}
TEST_F(SSLContextTest, TestSetCipherList) {
const vector<string> ciphers = {"ECDHE-RSA-AES128-SHA", "AES256-SHA"};
ctx.setCipherList(ciphers);
verifySSLCipherList(ciphers);
}
TEST_F(SSLContextTest, TestCipherRemoval) {
ctx.setCipherList({"ECDHE-RSA-AES128-SHA", "AES256-SHA"});
{
ssl::SSLUniquePtr ssl(ctx.createSSL());
auto ciphers = test::getCiphersFromSSL(ssl.get());
EXPECT_TRUE(
std::find(begin(ciphers), end(ciphers), "AES256-SHA") != end(ciphers));
}
ctx.setCipherList({"ECDHE-RSA-AES128-SHA"});
{
ssl::SSLUniquePtr ssl(ctx.createSSL());
auto ciphers = test::getCiphersFromSSL(ssl.get());
EXPECT_FALSE(
std::find(begin(ciphers), end(ciphers), "AES256-SHA") != end(ciphers));
}
}
TEST_F(SSLContextTest, TestLoadCertKey) {
std::string certData, keyData, anotherKeyData;
const char* certPath = FOLLY_CERTS_DIR "/tests-cert.pem";
const char* keyPath = FOLLY_CERTS_DIR "/tests-key.pem";
const char* anotherKeyPath = FOLLY_CERTS_DIR "/client_key.pem";
folly::readFile(find_resource(certPath).c_str(), certData);
folly::readFile(find_resource(keyPath).c_str(), keyData);
folly::readFile(find_resource(anotherKeyPath).c_str(), anotherKeyData);
{
SCOPED_TRACE("Valid cert/key pair from buffer");
SSLContext tmpCtx;
tmpCtx.loadCertificateFromBufferPEM(certData);
tmpCtx.loadPrivateKeyFromBufferPEM(keyData);
EXPECT_TRUE(tmpCtx.isCertKeyPairValid());
}
{
SCOPED_TRACE("Valid cert/key pair from files");
SSLContext tmpCtx;
tmpCtx.loadCertificate(find_resource(certPath).c_str());
tmpCtx.loadPrivateKey(find_resource(keyPath).c_str());
EXPECT_TRUE(tmpCtx.isCertKeyPairValid());
}
{
SCOPED_TRACE("Invalid cert/key pair from file. Load cert first");
SSLContext tmpCtx;
tmpCtx.loadCertificate(find_resource(certPath).c_str());
EXPECT_THROW(
tmpCtx.loadPrivateKey(find_resource(anotherKeyPath).c_str()),
std::runtime_error);
}
{
SCOPED_TRACE("Invalid cert/key pair from file. Load key first");
SSLContext tmpCtx;
tmpCtx.loadPrivateKey(find_resource(anotherKeyPath).c_str());
tmpCtx.loadCertificate(find_resource(certPath).c_str());
EXPECT_FALSE(tmpCtx.isCertKeyPairValid());
}
{
SCOPED_TRACE("Invalid key/cert pair from buf. Load cert first");
SSLContext tmpCtx;
tmpCtx.loadCertificateFromBufferPEM(certData);
EXPECT_THROW(
tmpCtx.loadPrivateKeyFromBufferPEM(anotherKeyData), std::runtime_error);
}
{
SCOPED_TRACE("Invalid key/cert pair from buf. Load key first");
SSLContext tmpCtx;
tmpCtx.loadPrivateKeyFromBufferPEM(anotherKeyData);
tmpCtx.loadCertificateFromBufferPEM(certData);
EXPECT_FALSE(tmpCtx.isCertKeyPairValid());
}
{
SCOPED_TRACE(
"loadCertKeyPairFromBufferPEM() must throw when cert/key mismatch");
SSLContext tmpCtx;
EXPECT_THROW(
tmpCtx.loadCertKeyPairFromBufferPEM(certData, anotherKeyData),
std::runtime_error);
}
{
SCOPED_TRACE(
"loadCertKeyPairFromBufferPEM() must succeed when cert/key match");
SSLContext tmpCtx;
tmpCtx.loadCertKeyPairFromBufferPEM(certData, keyData);
}
{
SCOPED_TRACE(
"loadCertKeyPairFromFiles() must throw when cert/key mismatch");
SSLContext tmpCtx;
EXPECT_THROW(
tmpCtx.loadCertKeyPairFromFiles(
find_resource(certPath).c_str(),
find_resource(anotherKeyPath).c_str()),
std::runtime_error);
}
{
SCOPED_TRACE("loadCertKeyPairFromFiles() must succeed when cert/key match");
SSLContext tmpCtx;
tmpCtx.loadCertKeyPairFromFiles(
find_resource(certPath).c_str(), find_resource(keyPath).c_str());
}
}
TEST_F(SSLContextTest, TestLoadCertificateChain) {
constexpr auto kCertChainPath = FOLLY_CERTS_DIR "/client_chain.pem";
auto path = find_resource(kCertChainPath);
std::unique_ptr<SSLContext> ctx2;
STACK_OF(X509) * stack;
SSL_CTX* sctx;
std::string contents;
EXPECT_TRUE(folly::readFile(path.c_str(), contents));
ctx2 = std::make_unique<SSLContext>();
ctx2->loadCertificate(path.c_str(), "PEM");
stack = nullptr;
sctx = ctx2->getSSLCtx();
SSL_CTX_get0_chain_certs(sctx, &stack);
ASSERT_NE(stack, nullptr);
EXPECT_EQ(1, sk_X509_num(stack));
ctx2 = std::make_unique<SSLContext>();
ctx2->loadCertificateFromBufferPEM(contents);
stack = nullptr;
sctx = ctx2->getSSLCtx();
SSL_CTX_get0_chain_certs(sctx, &stack);
ASSERT_NE(stack, nullptr);
EXPECT_EQ(1, sk_X509_num(stack));
}
TEST_F(SSLContextTest, TestSetCertificateChainKeyPair) {
constexpr auto kCertChainPath = FOLLY_CERTS_DIR "/client_chain.pem";
constexpr auto kKeyPath = FOLLY_CERTS_DIR "/clienti_key.pem";
constexpr auto kAnotherKeyPath = FOLLY_CERTS_DIR "/tests-key.pem";
std::string certChainData, keyData, anotherKeyData;
folly::readFile(find_resource(kCertChainPath).c_str(), certChainData);
folly::readFile(find_resource(kKeyPath).c_str(), keyData);
folly::readFile(find_resource(kAnotherKeyPath).c_str(), anotherKeyData);
{
SCOPED_TRACE("Set valid cert chaing and key pair.");
auto certChain =
ssl::OpenSSLCertUtils::readCertsFromBuffer(ByteRange(certChainData));
auto pKey =
ssl::OpenSSLKeyUtils::readPrivateKeyFromBuffer(ByteRange(keyData));
SSLContext tmpCtx;
tmpCtx.setCertChainKeyPair(std::move(certChain), std::move(pKey));
EXPECT_TRUE(tmpCtx.isCertKeyPairValid());
}
{
auto certChain =
ssl::OpenSSLCertUtils::readCertsFromBuffer(ByteRange(certChainData));
auto anotherPKey = ssl::OpenSSLKeyUtils::readPrivateKeyFromBuffer(
ByteRange(anotherKeyData));
SCOPED_TRACE("setCertChainKeyPair() must throw when cert/key mismatch");
SSLContext tmpCtx;
EXPECT_THROW(
tmpCtx.setCertChainKeyPair(
std::move(certChain), std::move(anotherPKey)),
std::runtime_error);
}
}
TEST_F(SSLContextTest, TestSetSupportedClientCAs) {
constexpr auto kCertChainPath = FOLLY_CERTS_DIR "/client_chain.pem";
ctx.setSupportedClientCertificateAuthorityNamesFromFile(
find_resource(kCertChainPath).c_str());
STACK_OF(X509_NAME)* names = SSL_CTX_get_client_CA_list(ctx.getSSLCtx());
EXPECT_EQ(2, sk_X509_NAME_num(names));
static const char* kExpectedCNs[] = {"Leaf Certificate", "Intermediate CA"};
for (int i = 0; i < sk_X509_NAME_num(names); i++) {
auto name = sk_X509_NAME_value(names, i);
int indexCN = X509_NAME_get_index_by_NID(name, NID_commonName, -1);
EXPECT_NE(indexCN, -1);
auto entry = X509_NAME_get_entry(name, indexCN);
ASSERT_NE(entry, nullptr);
auto asnStringCN = X509_NAME_ENTRY_get_data(entry);
std::string commonName(
reinterpret_cast<const char*>(ASN1_STRING_get0_data(asnStringCN)),
ASN1_STRING_length(asnStringCN));
EXPECT_EQ(commonName, std::string(kExpectedCNs[i]));
}
}
TEST_F(SSLContextTest, TestGetFromSSLCtx) {
// Positive test
SSLContext* contextPtr = SSLContext::getFromSSLCtx(ctx.getSSLCtx());
EXPECT_EQ(contextPtr, &ctx);
// Negative test
SSL_CTX* randomCtx = SSL_CTX_new(TLS_method());
EXPECT_EQ(nullptr, SSLContext::getFromSSLCtx(randomCtx));
SSL_CTX_free(randomCtx);
}
TEST_F(SSLContextTest, TestInvalidSigAlgThrows) {
{
SSLContext tmpCtx;
EXPECT_THROW(tmpCtx.setSigAlgsOrThrow(""), std::runtime_error);
}
{
SSLContext tmpCtx;
EXPECT_THROW(
tmpCtx.setSigAlgsOrThrow("rsa_pss_rsae_sha512:ECDSA+SHA256:RSA+HA256"),
std::runtime_error);
}
}
TEST_F(SSLContextTest, TestSetCiphersuites) {
std::vector<std::string> ciphersuitesList{
"TLS_AES_128_CCM_SHA256",
"TLS_AES_128_GCM_SHA256",
};
std::string ciphersuites;
folly::join(":", ciphersuitesList, ciphersuites);
ctx.setCiphersuitesOrThrow(ciphersuites);
verifySSLCiphersuites(ciphersuitesList);
}
TEST_F(SSLContextTest, TestSetInvalidCiphersuite) {
EXPECT_THROW(
ctx.setCiphersuitesOrThrow("ECDHE-ECDSA-AES256-GCM-SHA384"),
std::runtime_error);
}
TEST_F(SSLContextTest, TestTLS13MinVersion) {
SSLContext sslContext{SSLContext::SSLVersion::TLSv1_3};
int minProtoVersion = SSL_CTX_get_min_proto_version(sslContext.getSSLCtx());
EXPECT_EQ(minProtoVersion, TLS1_3_VERSION);
}
TEST_F(SSLContextTest, AdvertisedNextProtocols) {
EXPECT_EQ(ctx.getAdvertisedNextProtocols(), "");
ctx.setAdvertisedNextProtocols({"blub"});
EXPECT_EQ(ctx.getAdvertisedNextProtocols(), "blub");
ctx.setAdvertisedNextProtocols({"foo", "bar", "baz"});
EXPECT_EQ(ctx.getAdvertisedNextProtocols(), "foo,bar,baz");
}
} // namespace folly