// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "third_party/private_membership/src/private_membership_rlwe_client.h"
#include <algorithm>
#include <iterator>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "third_party/private-join-and-compute/src/crypto/ec_commutative_cipher.h"
#include "third_party/private_membership/src/internal/crypto_utils.h"
#include "third_party/private_membership/src/private_membership.pb.h"
#include "third_party/private_membership/src/private_membership_rlwe.pb.h"
#include "third_party/private_membership/src/internal/constants.h"
#include "third_party/private_membership/src/internal/encrypted_bucket_id.h"
#include "third_party/private_membership/src/internal/hashed_bucket_id.h"
#include "third_party/private_membership/src/internal/rlwe_id_utils.h"
#include "third_party/private_membership/src/internal/rlwe_params.h"
#include "third_party/private_membership/src/internal/utils.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "third_party/shell-encryption/src/polynomial.h"
#include "third_party/shell-encryption/src/status_macros.h"
#include "third_party/shell-encryption/src/symmetric_encryption_with_prng.h"
#include "third_party/shell-encryption/src/transcription.h"
namespace private_membership {
namespace rlwe {
::rlwe::StatusOr<std::unique_ptr<PrivateMembershipRlweClient>>
PrivateMembershipRlweClient::Create(
private_membership::rlwe::RlweUseCase use_case,
const std::vector<RlwePlaintextId>& plaintext_ids) {
return CreateInternal(use_case, plaintext_ids, std::optional<std::string>(),
internal::PrngSeedGenerator::Create(),
internal::SecurePrngGenerator::Create());
}
::rlwe::StatusOr<std::unique_ptr<PrivateMembershipRlweClient>>
PrivateMembershipRlweClient::CreateForTesting(
private_membership::rlwe::RlweUseCase use_case,
const std::vector<RlwePlaintextId>& plaintext_ids,
absl::string_view ec_cipher_key, absl::string_view seed) {
RLWE_ASSIGN_OR_RETURN(auto prng_seed_generator,
internal::PrngSeedGenerator::CreateDeterministic(seed));
return CreateInternal(
use_case, plaintext_ids, std::optional<std::string>(ec_cipher_key),
std::move(prng_seed_generator), internal::SecurePrngGenerator::Create());
}
::rlwe::StatusOr<std::unique_ptr<PrivateMembershipRlweClient>>
PrivateMembershipRlweClient::CreateInternal(
private_membership::rlwe::RlweUseCase use_case,
const std::vector<RlwePlaintextId>& plaintext_ids,
std::optional<std::string> ec_cipher_key,
std::unique_ptr<internal::PrngSeedGenerator> prng_seed_generator,
std::unique_ptr<internal::PrngGenerator> prng_generator) {
if (use_case == private_membership::rlwe::RLWE_USE_CASE_UNDEFINED) {
return absl::InvalidArgumentError("Use case must be defined.");
}
if (plaintext_ids.empty()) {
return absl::InvalidArgumentError("Plaintext ids must not be empty.");
}
// Remove duplicate IDs.
absl::flat_hash_set<std::string> hashed_rlwe_plaintext_ids;
std::vector<RlwePlaintextId> unique_plaintext_ids;
for (int i = 0; i < plaintext_ids.size(); ++i) {
std::string hash = HashRlwePlaintextId(plaintext_ids[i]);
if (!hashed_rlwe_plaintext_ids.contains(hash)) {
unique_plaintext_ids.push_back(plaintext_ids[i]);
}
hashed_rlwe_plaintext_ids.insert(hash);
}
// Create the cipher with new key or from existing key depending on whether
// the key was provided.
auto ec_cipher =
ec_cipher_key.has_value()
? ::private_join_and_compute::ECCommutativeCipher::CreateFromKey(
kCurveId, ec_cipher_key.value(),
::private_join_and_compute::ECCommutativeCipher::HashType::SHA256)
: ::private_join_and_compute::ECCommutativeCipher::CreateWithNewKey(
kCurveId, ::private_join_and_compute::ECCommutativeCipher::HashType::SHA256);
if (!ec_cipher.ok()) {
return ec_cipher.status();
}
return absl::WrapUnique<PrivateMembershipRlweClient>(
new PrivateMembershipRlweClient(
use_case, unique_plaintext_ids, std::move(ec_cipher).value(),
std::move(prng_seed_generator), std::move(prng_generator)));
}
PrivateMembershipRlweClient::PrivateMembershipRlweClient(
private_membership::rlwe::RlweUseCase use_case,
const std::vector<RlwePlaintextId>& plaintext_ids,
std::unique_ptr<::private_join_and_compute::ECCommutativeCipher> ec_cipher,
std::unique_ptr<internal::PrngSeedGenerator> prng_seed_generator,
std::unique_ptr<internal::PrngGenerator> prng_generator)
: use_case_(use_case),
plaintext_ids_(plaintext_ids),
ec_cipher_(std::move(ec_cipher)),
prng_seed_generator_(std::move(prng_seed_generator)),
prng_generator_(std::move(prng_generator)) {}
::rlwe::StatusOr<private_membership::rlwe::PrivateMembershipRlweOprfRequest>
PrivateMembershipRlweClient::CreateOprfRequest() {
private_membership::rlwe::PrivateMembershipRlweOprfRequest request;
request.set_use_case(use_case_);
// Encrypt the plaintext ids with the client generated key.
for (const auto& plaintext_id : plaintext_ids_) {
std::string whole_id = HashRlwePlaintextId(plaintext_id);
auto client_encrypted_id = ec_cipher_->Encrypt(whole_id);
if (!client_encrypted_id.ok()) {
return client_encrypted_id.status();
}
*request.add_encrypted_ids() = client_encrypted_id.value();
// Populate the map of client encrypted id to plaintext id.
client_encrypted_id_to_plaintext_id_[client_encrypted_id.value()] =
plaintext_id;
}
return request;
}
absl::Status PrivateMembershipRlweClient::ValidateOprfResponse(
const private_membership::rlwe::PrivateMembershipRlweOprfResponse&
oprf_response) const {
// Check for valid bucket ID lengths.
int encrypted_bucket_id_length =
oprf_response.encrypted_buckets_parameters().encrypted_bucket_id_length();
if (encrypted_bucket_id_length < 0 ||
encrypted_bucket_id_length > kMaxEncryptedBucketIdLength) {
return absl::InvalidArgumentError(absl::StrCat(
"Encrypted bucket ID length must be non-negative and at most ",
kMaxEncryptedBucketIdLength, "."));
}
// Check number of responses.
if (oprf_response.doubly_encrypted_ids_size() <
client_encrypted_id_to_plaintext_id_.size()) {
return absl::InvalidArgumentError(
"OPRF response missing a response to a requested ID.");
} else if (oprf_response.doubly_encrypted_ids_size() >
client_encrypted_id_to_plaintext_id_.size()) {
return absl::InvalidArgumentError(
"OPRF response contains too many responses.");
}
return absl::OkStatus();
}
::rlwe::StatusOr<private_membership::rlwe::PrivateMembershipRlweQueryRequest>
PrivateMembershipRlweClient::CreateQueryRequest(
const private_membership::rlwe::PrivateMembershipRlweOprfResponse&
oprf_response) {
auto validation_result = ValidateOprfResponse(oprf_response);
if (!validation_result.ok()) {
return validation_result;
}
// Initialize PIR client.
int encrypted_bucket_id_length =
oprf_response.encrypted_buckets_parameters().encrypted_bucket_id_length();
int encrypted_buckets_count = 1 << encrypted_bucket_id_length;
RLWE_ASSIGN_OR_RETURN(
pir_client_, internal::PirClient::Create(
oprf_response.rlwe_parameters(), encrypted_buckets_count,
prng_seed_generator_.get(), prng_generator_.get()));
private_membership::rlwe::PrivateMembershipRlweQueryRequest request;
request.set_use_case(use_case_);
request.set_key_version(oprf_response.key_version());
// Keep track of seen plaintext IDs to check for duplicates.
absl::flat_hash_set<std::string> seen_encrypted_ids;
for (const auto& doubly_encrypted_id : oprf_response.doubly_encrypted_ids()) {
private_membership::rlwe::PrivateMembershipRlweQuery single_query;
single_query.set_queried_encrypted_id(
doubly_encrypted_id.queried_encrypted_id());
const std::string& encrypted_id =
doubly_encrypted_id.queried_encrypted_id();
// Check validity of returned queried ID.
if (!client_encrypted_id_to_plaintext_id_.contains(encrypted_id)) {
return absl::InvalidArgumentError(
"OPRF response contains a response to an erroneous encrypted ID.");
}
// Already processed a response for this encrypted ID. Ignore this one.
if (seen_encrypted_ids.contains(encrypted_id)) {
return absl::InvalidArgumentError(
"OPRF response contains duplicate responses for the same ID.");
}
seen_encrypted_ids.insert(encrypted_id);
// Compute the hashed bucket id if the hashed bucket parameter is set in
// the response.
if (oprf_response.hashed_buckets_parameters().hashed_bucket_id_length() >
0) {
const RlwePlaintextId& plaintext_id =
client_encrypted_id_to_plaintext_id_[encrypted_id];
RLWE_ASSIGN_OR_RETURN(
HashedBucketId hashed_bucket_id,
HashedBucketId::Create(plaintext_id,
oprf_response.hashed_buckets_parameters(),
&context_));
*single_query.mutable_hashed_bucket_id() = hashed_bucket_id.ToApiProto();
}
// Decrypt doubly encrypted id to retrieve id encrypted only by the server
// key.
auto server_encrypted_id =
ec_cipher_->Decrypt(doubly_encrypted_id.doubly_encrypted_id());
if (!server_encrypted_id.ok()) {
return server_encrypted_id.status();
}
// Truncate the hash of the server encrypted id by the first
// encrypted_bucket_id_length bits to compute the encrypted bucket id.
RLWE_ASSIGN_OR_RETURN(
EncryptedBucketId encrypted_bucket_id_obj,
EncryptedBucketId::Create(server_encrypted_id.value(),
oprf_response.encrypted_buckets_parameters(),
&context_));
RLWE_ASSIGN_OR_RETURN(int encrypted_bucket_id,
encrypted_bucket_id_obj.ToUint32());
// Create query request.
RLWE_ASSIGN_OR_RETURN(*single_query.mutable_pir_request(),
pir_client_->CreateRequest(encrypted_bucket_id));
client_encrypted_id_to_server_encrypted_id_[encrypted_id] =
std::move(server_encrypted_id).value();
*request.add_queries() = single_query;
}
hashed_bucket_params_ = oprf_response.hashed_buckets_parameters();
encrypted_bucket_params_ = oprf_response.encrypted_buckets_parameters();
return request;
}
absl::Status PrivateMembershipRlweClient::ValidateQueryResponse(
const private_membership::rlwe::PrivateMembershipRlweQueryResponse&
query_response) const {
// Check response length for missing responses.
if (query_response.pir_responses_size() <
client_encrypted_id_to_plaintext_id_.size()) {
return absl::InvalidArgumentError(
"Query response missing a response to a requested ID.");
} else if (query_response.pir_responses_size() >
client_encrypted_id_to_plaintext_id_.size()) {
return absl::InvalidArgumentError(
"Query response contains too many responses.");
}
return absl::OkStatus();
}
::rlwe::StatusOr<RlweMembershipResponses>
PrivateMembershipRlweClient::ProcessQueryResponse(
const private_membership::rlwe::PrivateMembershipRlweQueryResponse&
query_response) {
auto validation_result = ValidateQueryResponse(query_response);
if (!validation_result.ok()) {
return validation_result;
}
// Keep track of seen encrypted IDs to avoid duplicates.
absl::flat_hash_set<std::string> seen_encrypted_ids;
RlweMembershipResponses membership_responses;
for (const auto& pir_response : query_response.pir_responses()) {
const std::string& encrypted_id = pir_response.queried_encrypted_id();
if (!client_encrypted_id_to_plaintext_id_.contains(encrypted_id) ||
!client_encrypted_id_to_server_encrypted_id_.contains(encrypted_id)) {
return absl::InvalidArgumentError(
"Query response contains a response to an erroneous encrypted ID.");
}
// Already processed this encrypted ID. Ignore this one.
if (seen_encrypted_ids.contains(encrypted_id)) {
return absl::InvalidArgumentError(
"Query response contains duplicate responses for the same ID.");
}
seen_encrypted_ids.insert(encrypted_id);
RLWE_ASSIGN_OR_RETURN(
std::vector<uint8_t> serialized_encrypted_bucket_byte,
pir_client_->ProcessResponse(pir_response.pir_response()));
std::string serialized_encrypted_bucket;
if (!serialized_encrypted_bucket_byte.empty()) {
RLWE_ASSIGN_OR_RETURN(serialized_encrypted_bucket,
private_membership::Unpad(std::string(
serialized_encrypted_bucket_byte.begin(),
serialized_encrypted_bucket_byte.end())));
}
private_membership::rlwe::EncryptedBucket encrypted_bucket;
if (!serialized_encrypted_bucket.empty() &&
!encrypted_bucket.ParseFromString(serialized_encrypted_bucket)) {
return absl::InternalError("Parsing serialized encrypted bucket failed.");
}
// Plaintext id associated with the client encrypted id.
const RlwePlaintextId& plaintext_id =
client_encrypted_id_to_plaintext_id_[encrypted_id];
// Server key encrypted id associated with the client encrypted id.
const std::string& server_encrypted_id =
client_encrypted_id_to_server_encrypted_id_[encrypted_id];
RLWE_ASSIGN_OR_RETURN(auto membership, CheckMembership(server_encrypted_id,
encrypted_bucket));
auto* response = membership_responses.add_membership_responses();
*response->mutable_plaintext_id() = plaintext_id;
*response->mutable_membership_response() = membership;
}
return membership_responses;
}
::rlwe::StatusOr<private_membership::MembershipResponse>
PrivateMembershipRlweClient::CheckMembership(
absl::string_view server_encrypted_id,
const private_membership::rlwe::EncryptedBucket& encrypted_bucket) {
private_membership::MembershipResponse membership_response;
RLWE_ASSIGN_OR_RETURN(
std::string to_match_hash,
ComputeBucketStoredEncryptedId(server_encrypted_id,
encrypted_bucket_params_, &context_));
for (const auto& encrypted_id_value_pair :
encrypted_bucket.encrypted_id_value_pairs()) {
const auto& encrypted_id = encrypted_id_value_pair.encrypted_id();
// Check encrypted_id is a prefix of to_match_hash. If it is, then the id
// is a member.
if (std::equal(encrypted_id.begin(), encrypted_id.end(),
to_match_hash.begin())) {
membership_response.set_is_member(true);
if (!encrypted_id_value_pair.encrypted_value().empty()) {
RLWE_ASSIGN_OR_RETURN(
std::string decrypted_value,
private_membership::DecryptValue(
server_encrypted_id, encrypted_id_value_pair.encrypted_value(),
&context_));
membership_response.set_value(decrypted_value);
}
break;
}
}
return membership_response;
}
namespace internal {
std::unique_ptr<PrngSeedGenerator> PrngSeedGenerator::Create() {
return absl::WrapUnique<PrngSeedGenerator>(new PrngSeedGenerator());
}
::rlwe::StatusOr<std::unique_ptr<PrngSeedGenerator>>
PrngSeedGenerator::CreateDeterministic(absl::string_view seed) {
RLWE_ASSIGN_OR_RETURN(auto prng_seed_generator,
SingleThreadPrng::Create(seed));
return absl::WrapUnique<PrngSeedGenerator>(
new PrngSeedGenerator(std::move(prng_seed_generator)));
}
::rlwe::StatusOr<std::string> PrngSeedGenerator::GeneratePrngSeed() const {
if (deterministic_prng_seed_generator_.has_value()) {
std::string res(SingleThreadPrng::SeedLength(), 0);
for (int i = 0; i < res.length(); ++i) {
RLWE_ASSIGN_OR_RETURN(
res[i], deterministic_prng_seed_generator_.value()->Rand8());
}
return res;
}
return SingleThreadPrng::GenerateSeed();
}
PrngSeedGenerator::PrngSeedGenerator(
std::unique_ptr<SingleThreadPrng> prng_seed_generator)
: deterministic_prng_seed_generator_(
std::optional<std::unique_ptr<SingleThreadPrng>>(
std::move(prng_seed_generator))) {}
template <typename ModularInt>
::rlwe::StatusOr<std::unique_ptr<PirClientImpl<ModularInt>>>
PirClientImpl<ModularInt>::Create(const RlweParameters& rlwe_params,
int total_entry_count,
const PrngSeedGenerator* prng_seed_generator,
const PrngGenerator* prng_generator) {
if (rlwe_params.log_degree() < 0 ||
rlwe_params.log_degree() > kMaxLogDegree) {
return absl::InvalidArgumentError(
"Degree must be positive and at most 2^20.");
}
int levels_of_recursion = rlwe_params.levels_of_recursion();
if (levels_of_recursion <= 0 || levels_of_recursion > kMaxLevelsOfRecursion) {
return absl::InvalidArgumentError(
absl::StrCat("Levels of recursion, ", levels_of_recursion,
", must be positive and at most ", kMaxLevelsOfRecursion));
}
// Create parameters.
std::vector<std::unique_ptr<const typename ModularInt::Params>>
modulus_params;
modulus_params.reserve(rlwe_params.modulus_size());
std::vector<std::unique_ptr<const ::rlwe::NttParameters<ModularInt>>>
ntt_params;
ntt_params.reserve(rlwe_params.modulus_size());
std::vector<std::unique_ptr<const ::rlwe::ErrorParams<ModularInt>>>
error_params;
error_params.reserve(rlwe_params.modulus_size());
for (int i = 0; i < rlwe_params.modulus_size(); ++i) {
RLWE_ASSIGN_OR_RETURN(
auto temp_modulus_params,
CreateModulusParams<ModularInt>(rlwe_params.modulus(i)));
modulus_params.push_back(std::move(temp_modulus_params));
RLWE_ASSIGN_OR_RETURN(
auto temp_ntt_params,
CreateNttParams<ModularInt>(rlwe_params, modulus_params[i].get()));
ntt_params.push_back(std::move(temp_ntt_params));
RLWE_ASSIGN_OR_RETURN(
auto temp_error_params,
CreateErrorParams<ModularInt>(rlwe_params, modulus_params[i].get(),
ntt_params[i].get()));
error_params.push_back(std::move(temp_error_params));
}
RLWE_ASSIGN_OR_RETURN(std::string prng_seed,
prng_seed_generator->GeneratePrngSeed());
RLWE_ASSIGN_OR_RETURN(auto prng, prng_generator->CreatePrng(prng_seed));
RLWE_ASSIGN_OR_RETURN(
auto key,
::rlwe::SymmetricRlweKey<ModularInt>::Sample(
rlwe_params.log_degree(), rlwe_params.variance(), rlwe_params.log_t(),
modulus_params[0].get(), ntt_params[0].get(), prng.get()));
return absl::WrapUnique<>(new PirClientImpl(
rlwe_params, std::move(modulus_params), std::move(ntt_params),
std::move(error_params), key, total_entry_count, prng_seed_generator,
prng_generator));
}
template <typename ModularInt>
PirClientImpl<ModularInt>::PirClientImpl(
const RlweParameters& rlwe_params,
std::vector<std::unique_ptr<const typename ModularInt::Params>>
modulus_params,
std::vector<std::unique_ptr<const ::rlwe::NttParameters<ModularInt>>>
ntt_params,
std::vector<std::unique_ptr<const ::rlwe::ErrorParams<ModularInt>>>
error_params,
const ::rlwe::SymmetricRlweKey<ModularInt>& key, int total_entry_count,
const PrngSeedGenerator* prng_seed_generator,
const PrngGenerator* prng_generator)
: rlwe_params_(rlwe_params),
modulus_params_(std::move(modulus_params)),
ntt_params_(std::move(ntt_params)),
error_params_(std::move(error_params)),
key_(key),
total_entry_count_(total_entry_count),
prng_seed_generator_(prng_seed_generator),
prng_generator_(prng_generator) {}
template <typename ModularInt>
::rlwe::StatusOr<PirRequest> PirClientImpl<ModularInt>::CreateRequest(
int index) {
if (index < 0 || index >= total_entry_count_) {
return absl::InvalidArgumentError("Index out of bounds.");
}
PirRequest req;
// The number of virtual entries per level of recursion = the
// (levels_of_recursion)th root of the number of items in the database.
double exact_entries_per_level =
pow(total_entry_count_, 1.0 / rlwe_params_.levels_of_recursion());
// Round this number up to the nearest whole integer.
int branching_factor = static_cast<int>(ceil(exact_entries_per_level));
// Create the ciphertexts for each level of recursion. This two-dimensional
// table is flattened when it is put into the proto.
// Determine the number of actual database items stored in each virtual
// database block at this level. This is the number of items remaining
// divided by the branching factor, rounded up.
int items_in_block =
(total_entry_count_ + branching_factor - 1) / branching_factor;
// The index of the item we want to request at the current level of recursion.
int index_remaining = index;
// Create useful zero polynomial.
std::vector<ModularInt> zeroes(
1 << rlwe_params_.log_degree(),
ModularInt::ImportZero(modulus_params_[0].get()));
::rlwe::Polynomial<ModularInt> zero_poly =
::rlwe::Polynomial<ModularInt>(zeroes);
// Create useful indicator polynomial.
std::vector<ModularInt> indicator(zeroes);
indicator[0] = ModularInt::ImportOne(modulus_params_[0].get());
const ::rlwe::Polynomial<ModularInt> indicator_poly =
::rlwe::Polynomial<ModularInt>::ConvertToNtt(indicator, *(ntt_params_[0]),
modulus_params_[0].get());
// Fill plaintext indicator vector with only zeroes at first.
if (branching_factor * rlwe_params_.levels_of_recursion() >
kMaxRequestEntries) {
return absl::InvalidArgumentError(
absl::StrCat("Number of request entries exceeds ", kMaxRequestEntries));
}
std::vector<::rlwe::Polynomial<ModularInt>> plaintexts(
branching_factor * rlwe_params_.levels_of_recursion(), zero_poly);
// Fill appropriate indicator for each level of recursion.
for (int level = 0; level < rlwe_params_.levels_of_recursion(); ++level) {
// Determine which block contains the item we wish to request.
int index_at_level = index_remaining / items_in_block;
int index_in_plaintext = (level * branching_factor) + index_at_level;
plaintexts[index_in_plaintext] = indicator_poly;
// Determine the index of the desired item within that block. This is
// the index within the items that remain after this level of recursion.
index_remaining = index_remaining % items_in_block;
// Update the block size for the next level of recursion.
items_in_block = (items_in_block + branching_factor - 1) / branching_factor;
}
RLWE_ASSIGN_OR_RETURN(auto prng_seed,
prng_seed_generator_->GeneratePrngSeed());
req.set_prng_seed(prng_seed);
RLWE_ASSIGN_OR_RETURN(auto prng, prng_generator_->CreatePrng(prng_seed));
RLWE_ASSIGN_OR_RETURN(std::string prng_encryption_seed,
prng_seed_generator_->GeneratePrngSeed());
RLWE_ASSIGN_OR_RETURN(auto prng_encryption,
prng_generator_->CreatePrng(prng_encryption_seed));
RLWE_ASSIGN_OR_RETURN(std::vector<::rlwe::Polynomial<ModularInt>> ciphertexts,
::rlwe::EncryptWithPrng(key_, plaintexts, prng.get(),
prng_encryption.get()));
for (int i = 0; i < ciphertexts.size(); ++i) {
RLWE_ASSIGN_OR_RETURN(*req.add_request(),
ciphertexts[i].Serialize(modulus_params_[0].get()));
}
return req;
}
template <typename ModularInt>
::rlwe::StatusOr<std::vector<uint8_t>>
PirClientImpl<ModularInt>::ProcessResponse(const PirResponse& response) {
if (response.plaintext_entry_size() < 0 ||
response.plaintext_entry_size() > kMaxPlaintextEntrySize) {
return absl::InvalidArgumentError(
"Invalid plaintext entry size that must be at most 10 MB in length.");
}
std::vector<uint8_t> raw_bytes;
for (int i = 0; i < response.response_size(); i++) {
const typename ModularInt::Params* decrypt_modulus_params;
const ::rlwe::NttParameters<ModularInt>* decrypt_ntt_params;
const ::rlwe::ErrorParams<ModularInt>* decrypt_error_params;
::rlwe::SymmetricRlweKey<ModularInt> decrypt_key = key_;
if (modulus_params_.size() == 2) {
decrypt_modulus_params = modulus_params_[1].get();
decrypt_ntt_params = ntt_params_[1].get();
decrypt_error_params = error_params_[1].get();
RLWE_ASSIGN_OR_RETURN(
decrypt_key,
key_.SwitchModulus(decrypt_modulus_params, decrypt_ntt_params));
} else if (modulus_params_.size() == 1) {
decrypt_modulus_params = modulus_params_[0].get();
decrypt_ntt_params = ntt_params_[0].get();
decrypt_error_params = error_params_[0].get();
} else {
return absl::InternalError("More than two moduli.");
}
RLWE_ASSIGN_OR_RETURN(
auto ciphertext,
::rlwe::SymmetricRlweCiphertext<ModularInt>::Deserialize(
response.response(i), decrypt_modulus_params,
decrypt_error_params));
RLWE_ASSIGN_OR_RETURN(std::vector<typename ModularInt::Int> plaintext,
::rlwe::Decrypt(decrypt_key, ciphertext));
RLWE_ASSIGN_OR_RETURN(
std::vector<uint8_t> column,
(::rlwe::TranscribeBits<typename ModularInt::Int, uint8_t>(
plaintext, key_.Len() * key_.BitsPerCoeff(), key_.BitsPerCoeff(),
8)));
raw_bytes.insert(raw_bytes.end(), std::make_move_iterator(column.begin()),
std::make_move_iterator(column.end()));
}
raw_bytes.resize(response.plaintext_entry_size());
return raw_bytes;
}
::rlwe::StatusOr<std::unique_ptr<internal::PirClient>>
internal::PirClient::Create(const RlweParameters& rlwe_params,
int total_entry_count,
const PrngSeedGenerator* prng_seed_generator,
const PrngGenerator* prng_generator) {
if (rlwe_params.modulus_size() <= 0) {
return absl::InvalidArgumentError("Must provide at least one modulus.");
}
if (rlwe_params.modulus(0).hi() > 0 ||
(rlwe_params.modulus(0).lo() >> 62) > 0) {
RLWE_ASSIGN_OR_RETURN(
auto client, PirClientImpl<ModularInt128>::Create(
rlwe_params, total_entry_count, prng_seed_generator,
prng_generator));
return std::unique_ptr<internal::PirClient>(std::move(client));
} else {
RLWE_ASSIGN_OR_RETURN(
auto client, PirClientImpl<ModularInt64>::Create(
rlwe_params, total_entry_count, prng_seed_generator,
prng_generator));
return std::unique_ptr<internal::PirClient>(std::move(client));
}
}
template class PirClientImpl<ModularInt64>;
template class PirClientImpl<ModularInt128>;
} // namespace internal
} // namespace rlwe
} // namespace private_membership