chromium/third_party/shell-encryption/src/polynomial.h

/*
 * Copyright 2017 Google LLC.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef RLWE_POLYNOMIAL_H_
#define RLWE_POLYNOMIAL_H_

#include <cmath>
#include <vector>

#include "absl/strings/str_cat.h"
#include "constants.h"
#include "ntt_parameters.h"
#include "prng/prng.h"
#include "serialization.pb.h"
#include "status_macros.h"
#include "statusor.h"

namespace rlwe {

// A polynomial in NTT form. The length of the polynomial must be a power of 2.
template <typename ModularInt>
class Polynomial {
  using ModularIntParams = typename ModularInt::Params;

 public:
  // Default constructor.
  Polynomial() = default;

  // Copy constructor.
  Polynomial(const Polynomial& p) = default;
  Polynomial& operator=(const Polynomial& that) = default;

  // Basic constructor.
  explicit Polynomial(std::vector<ModularInt> poly_coeffs)
      : log_len_(log2(poly_coeffs.size())), coeffs_(std::move(poly_coeffs)) {}

  // Create an empty polynomial of the specified length. The length must be
  // a power of 2.
  explicit Polynomial(int len, const ModularIntParams* params)
      : Polynomial(
            std::vector<ModularInt>(len, ModularInt::ImportZero(params))) {}

  // This is an implementation of the FFT from [Sei18, Sec. 2].
  // [Sei18] https://eprint.iacr.org/2018/039
  // All polynomial arithmetic performed is modulo (x^n+1) for n a power of two,
  // with the coefficients operated on modulo a prime modulus.
  //
  // Let psi be a primitive 2n-th root of the unity, i.e., psi is a 2n-th root
  // of unity such that psi^n = -1. Hence it holds that
  //           x^n+1 = x^n-psi^n = (x^n/2-psi^n/2)*(x^n/2+psi^n/2)
  //
  //
  // If f = f_0 + f_1*x + ... + f_{n-1}*x^(n-1) is the polynomial to transform,
  // the i-th coefficient of the polynomial mod x^n/2-psi^n/2 can thus be
  // computed as
  //            f'_i = f_i + psi^(n/2)*f_(n/2+i),
  // and the i-th coefficient of the polynomial mod x^n/2+psi^n/2 can thus be
  // computed as
  //            f''_i = f_i - psi^(n/2)*f_(n/2+i)
  // This operation is called the Cooley-Tukey butterfly and is done
  // iteratively during the NTT.
  //
  // The FFT can thus be performed in-place and after the k-th level, it
  // produces the vector of polynomials with pairs of coefficients
  //  f mod (x^(n/2^(k+1))-psi^brv[2^k+1]), f mod (x^(n/2^(k+1))+psi^brv[2^k+1])
  // where brv maps a log(n)-bit number to its bitreversal.
  static Polynomial ConvertToNtt(std::vector<ModularInt> poly_coeffs,
                                 const NttParameters<ModularInt>* ntt_params,
                                 const ModularIntParams* modular_params) {
    // Check to ensure that the coefficient vector is of the correct length.
    int len = poly_coeffs.size();
    if (len <= 0 || (len & (len - 1)) != 0) {
      // An error value.
      return Polynomial();
    }

    Polynomial output(std::move(poly_coeffs));
    output.IterativeCooleyTukey(ntt_params->psis_bitrev, modular_params);

    return output;
  }

  // Deprecated ConvertToNtt function taking NttParameters by constant reference
  ABSL_DEPRECATED("Use ConvertToNtt function with NttParameters pointer above.")
  static Polynomial ConvertToNtt(std::vector<ModularInt> poly_coeffs,
                                 const NttParameters<ModularInt>& ntt_params,
                                 const ModularIntParams* modular_params) {
    return ConvertToNtt(std::move(poly_coeffs), &ntt_params, modular_params);
  }

  // The inverse NTT transform is computed similarly by iteratively inverting
  // the NTT representation. For instance, using the same notation as above,
  //    f'_i + f''_i = 2f_i and  psi^(-n/2)*(f'_i-f''_i) = 2c_(n/2+i).
  //
  // In particular, the butterfly operation differs from the Cooley-Tukey
  // butterfly used during the forward transform in that addition and
  // substraction come before multiplying with a power of the root of unity.
  // This butterfly operation is called the Gentleman-Sande butterfly.
  //
  // At the end of the computation, a normalization step by the inverse of
  // n=2^log(n) (the factor 2 obtained at each level of the butterfly) is
  // required.
  std::vector<ModularInt> InverseNtt(
      const NttParameters<ModularInt>* ntt_params,
      const ModularIntParams* modular_params) const {
    Polynomial copy(*this);

    copy.IterativeGentlemanSande(ntt_params->psis_inv_bitrev, modular_params);

    // Normalize the result by multiplying by the inverse of n.
    for (auto& coeff : copy.coeffs_) {
      coeff.MulInPlace(ntt_params->n_inv_ptr.value(), modular_params);
    }

    return copy.coeffs_;
  }

  // Deprecated InverseNtt function taking NttParameters by constant reference
  ABSL_DEPRECATED("Use InverseNtt function with NttParameters pointer above.")
  std::vector<ModularInt> InverseNtt(
      const NttParameters<ModularInt>& ntt_params,
      const ModularIntParams* modular_params) const {
    return InverseNtt(&ntt_params, modular_params);
  }

  // Specifies whether the Polynomial is valid.
  bool IsValid() const { return !coeffs_.empty(); }

  // Scalar multiply.
  rlwe::StatusOr<Polynomial> Mul(const ModularInt& scalar,
                                 const ModularIntParams* modular_params) const {
    Polynomial output = *this;
    RLWE_RETURN_IF_ERROR(output.MulInPlace(scalar, modular_params));
    return output;
  }

  // Scalar multiply in place.
  absl::Status MulInPlace(const ModularInt& scalar,
                          const ModularIntParams* modular_params) {
    return ModularInt::BatchMulInPlace(&coeffs_, scalar, modular_params);
  }

  // Coordinate-wise multiplication.
  rlwe::StatusOr<Polynomial> Mul(const Polynomial& that,
                                 const ModularIntParams* modular_params) const {
    Polynomial output = *this;
    RLWE_RETURN_IF_ERROR(output.MulInPlace(that, modular_params));
    return output;
  }

  // Coordinate-wise multiplication in place.
  absl::Status MulInPlace(const Polynomial& that,
                          const ModularIntParams* modular_params) {
    // If this operation is invalid, return an invalid error.
    if (Len() != that.Len()) {
      return absl::InvalidArgumentError(
          "The polynomials do not have the same length.");
    }
    return ModularInt::BatchMulInPlace(&coeffs_, that.coeffs_, modular_params);
  }

  // Negation.
  Polynomial Negate(const ModularIntParams* modular_params) const {
    Polynomial output = *this;
    output.NegateInPlace(modular_params);
    return output;
  }

  // Negation in place.
  Polynomial& NegateInPlace(const ModularIntParams* modular_params) {
    for (auto& coeff : coeffs_) {
      coeff.NegateInPlace(modular_params);
    }

    return *this;
  }

  // Coordinate-wise addition.
  rlwe::StatusOr<Polynomial> Add(const Polynomial& that,
                                 const ModularIntParams* modular_params) const {
    Polynomial output = *this;
    RLWE_RETURN_IF_ERROR(output.AddInPlace(that, modular_params));
    return output;
  }

  // Coordinate-wise substraction.
  rlwe::StatusOr<Polynomial> Sub(const Polynomial& that,
                                 const ModularIntParams* modular_params) const {
    Polynomial output = *this;
    RLWE_RETURN_IF_ERROR(output.SubInPlace(that, modular_params));
    return output;
  }

  // Coordinate-wise addition in place.
  absl::Status AddInPlace(const Polynomial& that,
                          const ModularIntParams* modular_params) {
    // If this operation is invalid, return an invalid error.
    if (Len() != that.Len()) {
      return absl::InvalidArgumentError(
          "The polynomials do not have the same length.");
    }

    return ModularInt::BatchAddInPlace(&coeffs_, that.coeffs_, modular_params);
  }

  // Coordinate-wise substraction in place.
  absl::Status SubInPlace(const Polynomial& that,
                          const ModularIntParams* modular_params) {
    // If this operation is invalid, return an invalid error.
    if (Len() != that.Len()) {
      return absl::InvalidArgumentError(
          "The polynomials do not have the same length.");
    }

    return ModularInt::BatchSubInPlace(&coeffs_, that.coeffs_, modular_params);
  }

  // Substitute: Given an Polynomial representing p(x), returns an
  // Polynomial representing p(x^power). Power must be an odd non-negative
  // integer less than 2 * Len().
  rlwe::StatusOr<Polynomial> Substitute(
      const int power, const NttParameters<ModularInt>* ntt_params,
      const ModularIntParams* modulus_params) const {
    // The NTT representation consists in the evaluations of the polynomial at
    // roots psi^brv[n/2], psi^brv[n/2+1], ..., psi^brv[n/2+n/2-1],
    //       psi^(n/2+brv[n/2+1]), ...,         psi^(n/2+brv[n/2+n/2-1]).
    // Let f(x) be the original polynomial, and out(x) be the polynomial after
    // the substitution. Note that (psi^i)^power = psi^{(i * power) % (2 * n).
    if (0 > power || (power % 2) == 0 || power >= 2 * Len()) {
      return absl::InvalidArgumentError(
          absl::StrCat("Substitution power must be a non-negative odd "
                       "integer less than 2*n."));
    }

    Polynomial out = *this;

    // Get the index of the psi^power evaluation
    int psi_power_index = (power - 1) / 2;
    // Update the coefficients one by one: remember that they are stored in
    // bitreversed order.
    for (int i = 0; i < Len(); i++) {
      out.coeffs_[ntt_params->bitrevs[i]] =
          coeffs_[ntt_params->bitrevs[psi_power_index]];
      // Each time the index increases by 1, the psi_power_index increases by
      // power mod the length.
      psi_power_index = (psi_power_index + power) % Len();
    }

    return out;
  }

  // Deprecated Substitute function taking NttParameters by constant reference
  ABSL_DEPRECATED("Use Substitute function with NttParameters pointer above.")
  rlwe::StatusOr<Polynomial> Substitute(
      const int power, const NttParameters<ModularInt>& ntt_params,
      const ModularIntParams* modulus_params) const {
    return Substitute(power, &ntt_params, modulus_params);
  }

  // Boolean comparison.
  bool operator==(const Polynomial& that) const {
    if (Len() != that.Len()) {
      return false;
    }

    for (int i = 0; i < Len(); i++) {
      if (coeffs_[i] != that.coeffs_[i]) {
        return false;
      }
    }

    return true;
  }
  bool operator!=(const Polynomial& that) const { return !(*this == that); }

  int Len() const { return coeffs_.size(); }

  // Accessor for coefficients.
  std::vector<ModularInt> Coeffs() const { return coeffs_; }

  rlwe::StatusOr<SerializedNttPolynomial> Serialize(
      const ModularIntParams* modular_params) const {
    SerializedNttPolynomial output;
    RLWE_ASSIGN_OR_RETURN(*(output.mutable_coeffs()),
                          ModularInt::SerializeVector(coeffs_, modular_params));
    output.set_num_coeffs(coeffs_.size());
    return output;
  }

  static rlwe::StatusOr<Polynomial> Deserialize(
      const SerializedNttPolynomial& serialized,
      const ModularIntParams* modular_params) {
    if (serialized.num_coeffs() <= 0) {
      return absl::InvalidArgumentError(
          "Number of serialized coefficients must be positive.");
    } else if (serialized.num_coeffs() > kMaxNumCoeffs) {
      return absl::InvalidArgumentError(absl::StrCat(
          "Number of serialized coefficients, ", serialized.num_coeffs(),
          ", must be less than ", kMaxNumCoeffs, "."));
    }
    Polynomial output(serialized.num_coeffs(), modular_params);
    RLWE_ASSIGN_OR_RETURN(
        output.coeffs_,
        ModularInt::DeserializeVector(serialized.num_coeffs(),
                                      serialized.coeffs(), modular_params));
    return output;
  }

 private:
  // Instance variables.
  size_t log_len_;
  std::vector<ModularInt> coeffs_;

  // Helper function: Perform iterations of the Cooley-Tukey butterfly.
  void IterativeCooleyTukey(const std::vector<ModularInt>& psis_bitrev,
                            const ModularIntParams* modular_params) {
    int index_psi = 1;
    for (int i = log_len_ - 1; i >= 0; i--) {
      const unsigned int half_m = 1 << i;
      const unsigned int m = half_m << 1;
      for (int k = 0; k < Len(); k += m) {
        const ModularInt psi = psis_bitrev[index_psi];
        for (int j = 0; j < half_m; j++) {
          // The Cooley-Tukey butterfly operation.
          const ModularInt t = psi.Mul(coeffs_[k + j + half_m], modular_params);
          ModularInt u = coeffs_[k + j];
          coeffs_[k + j].AddInPlace(t, modular_params);
          coeffs_[k + j + half_m] = std::move(u.SubInPlace(t, modular_params));
        }
        index_psi++;
      }
    }
  }

  // Helper function: Perform iterations of the Gentleman-Sande butterfly.
  void IterativeGentlemanSande(const std::vector<ModularInt>& psis_inv_bitrev,
                               const ModularIntParams* modular_params) {
    int index_psi_inv = 0;
    for (int i = 0; i < log_len_; i++) {
      const unsigned int half_m = 1 << i;
      const unsigned int m = half_m << 1;
      for (int k = 0; k < Len(); k += m) {
        const ModularInt psi_inv = psis_inv_bitrev[index_psi_inv];
        for (int j = 0; j < half_m; j++) {
          // The Gentleman-Sande butterfly operation.
          const ModularInt t = coeffs_[k + j + half_m];
          ModularInt u = coeffs_[k + j];
          coeffs_[k + j].AddInPlace(t, modular_params);
          coeffs_[k + j + half_m] =
              std::move(u.SubInPlace(t, modular_params)
                            .MulInPlace(psi_inv, modular_params));
        }
        index_psi_inv++;
      }
    }
  }
};

template <typename ModularInt, typename Prng = rlwe::SecurePrng>
rlwe::StatusOr<Polynomial<ModularInt>> SamplePolynomialFromPrng(
    int num_coeffs, Prng* prng,
    const typename ModularInt::Params* modulus_params) {
  // Sample a from the uniform distribution. Since a is uniformly distributed,
  // it can be generated directly in NTT form since the NTT transformation is
  // an automorphism.
  if (num_coeffs < 1) {
    return absl::InvalidArgumentError(
        "SamplePolynomialFromPrng: number of coefficients must be a "
        "non-negative integer.");
  }
  std::vector<ModularInt> a_ntt_coeffs(num_coeffs,
                                       ModularInt::ImportZero(modulus_params));
  for (int i = 0; i < num_coeffs; i++) {
    RLWE_ASSIGN_OR_RETURN(a_ntt_coeffs[i],
                          ModularInt::ImportRandom(prng, modulus_params));
  }
  return Polynomial<ModularInt>(a_ntt_coeffs);
}

}  // namespace rlwe

#endif  // RLWE_POLYNOMIAL_H_