/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstdint>
#include <initializer_list>
#include <sstream>
#include <string_view>
#include <type_traits>
#include <folly/detail/base64_detail/Base64Scalar.h>
#include <folly/detail/base64_detail/Base64Simd.h>
#include <folly/detail/base64_detail/Base64_SSE4_2.h>
#include <folly/portability/Constexpr.h>
#include <folly/portability/GTest.h>
namespace folly::detail::base64_detail {
namespace {
struct TestCase {
std::string_view data;
std::string_view encodedStd;
std::string_view encodedURL;
};
// The types are weird because we need to do constexpr tests
struct TestCaseOnStack : TestCase {
constexpr TestCaseOnStack(
std::initializer_list<char> dataInit, // initializer list makes it easier
// to put non printable characters
// compared to strings
std::string_view encodedStdParam,
std::string_view encodedURLParam)
: TestCase{{}, encodedStdParam, encodedURLParam}, dataBuf{} {
// std::copy not constexpr
auto f = dataInit.begin(), l = dataInit.end();
auto o = dataBuf.begin();
while (f != l) {
*o++ = *f++;
}
data = std::string_view{
dataBuf.data(), static_cast<size_t>(o - dataBuf.begin())};
}
std::array<char, 200> dataBuf;
} constexpr staticTestCases[] = {
// clang-format off
{
std::initializer_list<char>{}, "", ""
},
{
{ 0, 0, 0 }, "AAAA", "AAAA"
},
{
{1}, "AQ==", "AQ"
},
{
{1, 0}, "AQA=", "AQA",
},
{
{1, 0, 0}, "AQAA", "AQAA"
},
{
{'a','b','c','d'},
"YWJjZA==",
"YWJjZA"
},
{
{'a', 'b', 'c'},
"YWJj",
"YWJj"
},
{
{'l','e','s','s',' ',
'i','s',' ',
'm','o','r','e',' ',
't','h','a','n',' ',
'm','o','r','e'},
"bGVzcyBpcyBtb3JlIHRoYW4gbW9yZQ==",
"bGVzcyBpcyBtb3JlIHRoYW4gbW9yZQ",
},
{
{'<','>', '?','s','u'},
"PD4/c3U=",
"PD4_c3U"
},
// clang-format on
};
std::string byteRangeToString(std::string_view data) {
std::stringstream res;
res << '{';
auto f = data.begin();
auto l = data.end();
if (f != l) {
res << static_cast<std::int32_t>(*f);
while (++f != l) {
res << ", " << static_cast<std::int32_t>(*f);
}
}
res << '}';
return std::move(res).str();
}
template <typename TestRunner>
constexpr bool staticTests(TestRunner testRunner) {
for (const auto& test : staticTestCases) {
if (!testRunner(TestCase{test})) {
return false;
}
}
return true;
}
template <typename I, typename N, typename V>
constexpr I fill_n(I f, N n, V v) {
while (n--) {
*f++ = v;
}
return f;
}
template <typename TestRunner>
constexpr bool manyZeroesTests(TestRunner testRunner) {
for (std::size_t inSize = 0; inSize != 129; ++inSize) {
TestCase test;
// Populate input
std::array<char, 256> buf = {}; // fill in 0s
buf[inSize + 1] = 15; // messing with the input
test.data = {buf.data(), inSize};
// Populate expected
std::array<char, 256> expectedBuf = {};
char* expectedL = fill_n(expectedBuf.begin(), inSize / 3 * 4, 'A');
char* expectedURLL = expectedL;
if (inSize % 3 == 2) {
*expectedL++ = 'A';
*expectedL++ = 'A';
*expectedL++ = 'A';
expectedURLL = expectedL;
*expectedL++ = '=';
} else if (inSize % 3 == 1) {
*expectedL++ = 'A';
*expectedL++ = 'A';
expectedURLL = expectedL;
*expectedL++ = '=';
*expectedL++ = '=';
}
test.encodedStd =
std::string_view(expectedBuf.data(), expectedL - expectedBuf.data());
test.encodedURL =
std::string_view(expectedBuf.data(), expectedURLL - expectedBuf.data());
// Run test
if (!testRunner(test))
return false;
}
return true;
}
template <typename TestRunner>
constexpr bool runEncodeTests(TestRunner testRunner) {
return staticTests(testRunner) && manyZeroesTests(testRunner);
}
// In constexpr we can have a non constexpr expression, as long
// as it is not evaluated.
//
// There was a gcc bug with respect to it. Luckily this just affects how
// much useful information will be output in case of a test failure.
#if defined(__GNUC__) && !defined(__clang__)
#define GCC_CONSTEXPR_BUG_ACTIVE
#endif
struct ConstexprTester {
constexpr bool encodeTest(TestCase test) const {
std::array<char, 1000> buf = {};
char* end =
base64EncodeScalar(test.data.begin(), test.data.end(), buf.data());
std::string_view actual(buf.data(), end - buf.data());
if (test.encodedStd == actual) {
return true;
}
#ifndef GCC_CONSTEXPR_BUG_ACTIVE
EXPECT_EQ(test.encodedStd, actual)
<< "Regular encoding mismatch. Input data:\n"
<< byteRangeToString(test.data);
#endif
return false;
}
constexpr bool decodeTest(TestCase test) const {
std::array<char, 1000> buf = {};
auto res = base64DecodeScalar(
test.encodedStd.data(),
test.encodedStd.data() + test.encodedStd.size(),
buf.begin());
std::string_view decoded(buf.begin(), res.o - buf.begin());
if (res.isSuccess && test.data == decoded) {
return true;
}
#ifndef GCC_CONSTEXPR_BUG_ACTIVE
EXPECT_TRUE(res.isSuccess) << "encoded: " << test.encodedStd;
EXPECT_EQ(test.data, decoded) << "encoded: " << test.encodedStd;
#endif
return false;
}
constexpr bool encodeURLTest(TestCase test) const {
std::array<char, 1000> buf = {};
char* end =
base64URLEncodeScalar(test.data.begin(), test.data.end(), buf.data());
std::string_view actual(buf.data(), end - buf.data());
if (test.encodedURL == actual) {
return true;
}
#ifndef GCC_CONSTEXPR_BUG_ACTIVE
EXPECT_EQ(test.encodedURL, actual) << "URL encoding mismatch. Input data:\n"
<< byteRangeToString(test.data);
#endif
return false;
}
constexpr bool decodeURLTest(TestCase test) const {
auto oneInput = [&](std::string_view encoded) {
std::array<char, 1000> buf = {};
auto res = base64URLDecodeScalar(
encoded.data(), encoded.data() + encoded.size(), buf.begin());
std::string_view decoded(buf.begin(), res.o - buf.begin());
if (res.isSuccess && test.data == decoded) {
return true;
}
#ifndef GCC_CONSTEXPR_BUG_ACTIVE
EXPECT_TRUE(res.isSuccess) << "encoded: " << encoded;
EXPECT_EQ(test.data, decoded) << "encoded: " << encoded;
#endif
return false;
};
return oneInput(test.encodedStd) && oneInput(test.encodedURL);
}
constexpr bool sizeTests(TestCase test) const {
std::size_t encodedSize = base64EncodedSize(test.data.size());
std::size_t encodedURLSize = base64URLEncodedSize(test.data.size());
std::size_t decodedSize = base64DecodedSize(
test.encodedStd.data(),
test.encodedStd.data() + test.encodedStd.size());
std::size_t decodedURLSize = base64URLDecodedSize(
test.encodedURL.data(),
test.encodedURL.data() + test.encodedURL.size());
std::size_t decodedStdWithURlSize = base64URLDecodedSize(
test.encodedStd.data(),
test.encodedStd.data() + test.encodedStd.size());
if (encodedSize == test.encodedStd.size() &&
encodedURLSize == test.encodedURL.size() &&
decodedSize == test.data.size() && decodedURLSize == test.data.size() &&
decodedStdWithURlSize == test.data.size()) {
return true;
}
#ifndef GCC_CONSTEXPR_BUG_ACTIVE
EXPECT_EQ(test.encodedStd.size(), encodedSize) << test.encodedStd;
EXPECT_EQ(test.encodedURL.size(), encodedURLSize) << test.encodedURL;
EXPECT_EQ(test.data.size(), decodedSize) << test.encodedStd;
EXPECT_EQ(test.data.size(), decodedStdWithURlSize) << test.encodedStd;
EXPECT_EQ(test.data.size(), decodedURLSize) << test.encodedURL;
#endif
return false;
}
constexpr bool operator()(TestCase test) const {
return sizeTests(test) && encodeTest(test) && encodeURLTest(test) &&
decodeTest(test) && decodeURLTest(test);
}
};
struct SimdTester {
using Encode = char* (*)(const char*, const char*, char*);
using Decode = Base64DecodeResult (*)(const char*, const char*, char*);
Encode encode;
Encode encodeURL;
Decode decode;
Decode decodeURL;
bool encodeTest(TestCase test) const {
std::string actual(base64EncodedSize(test.data.size()), 0);
encode(test.data.begin(), test.data.end(), actual.data());
if (test.encodedStd == actual) {
return true;
}
EXPECT_EQ(test.encodedStd, actual)
<< "Regular encoding mismatch. Input data:\n"
<< byteRangeToString(test.data);
return false;
}
bool encodeURLTest(TestCase test) const {
std::string actual(base64URLEncodedSize(test.data.size()), 0);
encodeURL(test.data.begin(), test.data.end(), actual.data());
if (test.encodedURL == actual) {
return true;
}
EXPECT_EQ(test.encodedStd, actual) << "URL encoding mismatch. Input data:\n"
<< byteRangeToString(test.data);
return false;
}
bool decodeTest(TestCase test) const {
std::string actual(
base64DecodedSize(
test.encodedStd.data(),
test.encodedStd.data() + test.encodedStd.size()),
0);
auto decodedResult = decode(
test.encodedStd.data(),
test.encodedStd.data() + test.encodedStd.size(),
actual.data());
if (decodedResult.isSuccess && test.data == actual) {
return true;
}
EXPECT_TRUE(decodedResult.isSuccess) << byteRangeToString(test.data);
EXPECT_EQ(test.data, actual) << byteRangeToString(test.data);
return false;
}
bool decodeURLTest(TestCase test) const {
auto oneInput = [&](std::string_view encoded) {
std::string decoded(
base64URLDecodedSize(encoded.data(), encoded.data() + encoded.size()),
0);
auto res = decodeURL(
encoded.data(), encoded.data() + encoded.size(), decoded.data());
if (res.isSuccess && test.data == decoded) {
return true;
}
EXPECT_TRUE(res.isSuccess) << "encoded: " << encoded;
EXPECT_EQ(test.data, decoded) << "encoded: " << encoded;
return false;
};
return oneInput(test.encodedStd) && oneInput(test.encodedURL);
}
bool operator()(TestCase test) const {
return encodeTest(test) && encodeURLTest(test) && decodeTest(test) &&
decodeURLTest(test);
}
};
TEST(Base64, ConstexprTests) {
// Comment out the static assert to debug
static_assert(runEncodeTests(ConstexprTester{}));
ASSERT_TRUE(runEncodeTests(ConstexprTester{}));
}
TEST(Base64, SpecialCases) {
ASSERT_TRUE(runEncodeTests(SimdTester{
base64EncodeScalar,
base64URLEncodeScalar,
base64DecodeSWAR,
base64URLDecodeSWAR}));
#if FOLLY_SSE_PREREQ(4, 2)
ASSERT_TRUE(runEncodeTests(SimdTester{
base64Encode_SSE4_2,
base64URLEncode_SSE4_2,
base64Decode_SSE4_2,
base64URLDecodeSWAR}));
#endif
}
constexpr char kHasNegative0[] = {'A', 'b', 'c', -15, '\0'};
constexpr char kHasNegative1[] = {'a', 'b', 'c', 'd', 'a', -15, 'c',
'd', 'a', 'b', 'c', 'd', 'a', 'b',
'c', 'd', 'a', 'b', 'c', 'd', '\0'};
struct DecodingErrorDetectionTest {
bool isSuccess;
std::string_view input;
} constexpr kDecodingErrorDection[] = {
// clang-format off
{ true, "" },
{ false, "=" },
{ false, "==" },
{ false, "A" },
{ false, "B=" },
{ false, "ba=" },
{ true, "0w==" },
{ true, "000=" },
{ false, "===" },
{ false, "0===" },
{ false, "aa=0" },
{ false, "aaaa""aaaa""aaaa""aaaa""0" },
{ true, "aaaa""aaaa""aaaa""aaaa""0w==" },
{ true, "0aaa""aaaa""aaaa""aaaa""aaaa""aaaa" },
{ false, "$aaa""aaaa""aaaa""aaaa""aaaa""aaaa" },
{ false, "aaaa""aa$a""aaaa""aaaa""aaaa""aaaa" },
{ false, "aaaaa"},
{ false, kHasNegative0 },
{ false, kHasNegative1 },
// clang-format on
};
constexpr std::string_view kDecodingOnlyURLValid[] = {
"ba",
"ba__",
"ba__ba--ba__",
"bA_/0a--ba+_",
"_-==",
"iZ==",
"00==",
"997=",
"+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-",
"bA_/0a--ba+_bA_/0a--ba+_bA_/0a--ba+_bA_/0a--ba+_bA_/0a--ba+_bA_/0a--ba+_bA_/0a--ba+_bA_/0a--ba+_"
"bA_/0a--ba+_bA_/0a--ba+_bA_/0a--ba+_bA_/0a--ba+_bA_/0a--ba+_bA_/0a--ba+_bA_/0a--ba+_bA_/0a--ba+_",
"bA_/0a--ba+_bA_/0a--ba+_bA_/0a--ba+_bA_/0a--ba+_bA_/0a--ba+_bA_/0a--ba+_bA_/0a--ba+_bA_/0a--ba==",
};
template <bool isURLDecoder>
constexpr size_t decodedSize(std::string_view in) {
const char* f = in.data();
const char* l = in.data() + in.size();
if constexpr (isURLDecoder) {
return base64URLDecodedSize(f, l);
} else {
return base64DecodedSize(f, l);
}
}
template <bool isURLDecoder, typename Decoder>
void triggerASANOnBadDecode(std::string_view in, Decoder decoder) {
std::vector<char> buf(decodedSize<isURLDecoder>(in));
decoder(in.data(), in.data() + in.size(), buf.data());
}
template <bool isURLDecoder, typename Decoder>
constexpr bool decodingErrorDectionTest(Decoder decoder) {
std::array<char, 1000> buf = {};
auto sizeTest = [&](std::string_view in, Base64DecodeResult r) {
std::size_t allocatedSize = decodedSize<isURLDecoder>(in);
std::size_t usedSize = static_cast<std::size_t>(r.o - buf.data());
if (usedSize == allocatedSize) {
return true;
}
if (r.isSuccess) {
#ifndef GCC_CONSTEXPR_BUG_ACTIVE
EXPECT_EQ(usedSize, allocatedSize) << in << " isURL: " << isURLDecoder;
#endif
return false;
}
if (allocatedSize > 1000 || // overflow
usedSize > allocatedSize) {
#ifndef GCC_CONSTEXPR_BUG_ACTIVE
EXPECT_LE(usedSize, allocatedSize) << in << " isURL: " << isURLDecoder;
#endif
return false;
}
return true;
};
for (const auto& test : kDecodingErrorDection) {
if (!folly::is_constant_evaluated_or(true)) {
triggerASANOnBadDecode<isURLDecoder>(test.input, decoder);
}
auto r = decoder(
test.input.data(), test.input.data() + test.input.size(), buf.data());
if (test.isSuccess != r.isSuccess) {
#ifndef GCC_CONSTEXPR_BUG_ACTIVE
EXPECT_EQ(test.isSuccess, r.isSuccess) << test.input;
#endif
return false;
}
if (!sizeTest(test.input, r)) {
return false;
}
}
for (std::string_view URLOnly : kDecodingOnlyURLValid) {
if (!folly::is_constant_evaluated_or(true)) {
triggerASANOnBadDecode<isURLDecoder>(URLOnly, decoder);
}
auto r =
decoder(URLOnly.data(), URLOnly.data() + URLOnly.size(), buf.data());
if (isURLDecoder != r.isSuccess) {
#ifndef GCC_CONSTEXPR_BUG_ACTIVE
EXPECT_EQ(isURLDecoder, r.isSuccess) << URLOnly;
#endif
return false;
}
if (!sizeTest(URLOnly, r)) {
return false;
}
}
return true;
}
TEST(Base64, DecodingErrorDeteciton) {
static_assert(decodingErrorDectionTest<false>(base64DecodeScalar));
static_assert(decodingErrorDectionTest<true>(base64URLDecodeScalar));
ASSERT_TRUE(decodingErrorDectionTest<false>(base64DecodeScalar));
ASSERT_TRUE(decodingErrorDectionTest<true>(base64URLDecodeScalar));
ASSERT_TRUE(decodingErrorDectionTest<false>(base64DecodeSWAR));
ASSERT_TRUE(decodingErrorDectionTest<true>(base64URLDecodeSWAR));
#if FOLLY_SSE_PREREQ(4, 2)
ASSERT_TRUE(decodingErrorDectionTest<false>(base64Decode_SSE4_2));
#endif
}
} // namespace
} // namespace folly::detail::base64_detail