chromium/third_party/shell-encryption/src/ntt_parameters.h

/*
 * Copyright 2017 Google LLC.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#ifndef RLWE_NTT_PARAMETERS_H_
#define RLWE_NTT_PARAMETERS_H_

#include <algorithm>
#include <cstdlib>
#include <vector>

#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "constants.h"
#include "status_macros.h"
#include "statusor.h"
#include "third_party/shell-encryption/base/shell_encryption_export.h"

namespace rlwe {
namespace internal {

// Fill row with every power in {0, 1, ..., n-1} (mod modulus) of base .
template <typename ModularInt>
void FillWithEveryPower(const ModularInt& base, unsigned int n,
                        std::vector<ModularInt>* row,
                        const typename ModularInt::Params* params) {
  for (unsigned int i = 0; i < n; i++) {
    (*row)[i].AddInPlace(base.ModExp(i, params), params);
  }
}

template <typename ModularInt>
rlwe::StatusOr<ModularInt> PrimitiveNthRootOfUnity(
    unsigned int log_n, const typename ModularInt::Params* params) {
  typename ModularInt::Int n = params->One() << log_n;
  typename ModularInt::Int half_n = n >> 1;

  // When the modulus is prime, the value k is a power such that any number
  // raised to it will be a n-th root of unity. (It will not necessarily be a
  // *primitive* root of unity, however).
  typename ModularInt::Int k = (params->modulus - params->One()) / n;

  // Test each number t to see whether t^k is a primitive n-th root
  // of unity - that t^{nk} is a root of unity but t^{(n/2)k} is not.
  ModularInt one = ModularInt::ImportOne(params);
  for (typename ModularInt::Int t = params->Two(); t < params->modulus;
       t = t + params->One()) {
    // Produce a candidate root of unity.
    RLWE_ASSIGN_OR_RETURN(auto mt, ModularInt::ImportInt(t, params));
    ModularInt candidate = mt.ModExp(k, params);

    // Check whether candidate^half_n = 1. If not, it is a primitive root of
    // unity.
    if (candidate.ModExp(half_n, params) != one) {
      return candidate;
    }
  }

  // Failure state. The above loop should always return successfully assuming
  // the parameters were set properly.
  return absl::UnknownError("Loop in PrimitiveNthRootOfUnity terminated.");
}

// Let psi be a primitive 2n-th root of unity, i.e., a 2n-th root of unity such
// that psi^n = -1. When performing the NTT transformation, the powers of psi in
// bitreversed order are needed. The vector produced by this helper function
// contains the powers of psi (psi^0, psi^1, psi^2, ..., psi^(n-1)).
//
// Each item of the vector is in modular integer representation.
template <typename ModularInt>
rlwe::StatusOr<std::vector<ModularInt>> NttPsis(
    unsigned int log_n, const typename ModularInt::Params* params) {
  // Obtain psi, a primitive 2n-th root of unity (hence log_n + 1).
  RLWE_ASSIGN_OR_RETURN(
      ModularInt psi,
      internal::PrimitiveNthRootOfUnity<ModularInt>(log_n + 1, params));
  unsigned int n = 1 << log_n;
  ModularInt zero = ModularInt::ImportZero(params);
  // Create a vector with the powers of psi.
  std::vector<ModularInt> row(n, zero);
  internal::FillWithEveryPower<ModularInt>(psi, n, &row, params);
  return row;
}

// Creates a vector containing the indices necessary to perform the NTT bit
// reversal operation. Index i of the returned vector contains an integer with
// the rightmost log_n bits of i reversed.
SHELL_ENCRYPTION_EXPORT std::vector<unsigned int> BitrevArray(unsigned int log_n);

// Helper function: Perform the bit-reversal operation in-place on coeffs_.
template <typename ModularInt>
static void BitrevHelper(const std::vector<unsigned int>& bitrevs,
                         std::vector<ModularInt>* item_to_reverse) {
  using std::swap;
  for (int i = 0; i < item_to_reverse->size(); i++) {
    // Only swap in one direction - don't accidentally swap twice.
    unsigned int r = bitrevs[i];
    if (static_cast<unsigned int>(i) < r) {
      swap((*item_to_reverse)[i], (*item_to_reverse)[r]);
    }
  }
}

}  // namespace internal

// The precomputed roots of unity used during the forward NTT are the
// bitreversed powers of the primitive 2n-th root of unity.
template <typename ModularInt>
rlwe::StatusOr<std::vector<ModularInt>> NttPsisBitrev(
    unsigned int log_n, const typename ModularInt::Params* params) {
  // Retrieve the table for the forward transformation.
  RLWE_ASSIGN_OR_RETURN(std::vector<ModularInt> psis,
                        internal::NttPsis<ModularInt>(log_n, params));
  // Bitreverse the vector.
  internal::BitrevHelper(internal::BitrevArray(log_n), &psis);
  return psis;
}

// The precomputed roots of unity used during the inverse NTT are the inverses
// of the bitreversed powers of the primitive 2n-th root of unity plus 1.
template <typename ModularInt>
rlwe::StatusOr<std::vector<ModularInt>> NttPsisInvBitrev(
    unsigned int log_n, const typename ModularInt::Params* params) {
  // Retrieve the table for the forward transformation.
  RLWE_ASSIGN_OR_RETURN(std::vector<ModularInt> row,
                        internal::NttPsis<ModularInt>(log_n, params));

  // Reverse the items at indices 1 through (n - 1). Multiplying index i
  // of the reversed row by index i of the original row will yield psi^n = -1.
  // (The exception is psi^0 = 1, which is already its own inverse.)
  std::reverse(row.begin() + 1, row.end());

  // Get the inverse of psi
  ModularInt psi_inv = row[1].Negate(params);
  ModularInt negative_psi_inv = row[1];

  // Bitreverse the vector.
  internal::BitrevHelper(internal::BitrevArray(log_n), &row);

  // Finally, multiply each of the items at indices 1 to (n-1) by -1. Multiply
  // every entry by psi_inv.
  row[0].MulInPlace(psi_inv, params);
  for (int i = 1; i < row.size(); i++) {
    row[i].MulInPlace(negative_psi_inv, params);
  }

  return row;
}

// A struct that stores a package of NTT Parameters
template <typename ModularInt>
struct NttParameters {
  NttParameters() = default;
  // Disallow copy and copy-assign, allow move and move-assign.
  NttParameters(const NttParameters<ModularInt>&) = delete;
  NttParameters& operator=(const NttParameters<ModularInt>&) = delete;
  NttParameters(NttParameters<ModularInt>&&) = default;
  NttParameters& operator=(NttParameters<ModularInt>&&) = default;
  ~NttParameters() = default;

  int number_coeffs;
  std::optional<ModularInt> n_inv_ptr;
  std::vector<ModularInt> psis_bitrev;
  std::vector<ModularInt> psis_inv_bitrev;
  std::vector<unsigned int> bitrevs;
};

// A convenient function that sets up all NTT parameters at once.
// Does not take ownership of params.
template <typename ModularInt>
rlwe::StatusOr<NttParameters<ModularInt>> InitializeNttParameters(
    int log_n, const typename ModularInt::Params* params) {
  // Abort if log_n is non-positive.
  if (log_n <= 0) {
    return absl::InvalidArgumentError("log_n must be positive");
  } else if (static_cast<Uint64>(log_n) > kMaxLogNumCoeffs) {
    return absl::InvalidArgumentError(absl::StrCat(
        "log_n, ", log_n, ", must be less than ", kMaxLogNumCoeffs, "."));
  }

  if (!ModularInt::Params::DoesLogNFit(log_n)) {
    return absl::InvalidArgumentError(
        absl::StrCat("log_n, ", log_n,
                     ", does not fit into underlying ModularInt::Int type."));
  }

  NttParameters<ModularInt> output;

  output.number_coeffs = 1 << log_n;
  typename ModularInt::Int two_times_n = params->One() << (log_n + 1);

  if (params->modulus % two_times_n != params->One()){
    return absl::InvalidArgumentError(
        absl::StrCat("modulus is not 1 mod 2n for logn, ", log_n));
  }

  // Compute the inverse of n.
  typename ModularInt::Int n = params->One() << log_n;
  RLWE_ASSIGN_OR_RETURN(auto mn, ModularInt::ImportInt(n, params));
  output.n_inv_ptr = mn.MultiplicativeInverse(params);

  RLWE_ASSIGN_OR_RETURN(output.psis_bitrev,
                        NttPsisBitrev<ModularInt>(log_n, params));
  RLWE_ASSIGN_OR_RETURN(output.psis_inv_bitrev,
                        NttPsisInvBitrev<ModularInt>(log_n, params));
  output.bitrevs = internal::BitrevArray(log_n);

  return output;
}

}  // namespace rlwe

#endif  // RLWE_NTT_PARAMETERS_H_