chromium/third_party/private_membership/src/internal/rlwe_params.cc

// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "third_party/private_membership/src/internal/rlwe_params.h"

#include <memory>
#include <utility>
#include <vector>

#include "third_party/private_membership/src/private_membership_rlwe.pb.h"
#include "third_party/private_membership/src/internal/constants.h"
#include "third_party/shell-encryption/src/montgomery.h"
#include "third_party/shell-encryption/src/status_macros.h"

namespace private_membership {
namespace rlwe {

template <>
::rlwe::StatusOr<
    std::vector<std::unique_ptr<const ::rlwe::RlweContext<ModularInt64>>>>
CreateContexts<ModularInt64>(const RlweParameters& rlwe_params) {
  size_t number_moduli = rlwe_params.modulus().size();
  std::vector<std::unique_ptr<const ::rlwe::RlweContext<ModularInt64>>>
      contexts;
  contexts.reserve(number_moduli);
  for (size_t i = 0; i < number_moduli; i++) {
    if (rlwe_params.modulus(i).hi() > 0) {
      return absl::InvalidArgumentError("Modulus cannot fit into 64-bits.");
    }
    RLWE_ASSIGN_OR_RETURN(
        auto context,
        ::rlwe::RlweContext<ModularInt64>::Create(
            {/*.modulus =*/rlwe_params.modulus(i).lo(),
             /*.log_n =*/static_cast<size_t>(rlwe_params.log_degree()),
             /*.log_t =*/static_cast<size_t>(rlwe_params.log_t()),
             /*.variance =*/static_cast<size_t>(rlwe_params.variance())}));
    contexts.push_back(std::move(context));
  }
  return contexts;
}

template <>
::rlwe::StatusOr<
    std::vector<std::unique_ptr<const ::rlwe::RlweContext<ModularInt128>>>>
CreateContexts<ModularInt128>(const RlweParameters& rlwe_params) {
  size_t number_moduli = rlwe_params.modulus().size();
  std::vector<std::unique_ptr<const ::rlwe::RlweContext<ModularInt128>>>
      contexts;
  contexts.reserve(number_moduli);
  for (size_t i = 0; i < number_moduli; i++) {
    absl::uint128 modulus128 = absl::MakeUint128(rlwe_params.modulus(i).hi(),
                                                 rlwe_params.modulus(i).lo());
    RLWE_ASSIGN_OR_RETURN(
        auto context,
        ::rlwe::RlweContext<ModularInt128>::Create(
            {/*.modulus =*/modulus128,
             /*.log_n =*/static_cast<size_t>(rlwe_params.log_degree()),
             /*.log_t =*/static_cast<size_t>(rlwe_params.log_t()),
             /*.variance =*/static_cast<size_t>(rlwe_params.variance())}));
    contexts.push_back(std::move(context));
  }
  return contexts;
}

template <>
::rlwe::StatusOr<std::unique_ptr<const ModularInt64::Params>>
CreateModulusParams<ModularInt64>(const Uint128& modulus) {
  if (modulus.hi() > 0) {
    return absl::InvalidArgumentError("Modulus cannot fit into 64-bits.");
  }
  return ModularInt64::Params::Create(modulus.lo());
}

template <>
::rlwe::StatusOr<std::unique_ptr<const ModularInt128::Params>>
CreateModulusParams<ModularInt128>(const Uint128& modulus) {
  absl::uint128 modulus128 = absl::MakeUint128(modulus.hi(), modulus.lo());
  return ModularInt128::Params::Create(modulus128);
}

template <typename ModularInt>
::rlwe::StatusOr<std::unique_ptr<const ::rlwe::NttParameters<ModularInt>>>
CreateNttParams(const RlweParameters& rlwe_params,
                const typename ModularInt::Params* modulus_params) {
  RLWE_ASSIGN_OR_RETURN(::rlwe::NttParameters<ModularInt> ntt_params,
                        ::rlwe::InitializeNttParameters<ModularInt>(
                            rlwe_params.log_degree(), modulus_params));
  return std::make_unique<const ::rlwe::NttParameters<ModularInt>>(
      std::move(ntt_params));
}

template ::rlwe::StatusOr<
    std::unique_ptr<const ::rlwe::NttParameters<ModularInt64>>>
CreateNttParams<ModularInt64>(const RlweParameters&,
                              const ModularInt64::Params*);
template ::rlwe::StatusOr<
    std::unique_ptr<const ::rlwe::NttParameters<ModularInt128>>>
CreateNttParams<ModularInt128>(const RlweParameters&,
                               const ModularInt128::Params*);

template <typename ModularInt>
::rlwe::StatusOr<std::unique_ptr<const ::rlwe::ErrorParams<ModularInt>>>
CreateErrorParams(const RlweParameters& rlwe_params,
                  const typename ModularInt::Params* modulus_params,
                  const ::rlwe::NttParameters<ModularInt>* ntt_params) {
  RLWE_ASSIGN_OR_RETURN(auto error_params,
                        ::rlwe::ErrorParams<ModularInt>::Create(
                            rlwe_params.log_t(), rlwe_params.variance(),
                            modulus_params, ntt_params));
  return std::make_unique<const ::rlwe::ErrorParams<ModularInt>>(error_params);
}

template ::rlwe::StatusOr<
    std::unique_ptr<const ::rlwe::ErrorParams<ModularInt64>>>
CreateErrorParams<ModularInt64>(const RlweParameters&,
                                const ModularInt64::Params*,
                                const ::rlwe::NttParameters<ModularInt64>*);
template ::rlwe::StatusOr<
    std::unique_ptr<const ::rlwe::ErrorParams<ModularInt128>>>
CreateErrorParams<ModularInt128>(const RlweParameters&,
                                 const ModularInt128::Params*,
                                 const ::rlwe::NttParameters<ModularInt128>*);

}  // namespace rlwe
}  // namespace private_membership