chromium/third_party/cloud_authenticator/handshake/src/lib.rs

// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#![no_std]

#[allow(unused_imports)] // Macros only used in tests.
#[macro_use]
extern crate alloc;
extern crate crypto;
extern crate static_assertions;

mod error;
mod noise;

use crate::error::Error;
use crate::noise::{HandshakeType, Noise};
use alloc::vec::Vec;
use crypto::{P256Scalar, NONCE_LEN, P256_X962_LENGTH};

// This is assumed to be vastly larger than any connection will ever reach.
const MAX_SEQUENCE: u32 = 1u32 << 24;

#[derive(Debug)]
pub struct Crypter {
    read_key: [u8; 32],
    write_key: [u8; 32],
    read_nonce: u32,
    write_nonce: u32,
}

/// Utility for encrypting and decrypting traffic between the Noise endpoints.
/// It is created by |respond| and configured with a key for each traffic
/// direction.
impl Crypter {
    fn new(read_key: &[u8; 32], write_key: &[u8; 32]) -> Self {
        Self { read_key: *read_key, write_key: *write_key, read_nonce: 0, write_nonce: 0 }
    }

    fn next_nonce(nonce: &mut u32) -> Result<[u8; NONCE_LEN], Error> {
        if *nonce > MAX_SEQUENCE {
            return Err(Error::DecryptFailed);
        }
        let mut ret = [0u8; NONCE_LEN];
        ret[NONCE_LEN - 4..].copy_from_slice(nonce.to_be_bytes().as_slice());
        *nonce += 1;
        Ok(ret)
    }

    pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>, Error> {
        const PADDING_GRANULARITY: usize = 32;
        static_assertions::const_assert!(PADDING_GRANULARITY < 256);
        static_assertions::const_assert!((PADDING_GRANULARITY & (PADDING_GRANULARITY - 1)) == 0);

        let mut padded_size: usize = plaintext.len();
        if padded_size > (1usize << 28) {
            return Err(Error::DataTooLarge(padded_size));
        }
        padded_size += 1; // padding-length byte
        padded_size = (padded_size + PADDING_GRANULARITY - 1) & !(PADDING_GRANULARITY - 1);

        let mut padded_encrypt_data = Vec::with_capacity(padded_size);
        padded_encrypt_data.extend_from_slice(plaintext);
        padded_encrypt_data.resize(padded_size, 0u8);
        let num_zeros = padded_size - plaintext.len() - 1;
        padded_encrypt_data[padded_size - 1] = num_zeros as u8;

        crypto::aes_256_gcm_seal_in_place(
            &self.write_key,
            &Self::next_nonce(&mut self.write_nonce)?,
            &[],
            &mut padded_encrypt_data,
        );
        Ok(padded_encrypt_data)
    }

    pub fn decrypt(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>, Error> {
        let plaintext = crypto::aes_256_gcm_open_in_place(
            &self.read_key,
            &Self::next_nonce(&mut self.read_nonce)?,
            &[],
            Vec::from(ciphertext),
        )
        .map_err(|_| Error::DecryptFailed)?;

        // Plaintext must have a padding byte, and the unpadded length must be
        // at least one.
        if plaintext.is_empty() || (plaintext[plaintext.len() - 1] as usize) >= plaintext.len() {
            return Err(Error::DecryptionPaddingError);
        }
        let unpadded_length = plaintext.len() - (plaintext[plaintext.len() - 1] as usize);
        Ok(Vec::from(&plaintext[0..unpadded_length - 1]))
    }
}

pub struct Response {
    pub crypter: Crypter,
    pub handshake_hash: [u8; 32],
    pub response: Vec<u8>,
}

/// Performs the Responder side of the Noise protocol with the NK pattern.
/// |identity_private_key_bytes| contains the private key scalar for the
/// service's provisioned identity. |in_data| is provided by the Initiator and
/// contains its ephemeral public key and encrypted payload.
///
/// The identity public key is computed from the private key, but could
/// alternatively be stored separately to reduce computation if needed to
/// reduce per-transaction computation.
/// See https://noiseexplorer.com/patterns/NK/
pub fn respond(identity_private_key_bytes: &[u8], in_data: &[u8]) -> Result<Response, Error> {
    if in_data.len() < P256_X962_LENGTH {
        return Err(Error::InvalidHandshake);
    }

    let mut noise = Noise::new(HandshakeType::Nk);
    noise.mix_hash(&[0; 1]); // Prologue

    let identity_scalar: P256Scalar =
        identity_private_key_bytes.try_into().map_err(|_| Error::InvalidPrivateKey)?;
    let identity_pub = identity_scalar.compute_public_key();

    noise.mix_hash_point(identity_pub.as_slice());

    // unwrap: we know that `in_data` is `P256_X962_LENGTH` bytes long.
    let peer_pub: [u8; P256_X962_LENGTH] = (&in_data[..P256_X962_LENGTH]).try_into().unwrap();
    noise.mix_hash(peer_pub.as_slice());
    noise.mix_key(peer_pub.as_slice());

    let es_ecdh_bytes = crypto::p256_scalar_mult(&identity_scalar, &peer_pub)
        .map_err(|_| Error::InvalidHandshake)?;
    noise.mix_key(es_ecdh_bytes.as_slice());

    let plaintext = noise.decrypt_and_hash(&in_data[P256_X962_LENGTH..])?;
    if !plaintext.is_empty() {
        return Err(Error::InvalidHandshake);
    }

    // Generate ephemeral key pair
    let ephemeral_priv = P256Scalar::generate();
    let ephemeral_pub_key_bytes = ephemeral_priv.compute_public_key();
    noise.mix_hash(ephemeral_pub_key_bytes.as_slice());
    noise.mix_key(ephemeral_pub_key_bytes.as_slice());
    let ee_ecdh_bytes = crypto::p256_scalar_mult(&ephemeral_priv, &peer_pub)
        .map_err(|_| Error::InvalidHandshake)?;
    noise.mix_key(ee_ecdh_bytes.as_slice());

    let response_ciphertext = noise.encrypt_and_hash(&[]);

    let keys = noise.traffic_keys();
    Ok(Response {
        crypter: Crypter::new(&keys.0, &keys.1),
        handshake_hash: noise.handshake_hash(),
        response: [ephemeral_pub_key_bytes.as_slice(), &response_ciphertext].concat(),
    })
}

pub mod test_client {
    use super::*;

    pub struct HandshakeInitiator {
        noise: Noise,
        identity_pub_key: [u8; P256_X962_LENGTH],
        ephemeral_priv_key: P256Scalar,
    }

    impl HandshakeInitiator {
        pub fn new(peer_public_key: &[u8; P256_X962_LENGTH]) -> Self {
            Self {
                noise: Noise::new(HandshakeType::Nk),
                identity_pub_key: *peer_public_key,
                ephemeral_priv_key: P256Scalar::generate(),
            }
        }

        pub fn build_initial_message(&mut self) -> Vec<u8> {
            self.noise.mix_hash(&[0; 1]);
            self.noise.mix_hash_point(self.identity_pub_key.as_slice());
            let ephemeral_pub_key = self.ephemeral_priv_key.compute_public_key();
            let ephemeral_pub_key_bytes = ephemeral_pub_key.as_ref();

            self.noise.mix_hash(ephemeral_pub_key_bytes);
            self.noise.mix_key(ephemeral_pub_key_bytes);
            let es_ecdh_bytes =
                crypto::p256_scalar_mult(&self.ephemeral_priv_key, &self.identity_pub_key).unwrap();
            self.noise.mix_key(&es_ecdh_bytes);

            let ciphertext = self.noise.encrypt_and_hash(&[]);
            [ephemeral_pub_key_bytes, &ciphertext].concat()
        }

        pub fn process_response(&mut self, handshake_response: &[u8]) -> ([u8; 32], Crypter) {
            let peer_public_key_bytes = &handshake_response[..P256_X962_LENGTH];
            let ciphertext = &handshake_response[P256_X962_LENGTH..];

            let ee_ecdh_bytes = crypto::p256_scalar_mult(
                &self.ephemeral_priv_key,
                peer_public_key_bytes.try_into().unwrap(),
            )
            .unwrap();
            self.noise.mix_hash(peer_public_key_bytes);
            self.noise.mix_key(peer_public_key_bytes);
            self.noise.mix_key(&ee_ecdh_bytes);

            let plaintext = self.noise.decrypt_and_hash(ciphertext).unwrap();
            assert_eq!(plaintext.len(), 0);
            let (write_key, read_key) = self.noise.traffic_keys();
            (self.noise.handshake_hash(), Crypter::new(&read_key, &write_key))
        }
    }
}

#[cfg(test)]
mod tests {
    use super::test_client::HandshakeInitiator;
    use super::*;

    #[test]
    fn process_handshake() {
        let test_messages = vec![vec![1u8, 2u8, 3u8, 4u8], vec![4u8, 3u8, 2u8, 1u8], vec![]];
        let identity_priv = P256Scalar::generate();
        let identity_pub_bytes = identity_priv.compute_public_key();
        let mut initiator = HandshakeInitiator::new(&identity_pub_bytes);
        let message = initiator.build_initial_message();
        let handshake_response = respond(identity_priv.bytes().as_slice(), &message).unwrap();
        let mut enclave_crypter = handshake_response.crypter;

        let (client_hash, mut client_crypter) =
            initiator.process_response(&handshake_response.response);
        assert_eq!(&client_hash, &handshake_response.handshake_hash);

        // Client -> Enclave encrypt+decrypt
        for message in &test_messages {
            let ciphertext = client_crypter.encrypt(message).unwrap();
            let plaintext = enclave_crypter.decrypt(&ciphertext).unwrap();
            assert_eq!(message, &plaintext);
        }

        // Enclave -> Client encrypt+decrypt
        for message in &test_messages {
            let ciphertext = enclave_crypter.encrypt(message).unwrap();
            let plaintext = client_crypter.decrypt(&ciphertext).unwrap();
            assert_eq!(message, &plaintext);
        }
    }
}