chromium/third_party/shell-encryption/src/galois_key_test.cc

/*
 * Copyright 2018 Google LLC.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "galois_key.h"

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <google/protobuf/util/message_differencer.h>
#include "constants.h"
#include "montgomery.h"
#include "ntt_parameters.h"
#include "polynomial.h"
#include "prng/integral_prng_types.h"
#include "status_macros.h"
#include "symmetric_encryption.h"
#include "testing/protobuf_matchers.h"
#include "testing/status_matchers.h"
#include "testing/status_testing.h"
#include "testing/testing_prng.h"
#include "testing/testing_utils.h"

namespace {

using Uint64 = rlwe::Uint64;

unsigned int seed = 0;

// Set constants.
const Uint64 kLogPlaintextModulus = 1;
const Uint64 kPlaintextModulus = (1 << kLogPlaintextModulus) + 1;
const Uint64 kLogDecompositionModulus = 2;
const Uint64 kLargeLogDecompositionModulus = 31;

// Useful typedefs.
using uint_m = rlwe::MontgomeryInt<Uint64>;
using Polynomial = rlwe::Polynomial<uint_m>;
using Ciphertext = rlwe::SymmetricRlweCiphertext<uint_m>;
using Key = rlwe::SymmetricRlweKey<uint_m>;

using ::rlwe::testing::EqualsProto;
using ::rlwe::testing::StatusIs;
using ::testing::HasSubstr;

// Test fixture.
class GaloisKeyTest : public ::testing::Test {
 protected:
  void SetUp() override {
    ASSERT_OK_AND_ASSIGN(params59_, uint_m::Params::Create(rlwe::kModulus59));
    ASSERT_OK_AND_ASSIGN(auto ntt_params,
                         rlwe::InitializeNttParameters<uint_m>(
                             rlwe::testing::kLogCoeffs, params59_.get()));
    ntt_params_ = absl::make_unique<const rlwe::NttParameters<uint_m>>(
        std::move(ntt_params));
    ASSERT_OK_AND_ASSIGN(
        auto error_params,
        rlwe::ErrorParams<uint_m>::Create(rlwe::testing::kDefaultLogT,
                                          rlwe::testing::kDefaultVariance,
                                          params59_.get(), ntt_params_.get()));
    error_params_ =
        absl::make_unique<const rlwe::ErrorParams<uint_m>>(error_params);
  }

  // Sample a random key.
  rlwe::StatusOr<Key> SampleKey(
      Uint64 variance = rlwe::testing::kDefaultVariance,
      Uint64 log_t = kLogPlaintextModulus) {
    RLWE_ASSIGN_OR_RETURN(std::string prng_seed,
                          rlwe::SingleThreadPrng::GenerateSeed());
    RLWE_ASSIGN_OR_RETURN(auto prng, rlwe::SingleThreadPrng::Create(prng_seed));
    return Key::Sample(rlwe::testing::kLogCoeffs, variance, log_t,
                       params59_.get(), ntt_params_.get(), prng.get());
  }

  // Convert a vector of integers to a vector of montgomery integers.
  rlwe::StatusOr<std::vector<uint_m>> ConvertToMontgomery(
      const std::vector<uint_m::Int>& coeffs, const uint_m::Params* params) {
    std::vector<uint_m> output(coeffs.size(), uint_m::ImportZero(params));
    for (unsigned int i = 0; i < output.size(); i++) {
      RLWE_ASSIGN_OR_RETURN(output[i], uint_m::ImportInt(coeffs[i], params));
    }
    return output;
  }

  // Sample a random plaintext.
  std::vector<uint_m::Int> SamplePlaintext(
      uint_m::Int t = kPlaintextModulus,
      Uint64 coeffs = rlwe::testing::kCoeffs) {
    std::vector<uint_m::Int> plaintext(coeffs);
    for (unsigned int i = 0; i < coeffs; i++) {
      plaintext[i] = rand_r(&seed) % t;
    }
    return plaintext;
  }

  // Encrypt a plaintext.
  rlwe::StatusOr<Ciphertext> Encrypt(
      const Key& key, const std::vector<uint_m::Int>& plaintext) {
    RLWE_ASSIGN_OR_RETURN(auto mp,
                          ConvertToMontgomery(plaintext, params59_.get()));
    auto plaintext_ntt =
        Polynomial::ConvertToNtt(mp, ntt_params_.get(), params59_.get());
    RLWE_ASSIGN_OR_RETURN(std::string prng_seed,
                          rlwe::SingleThreadPrng::GenerateSeed());
    RLWE_ASSIGN_OR_RETURN(auto prng, rlwe::SingleThreadPrng::Create(prng_seed));
    return rlwe::Encrypt<uint_m>(key, plaintext_ntt, error_params_.get(),
                                 prng.get());
  }

  std::unique_ptr<const uint_m::Params> params59_;
  std::unique_ptr<const rlwe::NttParameters<uint_m>> ntt_params_;
  std::unique_ptr<const rlwe::ErrorParams<uint_m>> error_params_;
};

TEST_F(GaloisKeyTest, GaloisKeyPowerOfSDoesNotMatchSubPower) {
  int substitution_power = 3;
  ASSERT_OK_AND_ASSIGN(auto key, SampleKey());
  ASSERT_OK_AND_ASSIGN(std::string prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());

  ASSERT_OK_AND_ASSIGN(auto galois_key, rlwe::GaloisKey<uint_m>::Create(
                                            key, prng_seed, substitution_power,
                                            kLargeLogDecompositionModulus));
  auto plaintext = SamplePlaintext(kPlaintextModulus);

  ASSERT_OK_AND_ASSIGN(auto ciphertext, Encrypt(key, plaintext));
  ASSERT_OK_AND_ASSIGN(
      auto subbed_ciphertext,
      ciphertext.Substitute(substitution_power + 2, ntt_params_.get()));
  EXPECT_THAT(
      galois_key.ApplyTo(subbed_ciphertext),
      StatusIs(::absl::StatusCode::kInvalidArgument,
               HasSubstr(absl::StrCat(
                   "Ciphertext PowerOfS: ", subbed_ciphertext.PowerOfS(),
                   " doesn't match the key substitution power: ",
                   substitution_power))));
}

TEST_F(GaloisKeyTest, GaloisKeyUpdatesPowerOfS) {
  int substitution_power = 3;
  ASSERT_OK_AND_ASSIGN(auto key, SampleKey());
  ASSERT_OK_AND_ASSIGN(std::string prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());

  ASSERT_OK_AND_ASSIGN(auto galois_key, rlwe::GaloisKey<uint_m>::Create(
                                            key, prng_seed, substitution_power,
                                            kLargeLogDecompositionModulus));
  auto plaintext = SamplePlaintext(kPlaintextModulus);

  // Substituted ciphertext has substition_power PowerOfS.
  ASSERT_OK_AND_ASSIGN(auto ciphertext, Encrypt(key, plaintext));
  ASSERT_OK_AND_ASSIGN(
      auto subbed_ciphertext,
      ciphertext.Substitute(substitution_power, ntt_params_.get()));
  EXPECT_EQ(subbed_ciphertext.PowerOfS(), substitution_power);

  // PowerOfS transformed back to 1.
  ASSERT_OK_AND_ASSIGN(auto transformed_ciphertext,
                       galois_key.ApplyTo(subbed_ciphertext));
  EXPECT_EQ(transformed_ciphertext.PowerOfS(), 1);
}

TEST_F(GaloisKeyTest, KeySwitchedCiphertextDecrypts) {
  int substitution_power = 3;
  ASSERT_OK_AND_ASSIGN(auto key, SampleKey());
  ASSERT_OK_AND_ASSIGN(std::string prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());

  ASSERT_OK_AND_ASSIGN(auto galois_key, rlwe::GaloisKey<uint_m>::Create(
                                            key, prng_seed, substitution_power,
                                            kLogDecompositionModulus));

  // Create the initial plaintexts.
  std::vector<uint_m::Int> plaintext = SamplePlaintext(kPlaintextModulus);

  // Create the expected polynomial output by substituting the plaintext.
  ASSERT_OK_AND_ASSIGN(auto mp1,
                       ConvertToMontgomery(plaintext, params59_.get()));
  Polynomial plaintext_ntt =
      Polynomial::ConvertToNtt(mp1, ntt_params_.get(), params59_.get());
  ASSERT_OK_AND_ASSIGN(
      Polynomial expected_ntt,
      plaintext_ntt.Substitute(substitution_power, ntt_params_.get(),
                               params59_.get()));
  std::vector<uint_m::Int> expected = rlwe::RemoveError<uint_m>(
      expected_ntt.InverseNtt(ntt_params_.get(), params59_.get()),
      params59_->modulus, kPlaintextModulus, params59_.get());

  // Encrypt and substitute the ciphertext. Decrypt with a substituted key.
  ASSERT_OK_AND_ASSIGN(auto intermediate, Encrypt(key, plaintext));
  ASSERT_OK_AND_ASSIGN(
      auto ciphertext,
      intermediate.Substitute(substitution_power, ntt_params_.get()));
  ASSERT_OK_AND_ASSIGN(auto transformed_ciphertext,
                       galois_key.ApplyTo(ciphertext));
  ASSERT_OK_AND_ASSIGN(std::vector<uint_m::Int> decrypted,
                       rlwe::Decrypt<uint_m>(key, transformed_ciphertext));

  EXPECT_EQ(decrypted, expected);
}

TEST_F(GaloisKeyTest, ComposingSubstitutions) {
  // Ensure that a ciphertext can be substituted by composing substitutions in
  // steps that have GaloisKeys.
  int substitution_power = 9;
  // Applying the substitution s -> s(x^3) twice will yield the substitution
  // power.
  int galois_power = 3;

  ASSERT_OK_AND_ASSIGN(auto key, SampleKey());
  ASSERT_OK_AND_ASSIGN(std::string prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());
  ASSERT_OK_AND_ASSIGN(auto galois_key, rlwe::GaloisKey<uint_m>::Create(
                                            key, prng_seed, galois_power,
                                            kLogDecompositionModulus));
  auto plaintext = SamplePlaintext(kPlaintextModulus);

  // Create the expected polynomial output by substituting the plaintext.
  ASSERT_OK_AND_ASSIGN(auto mp1,
                       ConvertToMontgomery(plaintext, params59_.get()));
  Polynomial plaintext_ntt =
      Polynomial::ConvertToNtt(mp1, ntt_params_.get(), params59_.get());
  ASSERT_OK_AND_ASSIGN(
      Polynomial expected_ntt,
      plaintext_ntt.Substitute(substitution_power, ntt_params_.get(),
                               params59_.get()));
  std::vector<uint_m::Int> expected = rlwe::RemoveError<uint_m>(
      expected_ntt.InverseNtt(ntt_params_.get(), params59_.get()),
      params59_->modulus, kPlaintextModulus, params59_.get());

  // Encrypt and substitute the ciphertext in steps using a single galois key.
  ASSERT_OK_AND_ASSIGN(auto ciphertext, Encrypt(key, plaintext));
  ASSERT_OK_AND_ASSIGN(auto sub_ciphertext,
                       ciphertext.Substitute(galois_power, ntt_params_.get()));
  ASSERT_OK_AND_ASSIGN(auto ciphertext_power_3,
                       galois_key.ApplyTo(sub_ciphertext));
  ASSERT_OK_AND_ASSIGN(
      auto sub_ciphertext_power_3,
      ciphertext_power_3.Substitute(galois_power, ntt_params_.get()));
  ASSERT_OK_AND_ASSIGN(auto ciphertext_power_9,
                       galois_key.ApplyTo(sub_ciphertext_power_3));

  EXPECT_EQ(ciphertext_power_9.PowerOfS(), 1);
  ASSERT_OK_AND_ASSIGN(std::vector<uint_m::Int> decrypted,
                       rlwe::Decrypt<uint_m>(key, ciphertext_power_9));
  EXPECT_EQ(decrypted, expected);
}

TEST_F(GaloisKeyTest, LargeDecompositionModulus) {
  int substitution_power = 3;

  ASSERT_OK_AND_ASSIGN(auto key, SampleKey());
  ASSERT_OK_AND_ASSIGN(std::string prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());

  ASSERT_OK_AND_ASSIGN(auto galois_key, rlwe::GaloisKey<uint_m>::Create(
                                            key, prng_seed, substitution_power,
                                            kLargeLogDecompositionModulus));
  auto plaintext = SamplePlaintext(kPlaintextModulus);

  // Create the expected polynomial output by substituting the plaintext.
  ASSERT_OK_AND_ASSIGN(auto mp1,
                       ConvertToMontgomery(plaintext, params59_.get()));
  Polynomial plaintext_ntt =
      Polynomial::ConvertToNtt(mp1, ntt_params_.get(), params59_.get());
  ASSERT_OK_AND_ASSIGN(
      Polynomial expected_ntt,
      plaintext_ntt.Substitute(substitution_power, ntt_params_.get(),
                               params59_.get()));
  std::vector<uint_m::Int> expected = rlwe::RemoveError<uint_m>(
      expected_ntt.InverseNtt(ntt_params_.get(), params59_.get()),
      params59_->modulus, kPlaintextModulus, params59_.get());

  // Encrypt and substitute the ciphertext. Decrypt with a substituted key.
  ASSERT_OK_AND_ASSIGN(auto intermediate, Encrypt(key, plaintext));
  ASSERT_OK_AND_ASSIGN(
      auto ciphertext,
      intermediate.Substitute(substitution_power, ntt_params_.get()));
  ASSERT_OK_AND_ASSIGN(auto transformed_ciphertext,
                       galois_key.ApplyTo(ciphertext));
  ASSERT_OK_AND_ASSIGN(std::vector<uint_m::Int> decrypted,
                       rlwe::Decrypt<uint_m>(key, transformed_ciphertext));

  EXPECT_EQ(decrypted, expected);
}

TEST_F(GaloisKeyTest, CiphertextWithTooManyComponents) {
  int substitution_power = 3;
  ASSERT_OK_AND_ASSIGN(auto key, SampleKey());
  ASSERT_OK_AND_ASSIGN(std::string prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());

  ASSERT_OK_AND_ASSIGN(auto galois_key, rlwe::GaloisKey<uint_m>::Create(
                                            key, prng_seed, substitution_power,
                                            kLargeLogDecompositionModulus));
  auto plaintext = SamplePlaintext(kPlaintextModulus);

  ASSERT_OK_AND_ASSIGN(auto intermediate, Encrypt(key, plaintext));
  ASSERT_OK_AND_ASSIGN(
      auto ciphertext,
      intermediate.Substitute(substitution_power, ntt_params_.get()));

  ASSERT_OK_AND_ASSIGN(auto product, ciphertext* ciphertext);
  EXPECT_THAT(galois_key.ApplyTo(product),
              StatusIs(::absl::StatusCode::kInvalidArgument,
                       HasSubstr("RelinearizationKey not large enough")));
}

TEST_F(GaloisKeyTest, DeserializedKeySwitches) {
  int substitution_power = 3;
  auto plaintext = SamplePlaintext(kPlaintextModulus);
  ASSERT_OK_AND_ASSIGN(auto key, SampleKey());
  ASSERT_OK_AND_ASSIGN(std::string prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());

  ASSERT_OK_AND_ASSIGN(auto galois_key, rlwe::GaloisKey<uint_m>::Create(
                                            key, prng_seed, substitution_power,
                                            kLargeLogDecompositionModulus));

  // Serialize and deserialize.
  ASSERT_OK_AND_ASSIGN(auto serialized, galois_key.Serialize());
  ASSERT_OK_AND_ASSIGN(auto deserialized,
                       rlwe::GaloisKey<uint_m>::Deserialize(
                           serialized, params59_.get(), ntt_params_.get()));

  // Create the expected polynomial output by substituting the plaintext.
  ASSERT_OK_AND_ASSIGN(auto mp,
                       ConvertToMontgomery(plaintext, params59_.get()));
  Polynomial plaintext_ntt =
      Polynomial::ConvertToNtt(mp, ntt_params_.get(), params59_.get());
  ASSERT_OK_AND_ASSIGN(
      Polynomial expected_ntt,
      plaintext_ntt.Substitute(substitution_power, ntt_params_.get(),
                               params59_.get()));
  std::vector<uint_m::Int> expected = rlwe::RemoveError<uint_m>(
      expected_ntt.InverseNtt(ntt_params_.get(), params59_.get()),
      params59_->modulus, kPlaintextModulus, params59_.get());

  // Encrypt and substitute the ciphertext.
  ASSERT_OK_AND_ASSIGN(auto intermediate, Encrypt(key, plaintext));
  ASSERT_OK_AND_ASSIGN(
      auto ciphertext,
      intermediate.Substitute(substitution_power, ntt_params_.get()));

  // Key-switch with the original galois key.
  ASSERT_OK_AND_ASSIGN(auto key_switched_ciphertext,
                       galois_key.ApplyTo(ciphertext));
  ASSERT_OK_AND_ASSIGN(std::vector<uint_m::Int> decrypted,
                       rlwe::Decrypt<uint_m>(key, key_switched_ciphertext));

  // Key-switch with the deserialized galois key.
  ASSERT_OK_AND_ASSIGN(auto key_switched_ciphertext_deserialized,
                       deserialized.ApplyTo(ciphertext));
  ASSERT_OK_AND_ASSIGN(
      std::vector<uint_m::Int> deserialized_decrypted,
      rlwe::Decrypt<uint_m>(key, key_switched_ciphertext_deserialized));

  EXPECT_EQ(deserialized_decrypted, expected);
  EXPECT_EQ(deserialized_decrypted, decrypted);
}

TEST_F(GaloisKeyTest, DeserializationFailsWithIncorrectModulus) {
  int substitution_power = 3;
  ASSERT_OK_AND_ASSIGN(auto key, SampleKey());
  ASSERT_OK_AND_ASSIGN(std::string prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());

  ASSERT_OK_AND_ASSIGN(auto galois_key, rlwe::GaloisKey<uint_m>::Create(
                                            key, prng_seed, substitution_power,
                                            kLargeLogDecompositionModulus));

  ASSERT_OK_AND_ASSIGN(auto params29, uint_m::Params::Create(rlwe::kModulus29));
  // Serialize and deserialize.
  ASSERT_OK_AND_ASSIGN(auto serialized, galois_key.Serialize());
  EXPECT_THAT(
      rlwe::GaloisKey<uint_m>::Deserialize(serialized, params29.get(),
                                           ntt_params_.get()),
      StatusIs(::absl::StatusCode::kInvalidArgument,
               HasSubstr(absl::StrCat(
                   "Log decomposition modulus, ", kLargeLogDecompositionModulus,
                   ", must be at most: ", params29->log_modulus, "."))));
}

TEST_F(GaloisKeyTest, SerializationsOfIdentialKeysEqual) {
  int substitution_power = 3;
  auto plaintext = SamplePlaintext(kPlaintextModulus);
  ASSERT_OK_AND_ASSIGN(auto key, SampleKey());
  ASSERT_OK_AND_ASSIGN(std::string prng_seed,
                       rlwe::SingleThreadPrng::GenerateSeed());

  ASSERT_OK_AND_ASSIGN(auto galois_key, rlwe::GaloisKey<uint_m>::Create(
                                            key, prng_seed, substitution_power,
                                            kLargeLogDecompositionModulus));
  auto galois_key_copy = galois_key;

  // Serialize both matrices.
  ASSERT_OK_AND_ASSIGN(auto serialized, galois_key.Serialize());
  ASSERT_OK_AND_ASSIGN(auto serialized_copy, galois_key_copy.Serialize());

  // Check that two serializations of the same matrix are equal.
  EXPECT_EQ(serialized_copy.SerializeAsString(), serialized.SerializeAsString());
}

}  //  namespace