chromium/third_party/shell-encryption/src/symmetric_encryption_with_prng_test.cc

/*
 * Copyright 2018 Google LLC.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "symmetric_encryption_with_prng.h"

#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "context.h"
#include "montgomery.h"
#include "ntt_parameters.h"
#include "polynomial.h"
#include "prng/integral_prng_types.h"
#include "status_macros.h"
#include "testing/parameters.h"
#include "testing/status_matchers.h"
#include "testing/status_testing.h"
#include "testing/testing_utils.h"

namespace rlwe {
namespace {

// Set constants.
const unsigned int kTestingRounds = 10;

template <typename ModularInt>
class SymmetricEncryptionWithPrngTest : public ::testing::Test {
 public:
  using Key = SymmetricRlweKey<ModularInt>;

  // Sample a random key.
  rlwe::StatusOr<Key> SampleKey(const rlwe::RlweContext<ModularInt>* context) {
    RLWE_ASSIGN_OR_RETURN(std::string prng_seed,
                          rlwe::SingleThreadPrng::GenerateSeed());
    RLWE_ASSIGN_OR_RETURN(auto prng, rlwe::SingleThreadPrng::Create(prng_seed));
    return Key::Sample(context->GetLogN(), context->GetVariance(),
                       context->GetLogT(), context->GetModulusParams(),
                       context->GetNttParams(), prng.get());
  }

  rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> ConvertPlaintextsToNtt(
      const std::vector<std::vector<typename ModularInt::Int>>& coeffs,
      const rlwe::RlweContext<ModularInt>* context) {
    std::vector<Polynomial<ModularInt>> ntt_plaintexts;
    for (int i = 0; i < coeffs.size(); ++i) {
      RLWE_ASSIGN_OR_RETURN(auto mont,
                            rlwe::testing::ConvertToMontgomery<ModularInt>(
                                coeffs[i], context->GetModulusParams()));
      ntt_plaintexts.push_back(Polynomial<ModularInt>::ConvertToNtt(
          mont, context->GetNttParams(), context->GetModulusParams()));
    }
    return ntt_plaintexts;
  }

  void TestCompressedEncryptionDecryption(
      const std::vector<std::vector<typename ModularInt::Int>>& plaintexts,
      const rlwe::RlweContext<ModularInt>* context) {
    ASSERT_OK_AND_ASSIGN(auto key, SampleKey(context));
    ASSERT_OK_AND_ASSIGN(std::string prng_seed,
                         SingleThreadPrng::GenerateSeed());
    ASSERT_OK_AND_ASSIGN(auto prng, SingleThreadPrng::Create(prng_seed));
    ASSERT_OK_AND_ASSIGN(std::string prng_encryption_seed,
                         SingleThreadPrng::GenerateSeed());
    ASSERT_OK_AND_ASSIGN(auto prng_encryption,
                         SingleThreadPrng::Create(prng_encryption_seed));
    ASSERT_OK_AND_ASSIGN(std::vector<Polynomial<ModularInt>> ntt_plaintexts,
                         ConvertPlaintextsToNtt(plaintexts, context));
    ASSERT_OK_AND_ASSIGN(
        auto compressed_ciphertexts,
        EncryptWithPrng<ModularInt>(key, ntt_plaintexts, prng.get(),
                                    prng_encryption.get()));
    EXPECT_EQ(plaintexts.size(), compressed_ciphertexts.size());
    ASSERT_OK_AND_ASSIGN(auto another_prng,
                         SingleThreadPrng::Create(prng_seed));
    ASSERT_OK_AND_ASSIGN(auto ciphertexts,
                         ExpandFromPrng<ModularInt>(compressed_ciphertexts,
                                                    context->GetModulusParams(),
                                                    context->GetNttParams(),
                                                    context->GetErrorParams(),
                                                    another_prng.get()));
    EXPECT_EQ(plaintexts.size(), ciphertexts.size());
    for (int i = 0; i < ciphertexts.size(); ++i) {
      // Expect that the error of an expanded ciphertext is of a fresh
      // encryption.
      EXPECT_EQ(ciphertexts[i].Error(),
                context->GetErrorParams()->B_encryption());
      ASSERT_OK_AND_ASSIGN(auto decrypted,
                           Decrypt<ModularInt>(key, ciphertexts[i]));
      EXPECT_EQ(plaintexts[i], decrypted);
    }
  }
};
TYPED_TEST_SUITE(SymmetricEncryptionWithPrngTest,
                 rlwe::testing::ModularIntTypes);

// Ensure that the encryption scheme can encrypt and decrypt a single compressed
// ciphertext.
TYPED_TEST(SymmetricEncryptionWithPrngTest, EncryptDecryptSingleCompressed) {
  for (const auto& params :
       rlwe::testing::ContextParameters<TypeParam>::Value()) {
    ASSERT_OK_AND_ASSIGN(auto context,
                         rlwe::RlweContext<TypeParam>::Create(params));
    for (unsigned int i = 0; i < kTestingRounds; ++i) {
      this->TestCompressedEncryptionDecryption(
          {rlwe::testing::SamplePlaintext<TypeParam>(context->GetN(),
                                                     context->GetT())},
          context.get());
    }
  }
}

// Ensure that the encryption scheme can encrypt and decrypt multiple compressed
// ciphertexts.
TYPED_TEST(SymmetricEncryptionWithPrngTest, EncryptDecryptMultipleCompressed) {
  for (const auto& params :
       rlwe::testing::ContextParameters<TypeParam>::Value()) {
    ASSERT_OK_AND_ASSIGN(auto context,
                         rlwe::RlweContext<TypeParam>::Create(params));
    for (unsigned int i = 0; i < kTestingRounds; ++i) {
      std::vector<std::vector<typename TypeParam::Int>> plaintexts;
      for (int j = 0; j < i + 2; ++j) {
        plaintexts.push_back(rlwe::testing::SamplePlaintext<TypeParam>(
            context->GetN(), context->GetT()));
      }
      this->TestCompressedEncryptionDecryption(plaintexts, context.get());
    }
  }
}

}  // namespace
}  // namespace rlwe