/*
* Copyright 2018 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "relinearization_key.h"
#include "absl/numeric/int128.h"
#include "bits_util.h"
#include "montgomery.h"
#include "prng/integral_prng_types.h"
#include "status_macros.h"
#include "statusor.h"
#include "symmetric_encryption_with_prng.h"
#include "third_party/shell-encryption/base/shell_encryption_export.h"
#include "third_party/shell-encryption/base/shell_encryption_export_template.h"
namespace rlwe {
namespace {
// Method to compute the number of digits needed to represent integers mod
// q in base T. Upcasts the modulus to absl::uint128 to handle all Uint*
// types.
inline int ComputeDimension(Uint64 log_decomposition_modulus,
absl::uint128 modulus) {
Uint64 modulus_bits = static_cast<Uint64>(internal::BitLength(modulus));
return (modulus_bits + (log_decomposition_modulus - 1)) /
log_decomposition_modulus;
}
// Returns a random vector r orthogonal to (1,s). The second component is chosen
// using randomness-of-encryption sampled using the specified PRNG. The first
// component is then chosen so that r is perpendicular to (1,s).
template <typename ModularInt>
rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> SampleOrthogonalFromPrng(
const SymmetricRlweKey<ModularInt>& key, SecurePrng* prng) {
// Sample a random polynomial r using a PRNG.
RLWE_ASSIGN_OR_RETURN(auto r, SamplePolynomialFromPrng<ModularInt>(
key.Len(), prng, key.ModulusParams()));
// Top entries of the matrix R will be -s*r, thus R is orthogonal to
// (1,s).
RLWE_ASSIGN_OR_RETURN(Polynomial<ModularInt> r_top,
r.Mul(key.Key(), key.ModulusParams()));
r_top.NegateInPlace(key.ModulusParams());
std::vector<Polynomial<ModularInt>> res = {std::move(r_top), std::move(r)};
return res;
}
// The i-th component of the result is (T^i key_power).
template <typename ModularInt>
rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> PowersOfT(
const Polynomial<ModularInt>& key_power,
const SymmetricRlweKey<ModularInt>& key,
const ModularInt& decomposition_modulus, int dimension) {
std::vector<Polynomial<ModularInt>> result;
result.reserve(dimension);
Polynomial<ModularInt> key_to_i = key_power;
for (int i = 0; i < dimension; i++) {
// Increase the power of T in T^i s in place.
if (i != 0) {
RLWE_RETURN_IF_ERROR(
key_to_i.MulInPlace(decomposition_modulus, key.ModulusParams()));
}
result.push_back(key_to_i);
}
return result;
}
// The i-th component of the result contains a vector of i-th digits of the
// coefficients in base T (the decomposition modulus).
template <typename ModularInt>
rlwe::StatusOr<std::vector<std::vector<ModularInt>>> BitDecompose(
const std::vector<ModularInt>& coefficients,
const typename ModularInt::Params* modulus_params,
const Uint64 log_decomposition_modulus, int dimension) {
std::vector<typename ModularInt::Int> ciphertext_coeffs(coefficients.size(),
0);
std::transform(
coefficients.begin(), coefficients.end(), ciphertext_coeffs.begin(),
[modulus_params](ModularInt x) { return x.ExportInt(modulus_params); });
std::vector<std::vector<ModularInt>> result(dimension);
for (int i = 0; i < dimension; i++) {
result[i].reserve(ciphertext_coeffs.size());
for (int j = 0; j < ciphertext_coeffs.size(); ++j) {
RLWE_ASSIGN_OR_RETURN(
auto coefficient_part,
ModularInt::ImportInt(
(ciphertext_coeffs[j] % (1L << log_decomposition_modulus)),
modulus_params));
result[i].push_back(std::move(coefficient_part));
ciphertext_coeffs[j] = ciphertext_coeffs[j] >> log_decomposition_modulus;
}
}
return result;
}
template <typename ModularInt>
rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> MatrixMultiply(
std::vector<std::vector<ModularInt>> decomposed_coefficients,
const std::vector<std::vector<Polynomial<ModularInt>>>& matrix,
const typename ModularInt::Params* modulus_params,
const NttParameters<ModularInt>* ntt_params) {
Polynomial<ModularInt> temp(matrix[0][0].Len(), modulus_params);
Polynomial<ModularInt> ntt_part(matrix[0][0].Len(), modulus_params);
std::vector<Polynomial<ModularInt>> result(2, temp);
for (int i = 0; i < matrix[0].size(); i++) {
ntt_part = Polynomial<ModularInt>::ConvertToNtt(
std::move(decomposed_coefficients[i]), ntt_params, modulus_params);
RLWE_ASSIGN_OR_RETURN(temp, ntt_part.Mul(matrix[0][i], modulus_params));
RLWE_RETURN_IF_ERROR(result[0].AddInPlace(temp, modulus_params));
RLWE_RETURN_IF_ERROR(ntt_part.MulInPlace(matrix[1][i], modulus_params))
RLWE_RETURN_IF_ERROR(result[1].AddInPlace(ntt_part, modulus_params));
}
return result;
}
} // namespace
template <typename ModularInt>
rlwe::StatusOr<typename RelinearizationKey<ModularInt>::RelinearizationKeyPart>
RelinearizationKey<ModularInt>::RelinearizationKeyPart::Create(
const Polynomial<ModularInt>& key_power,
const SymmetricRlweKey<ModularInt>& key,
const Uint64 log_decomposition_modulus,
const ModularInt& decomposition_modulus, int dimension, SecurePrng* prng,
SecurePrng* prng_encryption) {
std::vector<std::vector<Polynomial<ModularInt>>> matrix(2);
for (auto& row : matrix) {
row.reserve(dimension);
}
// Compute a vector of (T^i key_power).
RLWE_ASSIGN_OR_RETURN(
auto powers_of_t,
PowersOfT(key_power, key, decomposition_modulus, dimension));
// For key_power = s^j, the ith iteration of this loop computes the column of
// the KeyPart corresponding to (T^i s^j).
for (int i = 0; i < dimension; ++i) {
// Sample r component orthogonal to (1,s).
RLWE_ASSIGN_OR_RETURN(auto r, SampleOrthogonalFromPrng(key, prng));
// Sample error.
RLWE_ASSIGN_OR_RETURN(auto error,
SampleFromErrorDistribution<ModularInt>(
key_power.Len(), key.Variance(), prng_encryption,
key.ModulusParams()));
// Convert the error coefficients into an error polynomial.
auto e = Polynomial<ModularInt>::ConvertToNtt(
std::move(error), key.NttParams(), key.ModulusParams());
// Set the column of the Relinearization matrix.
RLWE_RETURN_IF_ERROR(
e.MulInPlace(key.PlaintextModulus(), key.ModulusParams()));
RLWE_RETURN_IF_ERROR(e.AddInPlace(r[0], key.ModulusParams()));
RLWE_RETURN_IF_ERROR(e.AddInPlace(powers_of_t[i], key.ModulusParams()));
matrix[0].push_back(std::move(e));
matrix[1].push_back(std::move(r[1]));
}
return RelinearizationKeyPart(std::move(matrix), log_decomposition_modulus);
}
template <typename ModularInt>
rlwe::StatusOr<std::vector<Polynomial<ModularInt>>>
RelinearizationKey<ModularInt>::RelinearizationKeyPart::ApplyPartTo(
const Polynomial<ModularInt>& ciphertext_part,
const typename ModularInt::Params* modulus_params,
const NttParameters<ModularInt>* ntt_params) const {
// Convert ciphertext out of NTT form.
std::vector<ModularInt> ciphertext_coefficients =
ciphertext_part.InverseNtt(ntt_params, modulus_params);
// Bit-decompose the vector of coefficients in the ciphertext.
RLWE_ASSIGN_OR_RETURN(
std::vector<std::vector<ModularInt>> decomposed_coefficients,
BitDecompose<ModularInt>(ciphertext_coefficients, modulus_params,
log_decomposition_modulus_, matrix_[0].size()));
// Matrix multiply with the bit-decomposed coefficients.
return MatrixMultiply<ModularInt>(std::move(decomposed_coefficients), matrix_,
modulus_params, ntt_params);
}
template <typename ModularInt>
rlwe::StatusOr<typename RelinearizationKey<ModularInt>::RelinearizationKeyPart>
RelinearizationKey<ModularInt>::RelinearizationKeyPart::Deserialize(
const std::vector<SerializedNttPolynomial>& polynomials,
Uint64 log_decomposition_modulus, SecurePrng* prng,
const ModularIntParams* modulus_params,
const NttParameters<ModularInt>* ntt_params) {
// The polynomials input is a flattened representation of a 2 x dimension
// matrix where the first half corresponds to the first row of matrix and the
// second half corresponds to the second row of matrix. This matrix makes up
// the RelinearizationKeyPart.
int dimension = polynomials.size();
auto matrix = std::vector<std::vector<Polynomial<ModularInt>>>(2);
matrix[0].reserve(dimension);
matrix[1].reserve(dimension);
for (int i = 0; i < dimension; i++) {
RLWE_ASSIGN_OR_RETURN(auto elt, Polynomial<ModularInt>::Deserialize(
polynomials[i], modulus_params));
matrix[0].push_back(std::move(elt));
RLWE_ASSIGN_OR_RETURN(auto sample,
SamplePolynomialFromPrng<ModularInt>(
matrix[0][i].Len(), prng, modulus_params));
matrix[1].push_back(std::move(sample));
}
return RelinearizationKeyPart(std::move(matrix), log_decomposition_modulus);
}
template <typename ModularInt>
RelinearizationKey<ModularInt>::RelinearizationKey(
const SymmetricRlweKey<ModularInt>& key, absl::string_view prng_seed,
ssize_t num_parts, Uint64 log_decomposition_modulus,
Uint64 substitution_power, ModularInt decomposition_modulus,
std::vector<RelinearizationKeyPart> relinearization_key)
: dimension_(ComputeDimension(log_decomposition_modulus,
key.ModulusParams()->modulus)),
num_parts_(num_parts),
log_decomposition_modulus_(log_decomposition_modulus),
decomposition_modulus_(decomposition_modulus),
substitution_power_(substitution_power),
modulus_params_(key.ModulusParams()),
ntt_params_(key.NttParams()),
relinearization_key_(std::move(relinearization_key)),
prng_seed_(prng_seed) {}
template <typename ModularInt>
rlwe::StatusOr<RelinearizationKey<ModularInt>>
RelinearizationKey<ModularInt>::Create(const SymmetricRlweKey<ModularInt>& key,
absl::string_view prng_seed,
ssize_t num_parts,
Uint64 log_decomposition_modulus,
Uint64 substitution_power) {
if (num_parts <= 0) {
return absl::InvalidArgumentError(
absl::StrCat("Num parts: ", num_parts, " must be positive."));
}
if (log_decomposition_modulus <= 0) {
return absl::InvalidArgumentError(
absl::StrCat("Log decomposition modulus, ", log_decomposition_modulus,
", must be positive."));
} else if (log_decomposition_modulus > key.ModulusParams()->log_modulus) {
return absl::InvalidArgumentError(absl::StrCat(
"Log decomposition modulus, ", log_decomposition_modulus,
", must be at most: ", key.ModulusParams()->log_modulus, "."));
}
RLWE_ASSIGN_OR_RETURN(auto decomposition_modulus,
ModularInt::ImportInt(key.ModulusParams()->One()
<< log_decomposition_modulus,
key.ModulusParams()));
// Initialize the first part of the secret key, s.
RLWE_ASSIGN_OR_RETURN(auto key_base, key.Substitute(substitution_power));
auto key_power = key_base.Key();
RLWE_ASSIGN_OR_RETURN(auto prng, SingleThreadPrng::Create(prng_seed));
RLWE_ASSIGN_OR_RETURN(auto prng_encryption_seed,
SingleThreadPrng::GenerateSeed());
RLWE_ASSIGN_OR_RETURN(auto prng_encryption,
SingleThreadPrng::Create(prng_encryption_seed));
auto dimension =
ComputeDimension(log_decomposition_modulus, key.ModulusParams()->modulus);
std::vector<RelinearizationKeyPart> relinearization_key;
relinearization_key.reserve(num_parts);
// Create RealinearizationKeyPart for each of the secret key parts: s, ...,
// s^k.
for (int i = 1; i < num_parts; i++) {
if (i != 1) {
// Increment the power of s.
RLWE_RETURN_IF_ERROR(
key_power.MulInPlace(key_base.Key(), key.ModulusParams()));
}
RLWE_ASSIGN_OR_RETURN(
auto key_part,
RelinearizationKeyPart::Create(
key_power, key, log_decomposition_modulus, decomposition_modulus,
dimension, prng.get(), prng_encryption.get()));
relinearization_key.push_back(std::move(key_part));
}
return RelinearizationKey<ModularInt>(
key, prng_seed, num_parts, log_decomposition_modulus, substitution_power,
decomposition_modulus, std::move(relinearization_key));
}
template <typename ModularInt>
rlwe::StatusOr<SymmetricRlweCiphertext<ModularInt>>
RelinearizationKey<ModularInt>::ApplyTo(
const SymmetricRlweCiphertext<ModularInt>& ciphertext) const {
// Ensure that the length of the ciphertext is less than or equal to the
// length of the relinearization key.
if (ciphertext.Len() > num_parts_) {
return absl::InvalidArgumentError(
"RelinearizationKey not large enough for ciphertext.");
}
// Initialize the result ciphertext of length 2.
RLWE_ASSIGN_OR_RETURN(auto comp, ciphertext.Component(0));
std::vector<Polynomial<ModularInt>> result(
2, Polynomial<ModularInt>(comp.Len(), modulus_params_));
// Apply each RelinearizationKeyPart to the part of the ciphertext it
// corresponds to. The first component of the ciphertext corresponds to the
// "1" part of the secret key, and is added without any
// RelinearizationKeyPart.
result[0] = std::move(comp);
for (int i = 0; i < relinearization_key_.size(); i++) {
// Add RelinearizationKeyPart_i c_i to the result vector.
RLWE_ASSIGN_OR_RETURN(auto temp_comp, ciphertext.Component(i + 1));
RLWE_ASSIGN_OR_RETURN(auto result_part,
relinearization_key_[i].ApplyPartTo(
temp_comp, modulus_params_, ntt_params_));
RLWE_RETURN_IF_ERROR(result[0].AddInPlace(result_part[0], modulus_params_));
RLWE_RETURN_IF_ERROR(result[1].AddInPlace(result_part[1], modulus_params_));
}
return SymmetricRlweCiphertext<ModularInt>(
std::move(result), 1,
ciphertext.Error() +
ciphertext.ErrorParams()->B_relinearize(log_decomposition_modulus_),
modulus_params_, ciphertext.ErrorParams());
}
template <typename ModularInt>
rlwe::StatusOr<SerializedRelinearizationKey>
RelinearizationKey<ModularInt>::Serialize() const {
SerializedRelinearizationKey output;
output.set_log_decomposition_modulus(log_decomposition_modulus_);
output.set_num_parts(num_parts_);
output.set_prng_seed(prng_seed_);
output.set_power_of_s(substitution_power_);
for (const RelinearizationKeyPart& matrix : relinearization_key_) {
// Only serialize the first row of each matrix.
for (const Polynomial<ModularInt>& c : matrix.Matrix()) {
RLWE_ASSIGN_OR_RETURN(*output.add_c(), c.Serialize(modulus_params_));
}
}
return output;
}
template <typename ModularInt>
rlwe::StatusOr<RelinearizationKey<ModularInt>>
RelinearizationKey<ModularInt>::Deserialize(
const SerializedRelinearizationKey& serialized,
const typename ModularInt::Params* modulus_params,
const NttParameters<ModularInt>* ntt_params) {
// Verifies that the number of polynomials in serialized is expected.
// A RelinearizationKey can decrypt ciphertexts with num_parts number of
// components corresponding to decryption under (1, s, ..., s^k) or (1,
// s(x^power)) but only contains parts corresponding to the non-"1"
// components.
if (serialized.num_parts() <= 1) {
return absl::InvalidArgumentError(
absl::StrCat("The number of parts, ", serialized.num_parts(),
", must be greater than one."));
} else if (serialized.c_size() % (serialized.num_parts() - 1) != 0) {
return absl::InvalidArgumentError(
absl::StrCat("The length of serialized, ", serialized.c_size(), ", ",
"must be divisible by the number of parts minus one ",
serialized.num_parts() - 1, "."));
}
// Return an error when log decomposition modulus is non-positive.
if (serialized.log_decomposition_modulus() <= 0) {
return absl::InvalidArgumentError(absl::StrCat(
"Log decomposition modulus, ", serialized.log_decomposition_modulus(),
", must be positive."));
} else if (serialized.log_decomposition_modulus() >
modulus_params->log_modulus) {
return absl::InvalidArgumentError(absl::StrCat(
"Log decomposition modulus, ", serialized.log_decomposition_modulus(),
", must be at most: ", modulus_params->log_modulus, "."));
}
int polynomials_per_matrix =
serialized.c_size() / (serialized.num_parts() - 1);
int dimension = polynomials_per_matrix;
if (dimension != ComputeDimension(serialized.log_decomposition_modulus(),
modulus_params->modulus)) {
return absl::InvalidArgumentError(
absl::StrCat("Number of NTT Polynomials does not match expected ",
"number of matrix entries."));
}
RLWE_ASSIGN_OR_RETURN(
auto decomposition_modulus,
ModularInt::ImportInt(static_cast<typename ModularInt::Int>(1)
<< serialized.log_decomposition_modulus(),
modulus_params));
RelinearizationKey output(serialized.log_decomposition_modulus(),
decomposition_modulus, modulus_params, ntt_params);
output.dimension_ = dimension;
output.num_parts_ = serialized.num_parts();
output.prng_seed_ = serialized.prng_seed();
output.substitution_power_ = serialized.power_of_s();
// Create prng based on seed.
RLWE_ASSIGN_OR_RETURN(auto prng, SingleThreadPrng::Create(output.prng_seed_));
// Takes each polynomials_per_matrix chunk of serialized.c()'s and places them
// into a KeyPart.
output.relinearization_key_.reserve(serialized.num_parts() - 1);
for (int i = 0; i < (serialized.num_parts() - 1); i++) {
auto start = serialized.c().begin() + i * polynomials_per_matrix;
auto end = start + polynomials_per_matrix;
std::vector<SerializedNttPolynomial> chunk(start, end);
RLWE_ASSIGN_OR_RETURN(auto deserialized,
RelinearizationKeyPart::Deserialize(
chunk, serialized.log_decomposition_modulus(),
prng.get(), modulus_params, ntt_params));
output.relinearization_key_.push_back(std::move(deserialized));
}
return output;
}
// Instantiations of RelinearizationKey with specific MontgomeryInt classes.
// If any new types are added, montgomery.h should be updated accordingly (such
// as ensuring BigInt is correctly specialized, etc.).
template class EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) RelinearizationKey<MontgomeryInt<Uint16>>;
template class EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) RelinearizationKey<MontgomeryInt<Uint32>>;
template class EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) RelinearizationKey<MontgomeryInt<Uint64>>;
template class EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) RelinearizationKey<MontgomeryInt<absl::uint128>>;
} // namespace rlwe