chromium/third_party/boringssl/src/crypto/mldsa/mldsa.c

/* Copyright (c) 2024, Google LLC
 *
 * Permission to use, copy, modify, and/or distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */

#include <openssl/mldsa.h>

#include <assert.h>
#include <stdlib.h>

#include <openssl/bytestring.h>
#include <openssl/mem.h>
#include <openssl/rand.h>

#include "../internal.h"
#include "../keccak/internal.h"
#include "./internal.h"

#define DEGREE
#define K
#define L
#define ETA
#define TAU
#define BETA
#define OMEGA

#define RHO_BYTES
#define SIGMA_BYTES
#define K_BYTES
#define TR_BYTES
#define MU_BYTES
#define RHO_PRIME_BYTES
#define LAMBDA_BITS
#define LAMBDA_BYTES

// 2^23 - 2^13 + 1
static const uint32_t kPrime =;
// Inverse of -kPrime modulo 2^32
static const uint32_t kPrimeNegInverse =;
static const int kDroppedBits =;
static const uint32_t kHalfPrime =;
static const uint32_t kGamma1 =;
static const uint32_t kGamma2 =;
// 256^-1 mod kPrime, in Montgomery form.
static const uint32_t kInverseDegreeMontgomery =;

scalar;

vectork;

vectorl;

matrix;

/* Arithmetic */

// This bit of Python will be referenced in some of the following comments:
//
// q = 8380417
// # Inverse of -q modulo 2^32
// q_neg_inverse = 4236238847
// # 2^64 modulo q
// montgomery_square = 2365951
//
// def bitreverse(i):
//     ret = 0
//     for n in range(8):
//         bit = i & 1
//         ret <<= 1
//         ret |= bit
//         i >>= 1
//     return ret
//
// def montgomery_reduce(x):
//     a = (x * q_neg_inverse) % 2**32
//     b = x + a * q
//     assert b & 0xFFFF_FFFF == 0
//     c = b >> 32
//     assert c < q
//     return c
//
// def montgomery_transform(x):
//     return montgomery_reduce(x * montgomery_square)

// kNTTRootsMontgomery = [
//   montgomery_transform(pow(1753, bitreverse(i), q)) for i in range(256)
// ]
static const uint32_t kNTTRootsMontgomery[256] =;

// Reduces x mod kPrime in constant time, where 0 <= x < 2*kPrime.
static uint32_t reduce_once(uint32_t x) {}

// Returns the absolute value in constant time.
static uint32_t abs_signed(uint32_t x) {}

// Returns the absolute value modulo kPrime.
static uint32_t abs_mod_prime(uint32_t x) {}

// Returns the maximum of two values in constant time.
static uint32_t maximum(uint32_t x, uint32_t y) {}

static uint32_t mod_sub(uint32_t a, uint32_t b) {}

static void scalar_add(scalar *out, const scalar *lhs, const scalar *rhs) {}

static void scalar_sub(scalar *out, const scalar *lhs, const scalar *rhs) {}

static uint32_t reduce_montgomery(uint64_t x) {}

// Multiply two scalars in the number theoretically transformed state.
static void scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) {}

// In place number theoretic transform of a given scalar.
//
// FIPS 204, Algorithm 41 (`NTT`).
static void scalar_ntt(scalar *s) {}

// In place inverse number theoretic transform of a given scalar.
//
// FIPS 204, Algorithm 42 (`NTT^-1`).
static void scalar_inverse_ntt(scalar *s) {}

static void vectork_zero(vectork *out) {}

static void vectork_add(vectork *out, const vectork *lhs, const vectork *rhs) {}

static void vectork_sub(vectork *out, const vectork *lhs, const vectork *rhs) {}

static void vectork_mult_scalar(vectork *out, const vectork *lhs,
                                const scalar *rhs) {}

static void vectork_ntt(vectork *a) {}

static void vectork_inverse_ntt(vectork *a) {}

static void vectorl_add(vectorl *out, const vectorl *lhs, const vectorl *rhs) {}

static void vectorl_mult_scalar(vectorl *out, const vectorl *lhs,
                                const scalar *rhs) {}

static void vectorl_ntt(vectorl *a) {}

static void vectorl_inverse_ntt(vectorl *a) {}

static void matrix_mult(vectork *out, const matrix *m, const vectorl *a) {}

/* Rounding & hints */

// FIPS 204, Algorithm 35 (`Power2Round`).
static void power2_round(uint32_t *r1, uint32_t *r0, uint32_t r) {}

// Scale back previously rounded value.
static void scale_power2_round(uint32_t *out, uint32_t r1) {}

// FIPS 204, Algorithm 37 (`HighBits`).
static uint32_t high_bits(uint32_t x) {}

// FIPS 204, Algorithm 36 (`Decompose`).
static void decompose(uint32_t *r1, int32_t *r0, uint32_t r) {}

// FIPS 204, Algorithm 38 (`LowBits`).
static int32_t low_bits(uint32_t x) {}

// FIPS 204, Algorithm 39 (`MakeHint`).
//
// In the spec this takes two arguments, z and r, and is called with
//   z = -ct0
//   r = w - cs2 + ct0
//
// It then computes HighBits (algorithm 37) of z and z+r. But z+r is just w -
// cs2, so this takes three arguments and saves an addition.
static int32_t make_hint(uint32_t ct0, uint32_t cs2, uint32_t w) {}

// FIPS 204, Algorithm 40 (`UseHint`).
static uint32_t use_hint_vartime(uint32_t h, uint32_t r) {}

static void scalar_power2_round(scalar *s1, scalar *s0, const scalar *s) {}

static void scalar_scale_power2_round(scalar *out, const scalar *in) {}

static void scalar_high_bits(scalar *out, const scalar *in) {}

static void scalar_low_bits(scalar *out, const scalar *in) {}

static void scalar_max(uint32_t *max, const scalar *s) {}

static void scalar_max_signed(uint32_t *max, const scalar *s) {}

static void scalar_make_hint(scalar *out, const scalar *ct0, const scalar *cs2,
                             const scalar *w) {}

static void scalar_use_hint_vartime(scalar *out, const scalar *h,
                                    const scalar *r) {}

static void vectork_power2_round(vectork *t1, vectork *t0, const vectork *t) {}

static void vectork_scale_power2_round(vectork *out, const vectork *in) {}

static void vectork_high_bits(vectork *out, const vectork *in) {}

static void vectork_low_bits(vectork *out, const vectork *in) {}

static uint32_t vectork_max(const vectork *a) {}

static uint32_t vectork_max_signed(const vectork *a) {}

// The input vector contains only zeroes and ones.
static size_t vectork_count_ones(const vectork *a) {}

static void vectork_make_hint(vectork *out, const vectork *ct0,
                              const vectork *cs2, const vectork *w) {}

static void vectork_use_hint_vartime(vectork *out, const vectork *h,
                                     const vectork *r) {}

static uint32_t vectorl_max(const vectorl *a) {}

/* Bit packing */

// FIPS 204, Algorithm 16 (`SimpleBitPack`). Specialized to bitlen(b) = 4.
static void scalar_encode_4(uint8_t out[128], const scalar *s) {}

// FIPS 204, Algorithm 16 (`SimpleBitPack`). Specialized to bitlen(b) = 10.
static void scalar_encode_10(uint8_t out[320], const scalar *s) {}

// FIPS 204, Algorithm 17 (`BitPack`). Specialized to bitlen(b) = 4 and b =
// 2^19.
static void scalar_encode_signed_4_eta(uint8_t out[128], const scalar *s) {}

// FIPS 204, Algorithm 17 (`BitPack`). Specialized to bitlen(b) = 13 and b =
// 2^12.
static void scalar_encode_signed_13_12(uint8_t out[416], const scalar *s) {}

// FIPS 204, Algorithm 17 (`BitPack`). Specialized to bitlen(b) = 20 and b =
// 2^19.
static void scalar_encode_signed_20_19(uint8_t out[640], const scalar *s) {}

// FIPS 204, Algorithm 17 (`BitPack`).
static void scalar_encode_signed(uint8_t *out, const scalar *s, int bits,
                                 uint32_t max) {}

// FIPS 204, Algorithm 18 (`SimpleBitUnpack`). Specialized for bitlen(b) == 10.
static void scalar_decode_10(scalar *out, const uint8_t in[320]) {}

// FIPS 204, Algorithm 19 (`BitUnpack`). Specialized to bitlen(a+b) = 4 and b =
// eta.
static int scalar_decode_signed_4_eta(scalar *out, const uint8_t in[128]) {}

// FIPS 204, Algorithm 19 (`BitUnpack`). Specialized to bitlen(a+b) = 13 and b =
// 2^12.
static void scalar_decode_signed_13_12(scalar *out, const uint8_t in[416]) {}

// FIPS 204, Algorithm 19 (`BitUnpack`). Specialized to bitlen(a+b) = 20 and b =
// 2^19.
static void scalar_decode_signed_20_19(scalar *out, const uint8_t in[640]) {}

// FIPS 204, Algorithm 19 (`BitUnpack`).
static int scalar_decode_signed(scalar *out, const uint8_t *in, int bits,
                                uint32_t max) {}

/* Expansion functions */

// FIPS 204, Algorithm 30 (`RejNTTPoly`).
//
// Rejection samples a Keccak stream to get uniformly distributed elements. This
// is used for matrix expansion and only operates on public inputs.
static void scalar_from_keccak_vartime(
    scalar *out, const uint8_t derived_seed[RHO_BYTES + 2]) {}

// FIPS 204, Algorithm 31 (`RejBoundedPoly`).
static void scalar_uniform_eta_4(scalar *out,
                                 const uint8_t derived_seed[SIGMA_BYTES + 2]) {}

// FIPS 204, Algorithm 34 (`ExpandMask`), but just a single step.
static void scalar_sample_mask(
    scalar *out, const uint8_t derived_seed[RHO_PRIME_BYTES + 2]) {}

// FIPS 204, Algorithm 29 (`SampleInBall`).
static void scalar_sample_in_ball_vartime(scalar *out, const uint8_t *seed,
                                          int len) {}

// FIPS 204, Algorithm 32 (`ExpandA`).
static void matrix_expand(matrix *out, const uint8_t rho[RHO_BYTES]) {}

// FIPS 204, Algorithm 33 (`ExpandS`).
static void vector_expand_short(vectorl *s1, vectork *s2,
                                const uint8_t sigma[SIGMA_BYTES]) {}

// FIPS 204, Algorithm 34 (`ExpandMask`).
static void vectorl_expand_mask(vectorl *out,
                                const uint8_t seed[RHO_PRIME_BYTES],
                                size_t kappa) {}

/* Encoding */

// FIPS 204, Algorithm 16 (`SimpleBitPack`).
//
// Encodes an entire vector into 32*K*|bits| bytes. Note that since 256 (DEGREE)
// is divisible by 8, the individual vector entries will always fill a whole
// number of bytes, so we do not need to worry about bit packing here.
static void vectork_encode(uint8_t *out, const vectork *a, int bits) {}

// FIPS 204, Algorithm 18 (`SimpleBitUnpack`).
static void vectork_decode_10(vectork *out, const uint8_t *in) {}

static void vectork_encode_signed(uint8_t *out, const vectork *a, int bits,
                                  uint32_t max) {}

static int vectork_decode_signed(vectork *out, const uint8_t *in, int bits,
                                 uint32_t max) {}

// FIPS 204, Algorithm 17 (`BitPack`).
//
// Encodes an entire vector into 32*L*|bits| bytes. Note that since 256 (DEGREE)
// is divisible by 8, the individual vector entries will always fill a whole
// number of bytes, so we do not need to worry about bit packing here.
static void vectorl_encode_signed(uint8_t *out, const vectorl *a, int bits,
                                  uint32_t max) {}

static int vectorl_decode_signed(vectorl *out, const uint8_t *in, int bits,
                                 uint32_t max) {}

// FIPS 204, Algorithm 28 (`w1Encode`).
static void w1_encode(uint8_t out[128 * K], const vectork *w1) {}

// FIPS 204, Algorithm 20 (`HintBitPack`).
static void hint_bit_pack(uint8_t out[OMEGA + K], const vectork *h) {}

// FIPS 204, Algorithm 21 (`HintBitUnpack`).
static int hint_bit_unpack(vectork *h, const uint8_t in[OMEGA + K]) {}

struct public_key {};

struct private_key {};

struct signature {};

// FIPS 204, Algorithm 22 (`pkEncode`).
static int mldsa_marshal_public_key(CBB *out, const struct public_key *pub) {}

// FIPS 204, Algorithm 23 (`pkDecode`).
static int mldsa_parse_public_key(struct public_key *pub, CBS *in) {}

// FIPS 204, Algorithm 24 (`skEncode`).
static int mldsa_marshal_private_key(CBB *out, const struct private_key *priv) {}

// FIPS 204, Algorithm 25 (`skDecode`).
static int mldsa_parse_private_key(struct private_key *priv, CBS *in) {}

// FIPS 204, Algorithm 26 (`sigEncode`).
static int mldsa_marshal_signature(CBB *out, const struct signature *sign) {}

// FIPS 204, Algorithm 27 (`sigDecode`).
static int mldsa_parse_signature(struct signature *sign, CBS *in) {}

static struct private_key *private_key_from_external(
    const struct MLDSA65_private_key *external) {}

static struct public_key *public_key_from_external(
    const struct MLDSA65_public_key *external) {}

/* API */

// Calls |MLDSA_generate_key_external_entropy| with random bytes from
// |RAND_bytes|. Returns 1 on success and 0 on failure.
int MLDSA65_generate_key(
    uint8_t out_encoded_public_key[MLDSA65_PUBLIC_KEY_BYTES],
    uint8_t out_seed[MLDSA_SEED_BYTES],
    struct MLDSA65_private_key *out_private_key) {}

int MLDSA65_private_key_from_seed(struct MLDSA65_private_key *out_private_key,
                                  const uint8_t *seed, size_t seed_len) {}

// FIPS 204, Algorithm 6 (`ML-DSA.KeyGen_internal`). Returns 1 on success and 0
// on failure.
int MLDSA65_generate_key_external_entropy(
    uint8_t out_encoded_public_key[MLDSA65_PUBLIC_KEY_BYTES],
    struct MLDSA65_private_key *out_private_key,
    const uint8_t entropy[MLDSA_SEED_BYTES]) {}

int MLDSA65_public_from_private(struct MLDSA65_public_key *out_public_key,
                                const struct MLDSA65_private_key *private_key) {}

// FIPS 204, Algorithm 7 (`ML-DSA.Sign_internal`). Returns 1 on success and 0 on
// failure.
int MLDSA65_sign_internal(
    uint8_t out_encoded_signature[MLDSA65_SIGNATURE_BYTES],
    const struct MLDSA65_private_key *private_key, const uint8_t *msg,
    size_t msg_len, const uint8_t *context_prefix, size_t context_prefix_len,
    const uint8_t *context, size_t context_len,
    const uint8_t randomizer[MLDSA_SIGNATURE_RANDOMIZER_BYTES]) {}

// mldsa signature in randomized mode, filling the random bytes with
// |RAND_bytes|. Returns 1 on success and 0 on failure.
int MLDSA65_sign(uint8_t out_encoded_signature[MLDSA65_SIGNATURE_BYTES],
                 const struct MLDSA65_private_key *private_key,
                 const uint8_t *msg, size_t msg_len, const uint8_t *context,
                 size_t context_len) {}

// FIPS 204, Algorithm 3 (`ML-DSA.Verify`).
int MLDSA65_verify(const struct MLDSA65_public_key *public_key,
                   const uint8_t *signature, size_t signature_len,
                   const uint8_t *msg, size_t msg_len, const uint8_t *context,
                   size_t context_len) {}

// FIPS 204, Algorithm 8 (`ML-DSA.Verify_internal`).
int MLDSA65_verify_internal(
    const struct MLDSA65_public_key *public_key,
    const uint8_t encoded_signature[MLDSA65_SIGNATURE_BYTES],
    const uint8_t *msg, size_t msg_len, const uint8_t *context_prefix,
    size_t context_prefix_len, const uint8_t *context, size_t context_len) {}

/* Serialization of keys. */

int MLDSA65_marshal_public_key(CBB *out,
                               const struct MLDSA65_public_key *public_key) {}

int MLDSA65_parse_public_key(struct MLDSA65_public_key *public_key, CBS *in) {}

int MLDSA65_marshal_private_key(CBB *out,
                                const struct MLDSA65_private_key *private_key) {}

int MLDSA65_parse_private_key(struct MLDSA65_private_key *private_key,
                              CBS *in) {}