chromium/third_party/shell-encryption/src/testing/coefficient_polynomial.h

/*
 * Copyright 2017 Google LLC.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// The class defined in this file should only be used for testing purposes.

#ifndef RLWE_TESTING_COEFFICIENT_POLYNOMIAL_H_
#define RLWE_TESTING_COEFFICIENT_POLYNOMIAL_H_

#include <cmath>
#include <cstdint>
#include <string>
#include <vector>

#include <glog/logging.h>
#include "absl/strings/str_cat.h"
#include "status_macros.h"
#include "statusor.h"
#include "testing/coefficient_polynomial.pb.h"

namespace rlwe {
namespace testing {

// A polynomial with ModularInt coefficients that is automatically reduced
// modulo <x^n + 1>, where n is the number of coefficients provided in the
// constructor.
// SHould only be used for testing.
template <typename ModularInt>
class CoefficientPolynomial {
  using ModularIntParams = typename ModularInt::Params;

 public:
  // Copy constructor.
  CoefficientPolynomial(const CoefficientPolynomial& that) = default;

  // Constructor. The polynomial is initialized to the values of a vector.
  CoefficientPolynomial(std::vector<ModularInt> coeffs,
                        const ModularIntParams* modulus_params)
      : coeffs_(std::move(coeffs)), modulus_params_(modulus_params) {}

  // Constructs an empty CoefficientPolynomial.
  explicit CoefficientPolynomial(int len,
                                 const ModularIntParams* modulus_params)
      : CoefficientPolynomial(std::vector<ModularInt>(
                                  len, ModularInt::ImportZero(modulus_params)),
                              modulus_params) {}

  // Accessor for length.
  int Len() const { return coeffs_.size(); }

  // Accessor for coefficients.
  std::vector<ModularInt> Coeffs() const { return coeffs_; }

  // Accessor for Modulus Params.
  const ModularIntParams* ModulusParams() const { return modulus_params_; }

  // Compute the degree.
  int Degree() const {
    for (int i = Len() - 1; i >= 0; i--) {
      if (coeffs_[i].ExportInt(modulus_params_) != 0) {
        return i;
      }
    }

    return 0;
  }

  // Equality.
  bool operator==(const CoefficientPolynomial& that) const {
    if (Degree() != that.Degree()) {
      return false;
    }

    for (int i = 0; i <= Degree(); i++) {
      if (coeffs_[i] != that.coeffs_[i]) {
        return false;
      }
    }

    return true;
  }

  bool operator!=(const CoefficientPolynomial& that) const {
    return !(*this == that);
  }

  // Addition.
  rlwe::StatusOr<CoefficientPolynomial> operator+(
      const CoefficientPolynomial& that) const {
    // Ensure the polynomials' dimensions are equal.
    if (Len() != that.Len()) {
      return absl::InvalidArgumentError(
          "CoefficientPolynomial dimensions mismatched.");
    }

    // Add polynomials point-wise.
    CoefficientPolynomial out(*this);
    for (int i = 0; i < Len(); i++) {
      out.coeffs_[i].AddInPlace(that.coeffs_[i], modulus_params_);
    }

    return out;
  }

  // Substraction.
  rlwe::StatusOr<CoefficientPolynomial> operator-(
      const CoefficientPolynomial& that) const {
    // Ensure the polynomials' dimensions are equal.
    if (Len() != that.Len()) {
      return absl::InvalidArgumentError(
          "CoefficientPolynomial dimensions mismatched.");
    }

    // Add polynomials point-wise.
    CoefficientPolynomial out(*this);
    for (int i = 0; i < Len(); i++) {
      out.coeffs_[i].SubInPlace(that.coeffs_[i], modulus_params_);
    }

    return out;
  }

  // Scalar multiplication.
  CoefficientPolynomial operator*(ModularInt c) const {
    CoefficientPolynomial out(*this);
    for (auto& coeff : out.coeffs_) {
      coeff.MulInPlace(c, modulus_params_);
    }
    return out;
  }

  // Multiplication modulo x^N + 1.
  rlwe::StatusOr<CoefficientPolynomial> operator*(
      const CoefficientPolynomial& that) const {
    // Ensure the polynomials' dimensions are equal.
    if (Len() != that.Len()) {
      return absl::InvalidArgumentError(
          "CoefficientPolynomial dimensions mismatched.");
    }

    // Create a zero polynomial of the correct dimension.
    CoefficientPolynomial out(Len(), modulus_params_);

    for (int i = 0; i < Len(); i++) {
      for (int j = 0; j < Len(); j++) {
        if ((i + j) >= coeffs_.size()) {
          // Since multiplciation is mod (x^N + 1), if the coefficient computed
          // has degree k (= i + j) larger than N, it contributes to the (k -
          // N)'th coefficient with a negative factor.
          out.coeffs_[(i + j) - coeffs_.size()].SubInPlace(
              coeffs_[i].Mul(that.coeffs_[j], modulus_params_),
              modulus_params_);
        } else {
          // Otherwise, contributes to the k'th coefficient as  normal.
          out.coeffs_[i + j].AddInPlace(
              coeffs_[i].Mul(that.coeffs_[j], modulus_params_),
              modulus_params_);
        }
      }
    }
    return out;
  }

  // A more efficient multiplication by a monomial x^power, where power <
  // 2*dimension.
  rlwe::StatusOr<CoefficientPolynomial> MonomialMultiplication(
      int power) const {
    // Check that the monomial is in range.
    if (0 > power || power >= 2 * Len()) {
      return absl::InvalidArgumentError(
          "Monomial to absorb must have non-negative degree less than 2n.");
    }

    CoefficientPolynomial out(*this);

    // Monomial multiplication be x^{k} where n <= k < 2*n is monomial
    // multiplication by -x^{k - n}.
    ModularInt multiplier = ModularInt::ImportOne(modulus_params_);
    if (power >= Len()) {
      multiplier.NegateInPlace(modulus_params_);
      power = power - Len();
    }
    ModularInt negative_multiplier = multiplier.Negate(modulus_params_);

    for (int i = 0; i < power; i++) {
      out.coeffs_[i] =
          negative_multiplier.Mul(coeffs_[i - power + Len()], modulus_params_);
    }
    for (int i = power; i < Len(); i++) {
      out.coeffs_[i] = multiplier.Mul(coeffs_[i - power], modulus_params_);
    }

    return out;
  }

  // Given a polynomial p(x), returns a polynomial p(x^a). Expects a power <
  // 2n, where n is the dimension of the polynomial.
  rlwe::StatusOr<CoefficientPolynomial> Substitute(const int power) const {
    // Check that the substitution is in range. The power must be relatively
    // prime to 2*n. Since our dimensions are always a power of two, this is
    // equivalent to the power being odd.
    if (0 > power || (power % 2) == 0 || power >= 2 * Len()) {
      return absl::InvalidArgumentError(
          absl::StrCat("Substitution power must be a non-negative odd ",
                       "integer less than 2*n."));
    }
    CoefficientPolynomial out(*this);

    // The ith coefficient of the original polynomial p(x) is sent to the (i *
    // power % Len())-th coefficient under the substitution. However, in the
    // polynomial ring mod (x^N + 1), x^N = -1, so we multiply the i-th
    // coefficient by (-1)^{(power * i) / Len()}.
    // In the loop, current_index keeps track of (i * power % Len()), and
    // multiplier keeps track of the power of -1 for the current coefficient.
    int current_index = 0;
    ModularInt multiplier = ModularInt::ImportOne(modulus_params_);
    for (int i = 0; i < Len(); i++) {
      out.coeffs_[current_index] = coeffs_[i].Mul(multiplier, modulus_params_);
      current_index += power;

      while (current_index > Len()) {
        multiplier.NegateInPlace(modulus_params_);
        current_index -= Len();
      }
    }

    return out;
  }

  rlwe::StatusOr<SerializedCoefficientPolynomial> Serialize() const {
    SerializedCoefficientPolynomial output;
    RLWE_ASSIGN_OR_RETURN(
        *(output.mutable_coeffs()),
        ModularInt::SerializeVector(coeffs_, modulus_params_));
    output.set_num_coeffs(coeffs_.size());

    return output;
  }

  static rlwe::StatusOr<CoefficientPolynomial> Deserialize(
      const SerializedCoefficientPolynomial& serialized,
      const ModularIntParams* modulus_params) {
    CoefficientPolynomial output(serialized.num_coeffs(), modulus_params);
    RLWE_ASSIGN_OR_RETURN(
        output.coeffs_,
        ModularInt::DeserializeVector(serialized.num_coeffs(),
                                      serialized.coeffs(), modulus_params));

    return output;
  }

 private:
  std::vector<ModularInt> coeffs_;
  const ModularIntParams* modulus_params_;
};

}  // namespace testing
}  // namespace rlwe

#endif  // RLWE_TESTING_COEFFICIENT_POLYNOMIAL_H_