chromium/third_party/shell-encryption/src/montgomery.cc

// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "montgomery.h"

#include "third_party/shell-encryption/base/shell_encryption_export.h"
#include "third_party/shell-encryption/base/shell_encryption_export_template.h"
#include "transcription.h"

namespace rlwe {

template <typename T>
rlwe::StatusOr<std::unique_ptr<const MontgomeryIntParams<T>>>
MontgomeryIntParams<T>::Create(Int modulus) {
  // Check that the modulus is smaller than max(Int) / 4.
  Int most_significant_bit = modulus >> (bitsize_int - 2);
  if (most_significant_bit != 0) {
    return absl::InvalidArgumentError(absl::StrCat(
        "The modulus should be less than 2^", (bitsize_int - 2), "."));
  }
  if ((modulus % 2) == 0) {
    return absl::InvalidArgumentError(
        absl::StrCat("The modulus should be odd."));
  }
  return absl::WrapUnique<const MontgomeryIntParams>(
      new MontgomeryIntParams(modulus));
}

// From Hacker's Delight.
template <typename T>
std::tuple<T, T> MontgomeryIntParams<T>::Inverses(BigInt modulus_bigint,
                                                  BigInt r) {
  // Invariants
  //   1) sum = x * 2^w - y * modulus.
  //   2) sum is always a power of 2.
  //   3) modulus is odd.
  //   4) y is always even.
  // sum will decrease from 2^w to 2^0 = 1
  BigInt x = 1;
  BigInt y = 0;
  for (int i = bitsize_int; i > 0; i--) {
    // Ensure that x is even.
    if ((x & 1) == 1) {
      // If x is odd, make x even by adding modulus to x and changing the
      // value of y accordingly (y remains even).
      //
      //     sum = x * 2^w - y * modulus
      //     sum = (x + modulus) * 2^w - (y + 2^w) * modulus
      //
      // We can then divide the new values of x and y by 2 safely.
      x += modulus_bigint;
      y += r;
    }
    // Divide x and y by 2
    x >>= 1;
    y >>= 1;
  }
  // Return the inverses
  return std::make_tuple(static_cast<Int>(x), static_cast<Int>(y));
}

template <typename T>
rlwe::StatusOr<MontgomeryInt<T>> MontgomeryInt<T>::ImportInt(
    Int n, const Params* params) {
  BigInt product = static_cast<BigInt>(params->r_mod_modulus_barrett) * n;
  Int result = static_cast<Int>(product >> Params::bitsize_int);
  result = n * params->r_mod_modulus - result * params->modulus;
  // The steps above produce an integer that is in the range [0, 2N).
  // We now reduce to the range [0, N).
  result -= (result >= params->modulus) ? params->modulus : 0;
  return MontgomeryInt(result);
}

template <typename T>
MontgomeryInt<T> MontgomeryInt<T>::ImportZero(const Params* params) {
  return MontgomeryInt(params->Zero());
}

template <typename T>
MontgomeryInt<T> MontgomeryInt<T>::ImportOne(const Params* params) {
  // 1 should be multiplied by r_mod_modulus; we load directly r_mod_modulus.
  return MontgomeryInt(static_cast<Int>(params->r_mod_modulus));
}

template <typename T>
typename internal::BigInt<T>::value_type MontgomeryInt<T>::DivAndTruncate(
    BigInt dividend, BigInt divisor) {
  return dividend / divisor;
}

template <typename T>
rlwe::StatusOr<std::string> MontgomeryInt<T>::Serialize(
    const Params* params) const {
  // Use transcription to transform all the LogModulus() bits of input into a
  // vector of unsigned char.
  RLWE_ASSIGN_OR_RETURN(
      auto v, (TranscribeBits<Int, Uint8>({this->n_}, params->log_modulus,
                                          params->log_modulus, 8)));
  // Return a string
  return std::string(std::make_move_iterator(v.begin()),
                     std::make_move_iterator(v.end()));
}

template <typename T>
rlwe::StatusOr<std::string> MontgomeryInt<T>::SerializeVector(
    const std::vector<MontgomeryInt>& coeffs, const Params* params) {
  if (coeffs.size() > kMaxNumCoeffs) {
    return absl::InvalidArgumentError(
        absl::StrCat("Number of coefficients, ", coeffs.size(),
                     ", cannot be larger than ", kMaxNumCoeffs, "."));
  } else if (coeffs.empty()) {
    return absl::InvalidArgumentError("Cannot serialize an empty vector.");
  }
  // Bits required to represent modulus.
  int bit_size = params->log_modulus;
  // Extract the values
  std::vector<Int> coeffs_values;
  coeffs_values.reserve(coeffs.size());
  for (const auto& c : coeffs) {
    coeffs_values.push_back(c.n_);
  }
  // Use transcription to transform all the bit_size bits of input into a
  // vector of unsigned char.
  RLWE_ASSIGN_OR_RETURN(
      auto v,
      (TranscribeBits<Int, Uint8>(
          coeffs_values, coeffs_values.size() * bit_size, bit_size, 8)));
  // Return a string
  return std::string(std::make_move_iterator(v.begin()),
                     std::make_move_iterator(v.end()));
}

template <typename T>
rlwe::StatusOr<MontgomeryInt<T>> MontgomeryInt<T>::Deserialize(
    absl::string_view payload, const Params* params) {
  // Parse the string as unsigned char
  std::vector<Uint8> input(payload.begin(), payload.end());
  // Bits required to represent modulus.
  int bit_size = params->log_modulus;
  // Recover the coefficients from the input stream.
  RLWE_ASSIGN_OR_RETURN(auto coeffs_values, (TranscribeBits<Uint8, Int>(
                                                input, bit_size, 8, bit_size)));
  // There will be at least one coefficient in coeff_values because bit_size
  // is always expected to be positive.
  return MontgomeryInt(coeffs_values[0]);
}

template <typename T>
rlwe::StatusOr<std::vector<MontgomeryInt<T>>>
MontgomeryInt<T>::DeserializeVector(int num_coeffs,
                                    absl::string_view serialized,
                                    const Params* params) {
  if (num_coeffs < 0) {
    return absl::InvalidArgumentError(
        "Number of coefficients must be non-negative.");
  }
  if (num_coeffs > kMaxNumCoeffs) {
    return absl::InvalidArgumentError(
        absl::StrCat("Number of coefficients, ", num_coeffs, ", cannot be ",
                     "larger than ", kMaxNumCoeffs, "."));
  }
  // Parse the string as unsigned char
  std::vector<Uint8> input(serialized.begin(), serialized.end());
  // Bits required to represent modulus.
  int bit_size = params->log_modulus;
  // Recover the coefficients from the input stream.
  RLWE_ASSIGN_OR_RETURN(
      auto coeffs_values,
      (TranscribeBits<Uint8, Int>(input, bit_size * num_coeffs, 8, bit_size)));
  // Check that the number of coefficients recovered is at least what is
  // expected.
  if (coeffs_values.size() < num_coeffs) {
    return absl::InvalidArgumentError("Given serialization is invalid.");
  }
  // Create a vector of Montgomery Int from the values.
  std::vector<MontgomeryInt> coeffs;
  coeffs.reserve(num_coeffs);
  for (int i = 0; i < num_coeffs; i++) {
    coeffs.push_back(MontgomeryInt(coeffs_values[i]));
  }
  return coeffs;
}

template <typename T>
std::tuple<T, T> MontgomeryInt<T>::GetConstant(const Params* params) const {
  Int constant = ExportInt(params);
  Int constant_barrett = static_cast<Int>(
      (static_cast<BigInt>(constant) << params->bitsize_int) / params->modulus);
  return std::make_tuple(constant, constant_barrett);
}

template <typename T>
rlwe::StatusOr<std::vector<MontgomeryInt<T>>> MontgomeryInt<T>::BatchAdd(
    const std::vector<MontgomeryInt>& in1,
    const std::vector<MontgomeryInt>& in2, const Params* params) {
  std::vector<MontgomeryInt> out = in1;
  RLWE_RETURN_IF_ERROR(BatchAddInPlace(&out, in2, params));
  return out;
}

template <typename T>
absl::Status MontgomeryInt<T>::BatchAddInPlace(
    std::vector<MontgomeryInt>* in1, const std::vector<MontgomeryInt>& in2,
    const Params* params) {
  // If the input vectors' sizes don't match, return an error.
  if (in1->size() != in2.size()) {
    return absl::InvalidArgumentError("Input vectors are not of same size");
  }
  int i = 0;
  // The remaining elements, if any, are added in place sequentially.
  for (; i < in1->size(); i++) {
    (*in1)[i].AddInPlace(in2[i], params);
  }
  return absl::OkStatus();
}

template <typename T>
rlwe::StatusOr<std::vector<MontgomeryInt<T>>> MontgomeryInt<T>::BatchAdd(
    const std::vector<MontgomeryInt>& in1, const MontgomeryInt& in2,
    const Params* params) {
  std::vector<MontgomeryInt> out = in1;
  RLWE_RETURN_IF_ERROR(BatchAddInPlace(&out, in2, params));
  return out;
}

template <typename T>
absl::Status MontgomeryInt<T>::BatchAddInPlace(std::vector<MontgomeryInt>* in1,
                                               const MontgomeryInt& in2,
                                               const Params* params) {
  int i = 0;
  std::for_each(in1->begin() + i, in1->end(),
                [&in2 = in2, params](MontgomeryInt& coeff) {
                  coeff.AddInPlace(in2, params);
                });
  return absl::OkStatus();
}

template <typename T>
rlwe::StatusOr<std::vector<MontgomeryInt<T>>> MontgomeryInt<T>::BatchSub(
    const std::vector<MontgomeryInt>& in1,
    const std::vector<MontgomeryInt>& in2, const Params* params) {
  std::vector<MontgomeryInt> out = in1;
  RLWE_RETURN_IF_ERROR(BatchSubInPlace(&out, in2, params));
  return out;
}

template <typename T>
absl::Status MontgomeryInt<T>::BatchSubInPlace(
    std::vector<MontgomeryInt>* in1, const std::vector<MontgomeryInt>& in2,
    const Params* params) {
  // If the input vectors' sizes don't match, return an error.
  if (in1->size() != in2.size()) {
    return absl::InvalidArgumentError("Input vectors are not of same size");
  }
  int i = 0;
  for (; i < in1->size(); i++) {
    (*in1)[i].SubInPlace(in2[i], params);
  }
  return absl::OkStatus();
}

template <typename T>
rlwe::StatusOr<std::vector<MontgomeryInt<T>>> MontgomeryInt<T>::BatchSub(
    const std::vector<MontgomeryInt>& in1, const MontgomeryInt& in2,
    const Params* params) {
  std::vector<MontgomeryInt> out = in1;
  RLWE_RETURN_IF_ERROR(BatchSubInPlace(&out, in2, params));
  return out;
}

template <typename T>
absl::Status MontgomeryInt<T>::BatchSubInPlace(std::vector<MontgomeryInt>* in1,
                                               const MontgomeryInt& in2,
                                               const Params* params) {
  int i = 0;
  std::for_each(in1->begin() + i, in1->end(),
                [&in2 = in2, params](MontgomeryInt& coeff) {
                  coeff.SubInPlace(in2, params);
                });
  return absl::OkStatus();
}

template <typename T>
rlwe::StatusOr<std::vector<MontgomeryInt<T>>>
MontgomeryInt<T>::BatchMulConstant(const std::vector<MontgomeryInt>& in1,
                                   const std::vector<Int>& constant,
                                   const std::vector<Int>& constant_barrett,
                                   const Params* params) {
  std::vector<MontgomeryInt> out = in1;
  RLWE_RETURN_IF_ERROR(
      BatchMulConstantInPlace(&out, constant, constant_barrett, params));
  return out;
}

template <typename T>
absl::Status MontgomeryInt<T>::BatchMulConstantInPlace(
    std::vector<MontgomeryInt>* in1, const std::vector<Int>& constant,
    const std::vector<Int>& constant_barrett, const Params* params) {
  // If the input vectors' sizes don't match, return an error.
  if (in1->size() != constant.size() ||
      constant.size() != constant_barrett.size()) {
    return absl::InvalidArgumentError("Input vectors are not of same size");
  }
  int i = 0;
  for (; i < in1->size(); i++) {
    (*in1)[i].MulConstantInPlace(constant[i], constant_barrett[i], params);
  }
  return absl::OkStatus();
}

template <typename T>
rlwe::StatusOr<std::vector<MontgomeryInt<T>>>
MontgomeryInt<T>::BatchMulConstant(const std::vector<MontgomeryInt>& in1,
                                   const Int& constant,
                                   const Int& constant_barrett,
                                   const Params* params) {
  std::vector<MontgomeryInt> out = in1;
  RLWE_RETURN_IF_ERROR(
      BatchMulConstantInPlace(&out, constant, constant_barrett, params));
  return out;
}

template <typename T>
absl::Status MontgomeryInt<T>::BatchMulConstantInPlace(
    std::vector<MontgomeryInt>* in1, const Int& constant,
    const Int& constant_barrett, const Params* params) {
  int i = 0;
  for (; i < in1->size(); i++) {
    (*in1)[i].MulConstantInPlace(constant, constant_barrett, params);
  }
  return absl::OkStatus();
}

template <typename T>
rlwe::StatusOr<std::vector<MontgomeryInt<T>>> MontgomeryInt<T>::BatchMul(
    const std::vector<MontgomeryInt>& in1,
    const std::vector<MontgomeryInt>& in2, const Params* params) {
  std::vector<MontgomeryInt> out = in1;
  RLWE_RETURN_IF_ERROR(BatchMulInPlace(&out, in2, params));
  return out;
}

template <typename T>
absl::Status MontgomeryInt<T>::BatchMulInPlace(
    std::vector<MontgomeryInt>* in1, const std::vector<MontgomeryInt>& in2,
    const Params* params) {
  // If the input vectors' sizes don't match, return an error.
  if (in1->size() != in2.size()) {
    return absl::InvalidArgumentError("Input vectors are not of same size");
  }
  int i = 0;
  for (; i < in1->size(); i++) {
    (*in1)[i].MulInPlace(in2[i], params);
  }
  return absl::OkStatus();
}

template <typename T>
rlwe::StatusOr<std::vector<MontgomeryInt<T>>> MontgomeryInt<T>::BatchMul(
    const std::vector<MontgomeryInt>& in1, const MontgomeryInt& in2,
    const Params* params) {
  std::vector<MontgomeryInt> out = in1;
  RLWE_RETURN_IF_ERROR(BatchMulInPlace(&out, in2, params));
  return out;
}

template <typename T>
absl::Status MontgomeryInt<T>::BatchMulInPlace(std::vector<MontgomeryInt>* in1,
                                               const MontgomeryInt& in2,
                                               const Params* params) {
  int i = 0;
  std::for_each(in1->begin() + i, in1->end(),
                [&in2 = in2, params](MontgomeryInt& coeff) {
                  coeff.MulInPlace(in2, params);
                });
  return absl::OkStatus();
}

template <typename T>
MontgomeryInt<T> MontgomeryInt<T>::ModExp(Int exponent,
                                          const Params* params) const {
  MontgomeryInt result = MontgomeryInt::ImportOne(params);
  MontgomeryInt base = *this;

  // Uses the bits of the exponent to gradually compute the result.
  // When bit k of the exponent is 1, the result is multiplied by
  // base^{2^k}.
  while (exponent > 0) {
    // If the current bit (bit k) is 1, multiply base^{2^k} into the result.
    if (exponent % 2 == 1) {
      result.MulInPlace(base, params);
    }

    // Update base from base^{2^k} to base^{2^{k+1}}.
    base.MulInPlace(base, params);
    exponent >>= 1;
  }

  return result;
}

template <typename T>
MontgomeryInt<T> MontgomeryInt<T>::MultiplicativeInverse(
    const Params* params) const {
  return (*this).ModExp(static_cast<Int>(params->modulus - 2), params);
}

// Instantiations of MontgomeryInt and MontgomeryIntParams with specific
// integral types.
template struct EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) MontgomeryIntParams<Uint16>;
template struct EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) MontgomeryIntParams<Uint32>;
template struct EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) MontgomeryIntParams<Uint64>;
template struct EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) MontgomeryIntParams<absl::uint128>;
template class EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) MontgomeryInt<Uint16>;
template class EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) MontgomeryInt<Uint32>;
template class EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) MontgomeryInt<Uint64>;
template class EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) MontgomeryInt<absl::uint128>;
}  // namespace rlwe