chromium/third_party/shell-encryption/src/relinearization_key_test.cc

/*
 * Copyright 2018 Google LLC.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "relinearization_key.h"

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "constants.h"
#include "montgomery.h"
#include "ntt_parameters.h"
#include "polynomial.h"
#include "prng/integral_prng_types.h"
#include "status_macros.h"
#include "symmetric_encryption.h"
#include "testing/status_matchers.h"
#include "testing/status_testing.h"
#include "testing/testing_prng.h"

namespace {

unsigned int seed = 1;

// Useful typedefs.
using uint_m = rlwe::MontgomeryInt<absl::uint128>;
using Polynomial = rlwe::Polynomial<uint_m>;
using Ciphertext = rlwe::SymmetricRlweCiphertext<uint_m>;
using Key = rlwe::SymmetricRlweKey<uint_m>;
using RelinearizationKey = rlwe::RelinearizationKey<uint_m>;
using ErrorParams = rlwe::ErrorParams<uint_m>;

// Set constants.
const ssize_t kLogPlaintextModulus = 1;
const ssize_t kPlaintextModulus = (1 << kLogPlaintextModulus) + 1;
const ssize_t kDefaultVariance = 4;
const ssize_t kCoeffs = rlwe::kNewhopeDegreeBound;
const ssize_t kLogCoeffs = rlwe::kNewhopeLogDegreeBound;
const ssize_t kSmallLogDecompositionModulus = 2;
const ssize_t kLargeLogDecompositionModulus = 20;

using ::rlwe::testing::StatusIs;
using ::testing::HasSubstr;

// Test fixture.
class RelinearizationKeyTest : public ::testing::Test {
 protected:
  void SetUp() override {
    ASSERT_OK_AND_ASSIGN(params14_,
                         uint_m::Params::Create(rlwe::kNewhopeModulus));
    ASSERT_OK_AND_ASSIGN(params80_, uint_m::Params::Create(rlwe::kModulus80));
    ASSERT_OK_AND_ASSIGN(auto ntt_params, rlwe::InitializeNttParameters<uint_m>(
                                              kLogCoeffs, params14_.get()));
    ASSERT_OK_AND_ASSIGN(
        auto ntt_params80,
        rlwe::InitializeNttParameters<uint_m>(kLogCoeffs, params80_.get()));
    ntt_params_ = absl::make_unique<const rlwe::NttParameters<uint_m>>(
        std::move(ntt_params));
    ntt_params80_ = absl::make_unique<const rlwe::NttParameters<uint_m>>(
        std::move(ntt_params80));
    ASSERT_OK_AND_ASSIGN(auto error_params,
                         rlwe::ErrorParams<uint_m>::Create(
                             kLogPlaintextModulus, kDefaultVariance,
                             params14_.get(), ntt_params_.get()));
    error_params_ = absl::make_unique<const ErrorParams>(error_params);
    ASSERT_OK_AND_ASSIGN(auto error_params80,
                         rlwe::ErrorParams<uint_m>::Create(
                             kLogPlaintextModulus, kDefaultVariance,
                             params80_.get(), ntt_params80_.get()));
    error_params80_ = absl::make_unique<const ErrorParams>(error_params80);
  }

  // Convert a vector of integers to a vector of montgomery integers.
  rlwe::StatusOr<std::vector<uint_m>> ConvertToMontgomery(
      const std::vector<uint_m::Int>& coeffs, const uint_m::Params* params) {
    std::vector<uint_m> output(coeffs.size(), uint_m::ImportZero(params));
    for (unsigned int i = 0; i < output.size(); i++) {
      RLWE_ASSIGN_OR_RETURN(output[i], uint_m::ImportInt(coeffs[i], params));
    }
    return output;
  }

  // Sample a random key.
  rlwe::StatusOr<Key> SampleKey(rlwe::Uint64 variance = kDefaultVariance,
                                rlwe::Uint64 log_t = kLogPlaintextModulus) {
    RLWE_ASSIGN_OR_RETURN(std::string prng_seed,
                          rlwe::SingleThreadPrng::GenerateSeed());
    RLWE_ASSIGN_OR_RETURN(auto prng, rlwe::SingleThreadPrng::Create(prng_seed));
    return Key::Sample(kLogCoeffs, variance, log_t, params14_.get(),
                       ntt_params_.get(), prng.get());
  }

  // Sample a random plaintext.
  std::vector<uint_m::Int> SamplePlaintext(uint_m::Int t = kPlaintextModulus,
                                           rlwe::Uint64 coeffs = kCoeffs) {
    std::vector<uint_m::Int> plaintext(kCoeffs);
    for (unsigned int i = 0; i < kCoeffs; i++) {
      plaintext[i] = rand_r(&seed) % t;
    }
    return plaintext;
  }

  // Encrypt a plaintext.
  rlwe::StatusOr<Ciphertext> Encrypt(
      const Key& key, const std::vector<uint_m::Int>& plaintext,
      const uint_m::Params* params,
      const rlwe::NttParameters<uint_m>* ntt_params,
      const ErrorParams* error_params) {
    RLWE_ASSIGN_OR_RETURN(auto m_plaintext,
                          ConvertToMontgomery(plaintext, params));
    auto plaintext_ntt =
        Polynomial::ConvertToNtt(m_plaintext, ntt_params, params);
    RLWE_ASSIGN_OR_RETURN(std::string prng_seed,
                          rlwe::SingleThreadPrng::GenerateSeed());
    RLWE_ASSIGN_OR_RETURN(auto prng, rlwe::SingleThreadPrng::Create(prng_seed));
    return rlwe::Encrypt<uint_m>(key, plaintext_ntt, error_params, prng.get());
  }

  std::unique_ptr<const uint_m::Params> params14_;
  std::unique_ptr<const uint_m::Params> params80_;
  std::unique_ptr<const rlwe::NttParameters<uint_m>> ntt_params_;
  std::unique_ptr<const rlwe::NttParameters<uint_m>> ntt_params80_;
  std::unique_ptr<const ErrorParams> error_params_;
  std::unique_ptr<const ErrorParams> error_params80_;
};

TEST_F(RelinearizationKeyTest, RelinearizationKeyReducesSizeOfCiphertext) {
  ASSERT_OK_AND_ASSIGN(auto key, SampleKey());
  ASSERT_OK_AND_ASSIGN(std::string prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());

  ASSERT_OK_AND_ASSIGN(
      auto relinearization_key,
      rlwe::RelinearizationKey<uint_m>::Create(key, prng_seed, /*num_parts=*/3,
                                               kSmallLogDecompositionModulus));

  // Create the initial plaintexts.
  std::vector<uint_m::Int> plaintext1 = SamplePlaintext(kPlaintextModulus);
  ASSERT_OK_AND_ASSIGN(auto mp1,
                       ConvertToMontgomery(plaintext1, params14_.get()));
  Polynomial plaintext1_ntt =
      Polynomial::ConvertToNtt(mp1, ntt_params_.get(), params14_.get());

  std::vector<uint_m::Int> plaintext2 = SamplePlaintext(kPlaintextModulus);
  ASSERT_OK_AND_ASSIGN(auto mp2,
                       ConvertToMontgomery(plaintext2, params14_.get()));
  Polynomial plaintext2_ntt =
      Polynomial::ConvertToNtt(mp2, ntt_params_.get(), params14_.get());

  // Encrypt, multiply, apply the relinearization key and decrypt.
  ASSERT_OK_AND_ASSIGN(auto ciphertext1,
                       Encrypt(key, plaintext1, params14_.get(),
                               ntt_params_.get(), error_params_.get()));
  ASSERT_OK_AND_ASSIGN(auto ciphertext2,
                       Encrypt(key, plaintext2, params14_.get(),
                               ntt_params_.get(), error_params_.get()));
  ASSERT_OK_AND_ASSIGN(auto product, ciphertext1* ciphertext2);
  ASSERT_OK_AND_ASSIGN(auto relinearized_product,
                       relinearization_key.ApplyTo(product));

  EXPECT_EQ(product.Len(), 3);
  EXPECT_EQ(relinearized_product.Len(), 2);
}

TEST_F(RelinearizationKeyTest, RelinearizeKey3PartsDecrypts) {
  ASSERT_OK_AND_ASSIGN(std::string key_prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());
  ASSERT_OK_AND_ASSIGN(auto key_prng,
                       rlwe::SingleThreadPrng::Create(key_prng_seed));
  ASSERT_OK_AND_ASSIGN(
      auto key,
      Key::Sample(kLogCoeffs, kDefaultVariance, kLogPlaintextModulus,
                  params80_.get(), ntt_params80_.get(), key_prng.get()));
  ASSERT_OK_AND_ASSIGN(std::string prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());

  ASSERT_OK_AND_ASSIGN(
      auto relinearization_key,
      rlwe::RelinearizationKey<uint_m>::Create(key, prng_seed, /*num_parts=*/3,
                                               kSmallLogDecompositionModulus));

  // Create the initial plaintexts.
  std::vector<uint_m::Int> plaintext1 = SamplePlaintext(kPlaintextModulus);
  ASSERT_OK_AND_ASSIGN(auto mp1,
                       ConvertToMontgomery(plaintext1, params80_.get()));
  Polynomial plaintext1_ntt =
      Polynomial::ConvertToNtt(mp1, ntt_params80_.get(), params80_.get());

  std::vector<uint_m::Int> plaintext2 = SamplePlaintext(kPlaintextModulus);
  ASSERT_OK_AND_ASSIGN(auto mp2,
                       ConvertToMontgomery(plaintext2, params80_.get()));
  Polynomial plaintext2_ntt =
      Polynomial::ConvertToNtt(mp2, ntt_params80_.get(), params80_.get());

  // Encrypt, multiply, apply the relinearization key and decrypt.
  ASSERT_OK_AND_ASSIGN(auto ciphertext1,
                       Encrypt(key, plaintext1, params80_.get(),
                               ntt_params80_.get(), error_params80_.get()));
  ASSERT_OK_AND_ASSIGN(auto ciphertext2,
                       Encrypt(key, plaintext2, params80_.get(),
                               ntt_params80_.get(), error_params80_.get()));
  ASSERT_OK_AND_ASSIGN(auto product, ciphertext1* ciphertext2);
  ASSERT_OK_AND_ASSIGN(auto relinearized_product,
                       relinearization_key.ApplyTo(product));
  ASSERT_OK_AND_ASSIGN(std::vector<uint_m::Int> decrypted,
                       rlwe::Decrypt<uint_m>(key, relinearized_product));

  // Create the polynomial we expect.
  ASSERT_OK(plaintext1_ntt.MulInPlace(plaintext2_ntt, params80_.get()));
  std::vector<uint_m::Int> expected = rlwe::RemoveError<uint_m>(
      plaintext1_ntt.InverseNtt(ntt_params80_.get(), params80_.get()),
      params80_->modulus, kPlaintextModulus, params80_.get());

  EXPECT_EQ(decrypted, expected);
}

TEST_F(RelinearizationKeyTest, RelinearizeKey4PartsDecrypts) {
  ASSERT_OK_AND_ASSIGN(std::string key_prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());
  ASSERT_OK_AND_ASSIGN(auto key_prng,
                       rlwe::SingleThreadPrng::Create(key_prng_seed));
  ASSERT_OK_AND_ASSIGN(
      auto key,
      Key::Sample(kLogCoeffs, kDefaultVariance, kLogPlaintextModulus,
                  params80_.get(), ntt_params80_.get(), key_prng.get()));
  ASSERT_OK_AND_ASSIGN(std::string prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());

  ASSERT_OK_AND_ASSIGN(
      auto relinearization_key,
      rlwe::RelinearizationKey<uint_m>::Create(key, prng_seed, /*num_parts=*/4,
                                               kLargeLogDecompositionModulus));

  // Create the initial plaintexts.
  std::vector<uint_m::Int> plaintext1 = SamplePlaintext(kPlaintextModulus);
  ASSERT_OK_AND_ASSIGN(auto mp1,
                       ConvertToMontgomery(plaintext1, params80_.get()));
  Polynomial plaintext1_ntt =
      Polynomial::ConvertToNtt(mp1, ntt_params80_.get(), params80_.get());

  std::vector<uint_m::Int> plaintext2 = SamplePlaintext(kPlaintextModulus);
  ASSERT_OK_AND_ASSIGN(auto mp2,
                       ConvertToMontgomery(plaintext2, params80_.get()));
  Polynomial plaintext2_ntt =
      Polynomial::ConvertToNtt(mp2, ntt_params80_.get(), params80_.get());

  std::vector<uint_m::Int> plaintext3 = SamplePlaintext(kPlaintextModulus);
  ASSERT_OK_AND_ASSIGN(auto mp3,
                       ConvertToMontgomery(plaintext3, params80_.get()));
  Polynomial plaintext3_ntt =
      Polynomial::ConvertToNtt(mp3, ntt_params80_.get(), params80_.get());

  // Relinearize a 4 component ciphertext produced from three multiplications.
  ASSERT_OK_AND_ASSIGN(auto ciphertext1,
                       Encrypt(key, plaintext1, params80_.get(),
                               ntt_params80_.get(), error_params80_.get()));
  ASSERT_OK_AND_ASSIGN(auto ciphertext2,
                       Encrypt(key, plaintext2, params80_.get(),
                               ntt_params80_.get(), error_params80_.get()));
  ASSERT_OK_AND_ASSIGN(auto ciphertext3,
                       Encrypt(key, plaintext3, params80_.get(),
                               ntt_params80_.get(), error_params80_.get()));
  ASSERT_OK_AND_ASSIGN(auto intermediate, ciphertext1* ciphertext2);
  ASSERT_OK_AND_ASSIGN(auto product, intermediate* ciphertext3);
  ASSERT_OK_AND_ASSIGN(auto relinearized_product,
                       relinearization_key.ApplyTo(product));

  ASSERT_OK_AND_ASSIGN(std::vector<uint_m::Int> decrypted,
                       rlwe::Decrypt<uint_m>(key, relinearized_product));

  // Create the polynomial we expect.
  ASSERT_OK(plaintext1_ntt.MulInPlace(plaintext2_ntt, params80_.get()));
  ASSERT_OK(plaintext1_ntt.MulInPlace(plaintext3_ntt, params80_.get()));
  std::vector<uint_m::Int> expected = rlwe::RemoveError<uint_m>(
      plaintext1_ntt.InverseNtt(ntt_params80_.get(), params80_.get()),
      params80_->modulus, kPlaintextModulus, params80_.get());

  EXPECT_EQ(decrypted, expected);
}

TEST_F(RelinearizationKeyTest, RelinearizeKeyLargeModulusDecrypts) {
  ASSERT_OK_AND_ASSIGN(std::string key_prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());
  ASSERT_OK_AND_ASSIGN(auto key_prng,
                       rlwe::SingleThreadPrng::Create(key_prng_seed));
  ASSERT_OK_AND_ASSIGN(
      auto key,
      Key::Sample(kLogCoeffs, kDefaultVariance, kLogPlaintextModulus,
                  params80_.get(), ntt_params80_.get(), key_prng.get()));
  ASSERT_OK_AND_ASSIGN(std::string prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());

  ASSERT_OK_AND_ASSIGN(
      auto relinearization_key,
      rlwe::RelinearizationKey<uint_m>::Create(key, prng_seed, /*num_parts=*/3,
                                               kLargeLogDecompositionModulus));

  // Create the initial plaintexts.
  std::vector<uint_m::Int> plaintext1 = SamplePlaintext(kPlaintextModulus);
  ASSERT_OK_AND_ASSIGN(auto mp1,
                       ConvertToMontgomery(plaintext1, params80_.get()));
  Polynomial plaintext1_ntt =
      Polynomial::ConvertToNtt(mp1, ntt_params80_.get(), params80_.get());

  std::vector<uint_m::Int> plaintext2 = SamplePlaintext(kPlaintextModulus);
  ASSERT_OK_AND_ASSIGN(auto mp2,
                       ConvertToMontgomery(plaintext2, params80_.get()));
  Polynomial plaintext2_ntt =
      Polynomial::ConvertToNtt(mp2, ntt_params80_.get(), params80_.get());

  // Multiply, apply the relinearization key, multiply, relinearize and decrypt.
  ASSERT_OK_AND_ASSIGN(auto ciphertext1,
                       Encrypt(key, plaintext1, params80_.get(),
                               ntt_params80_.get(), error_params80_.get()));
  ASSERT_OK_AND_ASSIGN(auto ciphertext2,
                       Encrypt(key, plaintext2, params80_.get(),
                               ntt_params80_.get(), error_params80_.get()));
  ASSERT_OK_AND_ASSIGN(auto product, ciphertext1* ciphertext2);
  ASSERT_OK_AND_ASSIGN(auto relinearized_product,
                       relinearization_key.ApplyTo(product));

  ASSERT_OK_AND_ASSIGN(std::vector<uint_m::Int> decrypted,
                       rlwe::Decrypt<uint_m>(key, relinearized_product));

  // Create the polynomial we expect.
  ASSERT_OK(plaintext1_ntt.MulInPlace(plaintext2_ntt, params80_.get()));
  std::vector<uint_m::Int> expected = rlwe::RemoveError<uint_m>(
      plaintext1_ntt.InverseNtt(ntt_params80_.get(), params80_.get()),
      params80_->modulus, kPlaintextModulus, params80_.get());

  EXPECT_EQ(decrypted, expected);
}

TEST_F(RelinearizationKeyTest, RepeatedRelinearizationDecrypts) {
  ASSERT_OK_AND_ASSIGN(std::string key_prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());
  ASSERT_OK_AND_ASSIGN(auto key_prng,
                       rlwe::SingleThreadPrng::Create(key_prng_seed));
  ASSERT_OK_AND_ASSIGN(
      auto key,
      Key::Sample(kLogCoeffs, kDefaultVariance, kLogPlaintextModulus,
                  params80_.get(), ntt_params80_.get(), key_prng.get()));
  ASSERT_OK_AND_ASSIGN(std::string prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());

  ASSERT_OK_AND_ASSIGN(
      auto relinearization_key,
      rlwe::RelinearizationKey<uint_m>::Create(key, prng_seed, /*num_parts=*/3,
                                               kLargeLogDecompositionModulus));

  // Create the initial plaintexts.
  std::vector<uint_m::Int> plaintext1 = SamplePlaintext(kPlaintextModulus);
  ASSERT_OK_AND_ASSIGN(auto mp1,
                       ConvertToMontgomery(plaintext1, params80_.get()));
  Polynomial plaintext1_ntt =
      Polynomial::ConvertToNtt(mp1, ntt_params80_.get(), params80_.get());

  std::vector<uint_m::Int> plaintext2 = SamplePlaintext(kPlaintextModulus);
  ASSERT_OK_AND_ASSIGN(auto mp2,
                       ConvertToMontgomery(plaintext2, params80_.get()));
  Polynomial plaintext2_ntt =
      Polynomial::ConvertToNtt(mp2, ntt_params80_.get(), params80_.get());

  std::vector<uint_m::Int> plaintext3 = SamplePlaintext(kPlaintextModulus);
  ASSERT_OK_AND_ASSIGN(auto mp3,
                       ConvertToMontgomery(plaintext3, params80_.get()));
  Polynomial plaintext3_ntt =
      Polynomial::ConvertToNtt(mp3, ntt_params80_.get(), params80_.get());

  // Multiply, apply the relinearization key, multiply, relinearize and decrypt.
  ASSERT_OK_AND_ASSIGN(auto ciphertext1,
                       Encrypt(key, plaintext1, params80_.get(),
                               ntt_params80_.get(), error_params80_.get()));
  ASSERT_OK_AND_ASSIGN(auto ciphertext2,
                       Encrypt(key, plaintext2, params80_.get(),
                               ntt_params80_.get(), error_params80_.get()));
  ASSERT_OK_AND_ASSIGN(auto ciphertext3,
                       Encrypt(key, plaintext3, params80_.get(),
                               ntt_params80_.get(), error_params80_.get()));
  ASSERT_OK_AND_ASSIGN(auto product1, ciphertext1* ciphertext2);
  ASSERT_OK_AND_ASSIGN(auto relinearized_product1,
                       relinearization_key.ApplyTo(product1));
  ASSERT_OK_AND_ASSIGN(auto product2, relinearized_product1* ciphertext3);
  ASSERT_OK_AND_ASSIGN(auto relinearized_product2,
                       relinearization_key.ApplyTo(product2));

  ASSERT_OK_AND_ASSIGN(std::vector<uint_m::Int> decrypted,
                       rlwe::Decrypt<uint_m>(key, relinearized_product2));

  // Create the polynomial we expect.
  ASSERT_OK(plaintext1_ntt.MulInPlace(plaintext2_ntt, params80_.get()));
  ASSERT_OK(plaintext1_ntt.MulInPlace(plaintext3_ntt, params80_.get()));
  std::vector<uint_m::Int> expected = rlwe::RemoveError<uint_m>(
      plaintext1_ntt.InverseNtt(ntt_params80_.get(), params80_.get()),
      params80_->modulus, kPlaintextModulus, params80_.get());

  EXPECT_EQ(decrypted, expected);
}

TEST_F(RelinearizationKeyTest, CiphertextWithTooManyComponents) {
  ASSERT_OK_AND_ASSIGN(auto key, SampleKey());
  ASSERT_OK_AND_ASSIGN(std::string prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());

  // RelinearizationKey has length 2.
  ASSERT_OK_AND_ASSIGN(
      auto relinearization_key,
      rlwe::RelinearizationKey<uint_m>::Create(key, prng_seed, /*num_parts=*/2,
                                               kSmallLogDecompositionModulus));

  // Create the initial plaintexts.
  std::vector<uint_m::Int> plaintext1 = SamplePlaintext(kPlaintextModulus);
  ASSERT_OK_AND_ASSIGN(auto mp1,
                       ConvertToMontgomery(plaintext1, params14_.get()));
  Polynomial plaintext1_ntt =
      Polynomial::ConvertToNtt(mp1, ntt_params_.get(), params14_.get());

  std::vector<uint_m::Int> plaintext2 = SamplePlaintext(kPlaintextModulus);
  ASSERT_OK_AND_ASSIGN(auto mp2,
                       ConvertToMontgomery(plaintext2, params14_.get()));
  Polynomial plaintext2_ntt =
      Polynomial::ConvertToNtt(mp2, ntt_params_.get(), params14_.get());

  // Encrypt, multiply, apply the relinearization key.
  ASSERT_OK_AND_ASSIGN(auto ciphertext1,
                       Encrypt(key, plaintext1, params14_.get(),
                               ntt_params_.get(), error_params_.get()));
  ASSERT_OK_AND_ASSIGN(auto ciphertext2,
                       Encrypt(key, plaintext2, params14_.get(),
                               ntt_params_.get(), error_params_.get()));
  ASSERT_OK_AND_ASSIGN(auto product, ciphertext1* ciphertext2);
  EXPECT_THAT(relinearization_key.ApplyTo(product),
              StatusIs(absl::StatusCode::kInvalidArgument,
                       HasSubstr("RelinearizationKey not large enough")));
}

TEST_F(RelinearizationKeyTest, LogDecompositionModulusOutOfBounds) {
  ASSERT_OK_AND_ASSIGN(auto key, SampleKey());
  ASSERT_OK_AND_ASSIGN(std::string prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());

  // RelinearizationKey has length 2.
  EXPECT_THAT(
      rlwe::RelinearizationKey<uint_m>::Create(
          key, prng_seed, /*num_parts=*/2,
          /*log_decomposition_modulus=*/key.ModulusParams()->log_modulus + 1),
      StatusIs(absl::StatusCode::kInvalidArgument,
               HasSubstr(absl::StrCat(
                   "Log decomposition modulus, ",
                   key.ModulusParams()->log_modulus + 1, ", ",
                   "must be at most: ", key.ModulusParams()->log_modulus))));

  int log_decomposition_modulus = 0;
  EXPECT_THAT(rlwe::RelinearizationKey<uint_m>::Create(
                  key, prng_seed, /*num_parts=*/3, log_decomposition_modulus),
              StatusIs(absl::StatusCode::kInvalidArgument,
                       HasSubstr(absl::StrCat("Log decomposition modulus, ",
                                              log_decomposition_modulus,
                                              ", must be positive."))));
}

TEST_F(RelinearizationKeyTest, NumPartsMustBePositive) {
  ASSERT_OK_AND_ASSIGN(auto key, SampleKey());
  ASSERT_OK_AND_ASSIGN(std::string prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());

  EXPECT_THAT(
      rlwe::RelinearizationKey<uint_m>::Create(key, prng_seed, /*num_parts=*/-1,
                                               kSmallLogDecompositionModulus),
      StatusIs(absl::StatusCode::kInvalidArgument,
               HasSubstr("Num parts: -1 must be positive.")));
}

TEST_F(RelinearizationKeyTest, InvalidDeserialize) {
  ASSERT_OK_AND_ASSIGN(std::string key_prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());
  ASSERT_OK_AND_ASSIGN(auto key_prng,
                       rlwe::SingleThreadPrng::Create(key_prng_seed));
  ASSERT_OK_AND_ASSIGN(
      auto key,
      Key::Sample(kLogCoeffs, kDefaultVariance, kLogPlaintextModulus,
                  params80_.get(), ntt_params80_.get(), key_prng.get()));
  ASSERT_OK_AND_ASSIGN(std::string prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());

  ASSERT_OK_AND_ASSIGN(
      auto relinearization_key,
      rlwe::RelinearizationKey<uint_m>::Create(key, prng_seed, /*num_parts=*/3,
                                               kLargeLogDecompositionModulus));
  // Serialize and deserialize.
  ASSERT_OK_AND_ASSIGN(rlwe::SerializedRelinearizationKey serialized,
                       relinearization_key.Serialize());
  for (int i = -1; i <= 1; i++) {
    serialized.set_num_parts(i);
    EXPECT_THAT(RelinearizationKey::Deserialize(serialized, params80_.get(),
                                                ntt_params80_.get()),
                StatusIs(absl::StatusCode::kInvalidArgument,
                         HasSubstr(absl::StrCat(
                             "The number of parts, ", serialized.num_parts(),
                             ", must be greater than one."))));
  }
  ASSERT_GT(serialized.c_size(), 2);
  serialized.set_num_parts(serialized.c_size() - 1);
  EXPECT_THAT(
      RelinearizationKey::Deserialize(serialized, params80_.get(),
                                      ntt_params80_.get()),
      StatusIs(absl::StatusCode::kInvalidArgument,
               HasSubstr(absl::StrCat(
                   "The length of serialized, ", serialized.c_size(), ", ",
                   "must be divisible by the number of parts minus one ",
                   serialized.num_parts() - 1, "."))));
  ASSERT_EQ(serialized.c_size(),
            /* log2(kModulus80) / kLargeLogDecompositionModulus = */ 8);
  serialized.set_num_parts(serialized.c_size() + 1);
  EXPECT_THAT(RelinearizationKey::Deserialize(serialized, params80_.get(),
                                              ntt_params80_.get()),
              StatusIs(absl::StatusCode::kInvalidArgument,
                       HasSubstr(absl::StrCat(
                           "Number of NTT Polynomials does not match expected ",
                           "number of matrix entries."))));
}

TEST_F(RelinearizationKeyTest, SerializeKey) {
  ASSERT_OK_AND_ASSIGN(std::string key_prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());
  ASSERT_OK_AND_ASSIGN(auto key_prng,
                       rlwe::SingleThreadPrng::Create(key_prng_seed));
  ASSERT_OK_AND_ASSIGN(
      auto key,
      Key::Sample(kLogCoeffs, kDefaultVariance, kLogPlaintextModulus,
                  params80_.get(), ntt_params80_.get(), key_prng.get()));
  ASSERT_OK_AND_ASSIGN(std::string prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());

  ASSERT_OK_AND_ASSIGN(
      auto relinearization_key,
      rlwe::RelinearizationKey<uint_m>::Create(key, prng_seed, /*num_parts=*/3,
                                               kLargeLogDecompositionModulus));

  // Serialize and deserialize.
  ASSERT_OK_AND_ASSIGN(rlwe::SerializedRelinearizationKey serialized,
                       relinearization_key.Serialize());
  ASSERT_OK_AND_ASSIGN(auto deserialized,
                       RelinearizationKey::Deserialize(
                           serialized, params80_.get(), ntt_params80_.get()));

  // Create the initial plaintexts.
  std::vector<uint_m::Int> plaintext1 = SamplePlaintext(kPlaintextModulus);
  ASSERT_OK_AND_ASSIGN(auto mp1,
                       ConvertToMontgomery(plaintext1, params80_.get()));
  Polynomial plaintext1_ntt =
      Polynomial::ConvertToNtt(mp1, ntt_params80_.get(), params80_.get());

  std::vector<uint_m::Int> plaintext2 = SamplePlaintext(kPlaintextModulus);
  ASSERT_OK_AND_ASSIGN(auto mp2,
                       ConvertToMontgomery(plaintext2, params80_.get()));
  Polynomial plaintext2_ntt =
      Polynomial::ConvertToNtt(mp2, ntt_params80_.get(), params80_.get());

  // Encrypt, multiply, apply the relinearization key and decrypt.
  ASSERT_OK_AND_ASSIGN(auto ciphertext1,
                       Encrypt(key, plaintext1, params80_.get(),
                               ntt_params80_.get(), error_params80_.get()));
  ASSERT_OK_AND_ASSIGN(auto ciphertext2,
                       Encrypt(key, plaintext2, params80_.get(),
                               ntt_params80_.get(), error_params80_.get()));
  ASSERT_OK_AND_ASSIGN(auto product, ciphertext1* ciphertext2);
  ASSERT_OK_AND_ASSIGN(auto relinearized_product,
                       deserialized.ApplyTo(product));
  ASSERT_OK_AND_ASSIGN(std::vector<uint_m::Int> decrypted,
                       rlwe::Decrypt<uint_m>(key, relinearized_product));

  // Create the polynomial we expect.
  ASSERT_OK(plaintext1_ntt.MulInPlace(plaintext2_ntt, params80_.get()));
  std::vector<uint_m::Int> expected = rlwe::RemoveError<uint_m>(
      plaintext1_ntt.InverseNtt(ntt_params80_.get(), params80_.get()),
      params80_->modulus, kPlaintextModulus, params80_.get());

  EXPECT_EQ(decrypted, expected);
}

TEST_F(RelinearizationKeyTest, RelinearizationKeyIncreasesError) {
  ASSERT_OK_AND_ASSIGN(auto key, SampleKey());
  ASSERT_OK_AND_ASSIGN(std::string prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());

  ASSERT_OK_AND_ASSIGN(
      auto relinearization_key,
      rlwe::RelinearizationKey<uint_m>::Create(key, prng_seed, /*num_parts=*/3,
                                               kSmallLogDecompositionModulus));

  // Create the initial plaintexts.
  std::vector<uint_m::Int> plaintext1 = SamplePlaintext(kPlaintextModulus);
  ASSERT_OK_AND_ASSIGN(auto mp1,
                       ConvertToMontgomery(plaintext1, params14_.get()));
  Polynomial plaintext1_ntt =
      Polynomial::ConvertToNtt(mp1, ntt_params_.get(), params14_.get());

  std::vector<uint_m::Int> plaintext2 = SamplePlaintext(kPlaintextModulus);
  ASSERT_OK_AND_ASSIGN(auto mp2,
                       ConvertToMontgomery(plaintext2, params14_.get()));
  Polynomial plaintext2_ntt =
      Polynomial::ConvertToNtt(mp2, ntt_params_.get(), params14_.get());

  // Encrypt, multiply, apply the relinearization key and decrypt.
  ASSERT_OK_AND_ASSIGN(auto ciphertext1,
                       Encrypt(key, plaintext1, params14_.get(),
                               ntt_params_.get(), error_params_.get()));
  ASSERT_OK_AND_ASSIGN(auto ciphertext2,
                       Encrypt(key, plaintext2, params14_.get(),
                               ntt_params_.get(), error_params_.get()));
  ASSERT_OK_AND_ASSIGN(auto product, ciphertext1* ciphertext2);
  ASSERT_OK_AND_ASSIGN(auto relinearized_product,
                       relinearization_key.ApplyTo(product));

  // Expect that the error grows after relinearization.
  EXPECT_GT(relinearized_product.Error(), product.Error());
}

}  // namespace