folly/folly/detail/base64_detail/tests/Base64PlatformTest.cpp

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <algorithm>
#include <array>
#include <cstdint>
#include <cstring>
#include <numeric>
#include <vector>
#include <folly/portability/GTest.h>

#include <folly/detail/base64_detail/Base64_SSE4_2_Platform.h>

namespace folly::detail::base64_detail {
namespace {
#if FOLLY_SSE_PREREQ(4, 2)

std::array<std::uint8_t, 16> expectedEncodeToIndexes(
    std::array<std::uint8_t, 16> in) {
  std::array<std::uint8_t, 16> res{};

  std::uint8_t const* f = in.data();
  std::uint8_t* o = res.data();
  std::uint8_t* const oEnd = res.data() + res.size();

  while (o != oEnd) {
    std::uint8_t aaab = f[0];
    std::uint8_t bbcc = f[1];
    std::uint8_t cddd = f[2];

    std::uint8_t aaa = aaab >> 2;
    std::uint8_t bbb = ((aaab << 4) | (bbcc >> 4)) & 0x3f;
    std::uint8_t ccc = ((bbcc << 2) | (cddd >> 6)) & 0x3f;
    std::uint8_t ddd = cddd & 0x3f;

    o[0] = aaa;
    o[1] = bbb;
    o[2] = ccc;
    o[3] = ddd;

    f += 3;
    o += 4;
  }

  return res;
}

std::array<std::uint8_t, 16> expectedPackIndexesToBytes(
    std::array<std::uint8_t, 16> in) {
  std::array<std::uint8_t, 16> res{};
  res.fill(0);

  std::uint8_t const* f = in.data();
  std::uint8_t const* const inEnd = in.data() + in.size();
  std::uint8_t* o = res.data();

  while (f != inEnd) {
    std::uint8_t aaa = f[0];
    std::uint8_t bbb = f[1];
    std::uint8_t ccc = f[2];
    std::uint8_t ddd = f[3];

    std::uint8_t aaab = (aaa << 2) | (bbb >> 4);
    std::uint8_t bbcc = (bbb << 4) | (ccc >> 2);
    std::uint8_t cddd = (ccc << 6) | ddd;

    o[0] = aaab;
    o[1] = bbcc;
    o[2] = cddd;

    f += 4;
    o += 3;
  }

  return res;
}

constexpr char kBase64EncodeTable[] =
    "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
    "="; // \0 is also at the end.

std::array<std::uint8_t, 16> expectedLookupByIndex(
    std::array<std::uint8_t, 16> in, const char* sampleTable) {
  std::array<std::uint8_t, 16> res{};

  for (std::size_t i = 0; i != in.size(); ++i) {
    res[i] = static_cast<std::uint8_t>(sampleTable[in[i]]);
  }

  return res;
}

std::array<std::uint8_t, 16> expectedSuccessfullDecodeToIndex(
    std::array<std::uint8_t, 16> in) {
  std::array<std::uint8_t, 16> r = {};
  for (std::size_t i = 0; i != in.size(); ++i) {
    if ('A' <= in[i] && in[i] <= 'Z') {
      r[i] = in[i] - 'A';
    } else if ('a' <= in[i] && in[i] <= 'z') {
      r[i] = in[i] - 'a' + 26;
    } else if ('0' <= in[i] && in[i] <= '9') {
      r[i] = in[i] - '0' + 26 * 2;
    } else if ('+' == in[i]) {
      r[i] = 62;
    } else if ('/' == in[i]) {
      r[i] = 63;
    } else {
      return {};
    }
  }
  return r;
}

template <typename Platform>
struct Base64PlatformTest : ::testing::Test {
  using platform = Platform;
  using RegBytesArray = std::array<std::uint8_t, platform::kRegisterSize>;

  static RegBytesArray actualEncodeToIndexes(RegBytesArray from) {
    RegBytesArray res;
    auto reg = platform::encodeToIndexes(platform::loadu(from.data()));
    platform::storeu(res.data(), reg);
    return res;
  }

  static RegBytesArray actualLookupByIndex(RegBytesArray from) {
    RegBytesArray res;
    auto reg = platform::lookupByIndex(
        platform::loadu(from.data()), constants::kEncodeTable.data());
    platform::storeu(res.data(), reg);
    return res;
  }

  static RegBytesArray actualSuccesfullDecodeToIndex(RegBytesArray from) {
    RegBytesArray res;
    auto err = platform::initError();
    auto reg = platform::decodeToIndex(platform::loadu(from.data()), err);
    EXPECT_FALSE(platform::hasErrors(err));
    platform::storeu(res.data(), reg);
    return res;
  }

  static RegBytesArray actualPackIndexesToBytes(RegBytesArray from) {
    auto reg = platform::packIndexesToBytes(platform::loadu(from.data()));
    RegBytesArray res;
    platform::storeu(res.data(), reg);
    return res;
  }
};

TYPED_TEST_SUITE(Base64PlatformTest, ::testing::Types<Base64_SSE4_2_Platform>);

TYPED_TEST(Base64PlatformTest, EncodeToIndexes) {
  using RegBytes = typename TestFixture::RegBytesArray;

  for (std::uint16_t v = 0; v != 256; v += 8) {
    RegBytes in;
    std::iota(in.data(), in.data() + in.size(), static_cast<std::uint8_t>(v));

    RegBytes expected = expectedEncodeToIndexes(in);
    RegBytes actual = TestFixture::actualEncodeToIndexes(in);

    EXPECT_EQ(expected, actual);
  }
}

TYPED_TEST(Base64PlatformTest, IndexLookup) {
  using RegBytes = typename TestFixture::RegBytesArray;

  std::uint8_t max_index =
      std::strlen(kBase64EncodeTable) + 1; // to include '\0'

  for (std::uint8_t i = 0; i != max_index + 1 - RegBytes{}.size(); i += 1) {
    RegBytes in;
    std::iota(in.data(), in.data() + in.size(), i);
    RegBytes expected = expectedLookupByIndex(in, kBase64EncodeTable);
    RegBytes actual = TestFixture::actualLookupByIndex(in);
    ASSERT_EQ(expected, actual);
  }
}

TYPED_TEST(Base64PlatformTest, errorDetection) {
  using RegBytes = typename TestFixture::RegBytesArray;

  auto anyErrors = [](const RegBytes& arr) {
    using pl = typename TestFixture::platform;

    auto errorAccum = pl::initError();
    pl::decodeToIndex(pl::loadu(arr.data()), errorAccum);
    return pl::hasErrors(errorAccum);
  };

  constexpr char kValidChar = 'A';

  RegBytes in;
  in.fill(kValidChar);
  ASSERT_FALSE(anyErrors(in));

  for (std::size_t sym = 0; //
       sym != std::numeric_limits<std::uint8_t>::max() + 1;
       ++sym) {
    bool isValid = //
        (sym == '+') || //
        (sym == '/') || //
        ('0' <= sym && sym <= '9') || //
        ('A' <= sym && sym <= 'Z') || //
        ('a' <= sym && sym <= 'z');
    for (auto& inByte : in) {
      inByte = static_cast<std::uint8_t>(sym);

      ASSERT_EQ(anyErrors(in), !isValid) << std::hex << sym << std::dec;
      inByte = kValidChar;
    }
  }
}

TYPED_TEST(Base64PlatformTest, decodeToIndexSuccess) {
  using RegBytes = typename TestFixture::RegBytesArray;

  // Some cases
  for (std::uint16_t v = 0; v < 256; v += 1) {
    RegBytes in;
    std::iota(in.data(), in.data() + in.size(), static_cast<std::uint8_t>(v));

    for (auto& x : in) {
      x = x % 64;
    }

    RegBytes encoded = expectedLookupByIndex(in, kBase64EncodeTable);
    RegBytes expected = expectedSuccessfullDecodeToIndex(encoded);
    RegBytes actual = TestFixture::actualSuccesfullDecodeToIndex(encoded);

    ASSERT_EQ(expected, actual) << v;
  }
}

TYPED_TEST(Base64PlatformTest, packIndexesToBytes) {
  using RegBytes = typename TestFixture::RegBytesArray;

  for (std::uint16_t v = 0; v < 256; v += 1) {
    RegBytes in;
    in.fill(0);
    std::iota(
        in.data(), in.data() + in.size() / 4 * 3, static_cast<std::uint8_t>(v));

    for (auto& x : in) {
      x = x % 64;
    }

    RegBytes expected = expectedPackIndexesToBytes(in);
    ASSERT_EQ(in, expectedEncodeToIndexes(expected)) << "sanity check";

    RegBytes actual = TestFixture::actualPackIndexesToBytes(in);
    ASSERT_EQ(expected, actual);
  }
}
#endif // FOLLY_SSE_PREREQ(4, 2)

} // namespace
} // namespace folly::detail::base64_detail