// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! `local_service` exposes the `processor` crate over a local, unencrypted
//! WebSocket connection. This allows the enclave service to be simulated on the
//! local machine. It keeps state in two files in the working directory:
//! "state.transparent" and "state.confidential".
extern crate alloc;
extern crate base64;
extern crate cbor;
extern crate crypto;
extern crate handshake;
extern crate hex;
extern crate processor;
use cbor::{cbor, Value};
use crypto::P256Scalar;
use processor::{ClientState, StateUpdate};
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
// Frame types from https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
const BINARY: u8 = 2;
const CONTINUATION: u8 = 2;
const WEBSOCKET_PROTOCOL: &str = "cloudauthenticator";
/// Completly fills `buf` with data from `conn` and returns true iff successful.
fn read_all(mut conn: &TcpStream, buf: &mut [u8]) -> bool {
let mut done = 0;
while done < buf.len() {
let Ok(bytes_read) = conn.read(&mut buf[done..]) else {
return false;
};
if bytes_read == 0 {
return false;
}
done += bytes_read;
}
true
}
/// Write the full contents of `buf` to `conn`. Returns true iff successful.
fn write_all(mut conn: &TcpStream, buf: &[u8]) -> bool {
let mut done = 0;
while done < buf.len() {
let Ok(bytes_written) = conn.write(&buf[done..]) else {
return false;
};
done += bytes_written;
}
true
}
/// `next_line` recognises TLS handshakes and returns them as a special error.
enum NextLineError {
/// The client probably sent a TLS handshake, not an HTTP request.
TlsHandshake,
/// Some other I/O or UTF-8 error.
OtherError,
}
impl<E> From<E> for NextLineError
where
E: std::error::Error,
{
fn from(_: E) -> NextLineError {
Self::OtherError
}
}
/// Reads a "\r\n"-terminated line from `conn` and returns it without that
/// terminator. (Inefficient, but we don't mind in this context.)
fn next_line(mut conn: &TcpStream) -> Result<String, NextLineError> {
let mut ret = Vec::with_capacity(32);
let mut seen_cr = false;
loop {
let mut buf = [0u8; 1];
if conn.read(&mut buf)? == 0 {
return Err(NextLineError::OtherError);
}
if ret.is_empty() && buf[0] == 0x16 {
return Err(NextLineError::TlsHandshake);
}
if seen_cr && buf[0] == b'\n' {
ret.pop();
return Ok(String::from_utf8(ret)?);
}
seen_cr = buf[0] == b'\r';
ret.push(buf[0]);
}
}
/// Reads a WebSocket frame from `conn`, returning whether it's the final frame
/// of a message, the type of the frame, and its contents. See
/// https://datatracker.ietf.org/doc/html/rfc6455#section-5.
fn read_frame(conn: &TcpStream) -> Option<(bool, u8, Vec<u8>)> {
let mut buf = [0u8; 6];
if !read_all(conn, &mut buf) {
return None;
}
// See https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
let fin = buf[0] & 0x80 == 0x80;
let opcode = buf[0] & 0x0f;
let has_mask = buf[1] & 0x80 == 0x80;
if !has_mask {
eprintln!("frame from client should be masked");
return None;
}
let payload_len = buf[1] & 0x7f;
let mut mask = [0u8; 4];
let payload_len = if payload_len == 127 {
// Lengths must be minimally encoded. So this suggests a frame > 64KiB,
// which we don't need to handle in this context.
eprintln!("unsupported 64-bit length");
return None;
} else if payload_len == 126 {
let mut extra = [0u8; 2];
if !read_all(conn, &mut extra) {
return None;
}
mask[0] = buf[4];
mask[1] = buf[5];
mask[2] = extra[0];
mask[3] = extra[1];
(buf[2] as usize) << 8 | (buf[3] as usize)
} else {
mask.copy_from_slice(&buf[2..]);
payload_len as usize
};
let mut ret = vec![0; payload_len];
if !read_all(conn, &mut ret) {
return None;
}
for i in 0..ret.len() {
ret[i] ^= mask[i % 4];
}
Some((fin, opcode, ret))
}
/// Reads a WebSocket message from `conn`. See
/// https://datatracker.ietf.org/doc/html/rfc6455#section-6.
fn read_msg(conn: &TcpStream) -> Option<Vec<u8>> {
let (fin, opcode, payload) = read_frame(conn)?;
if opcode != BINARY {
eprintln!("unexpected message type {}", opcode);
return None;
}
if fin {
return Some(payload);
}
let mut ret = payload;
loop {
let (fin, opcode, payload) = read_frame(conn)?;
if opcode != CONTINUATION {
eprintln!("unexpected message type {}", opcode);
return None;
}
ret.extend_from_slice(&payload);
if fin {
return Some(ret);
}
}
}
/// Write `msg` as a WebSocket message to `conn`. Returns true iff successful.
fn write_msg(conn: &TcpStream, msg: &[u8]) -> bool {
// See https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
let len = msg.len();
if len < 126 {
let header = [0x80 | BINARY, len as u8];
write_all(conn, &header) && write_all(conn, msg)
} else if len < 0x10000 {
let header = [0x80 | BINARY, 126, (len >> 8) as u8, len as u8];
write_all(conn, &header) && write_all(conn, msg)
} else {
// Frames larger than 64KiB don't need to be supported in this context.
false
}
}
/// Calculate the correct response to a WebSocket challenge. This is checked by
/// the client to ensure that the server intends to negotiate a WebSocket
/// connection. See https://datatracker.ietf.org/doc/html/rfc6455#section-1.3
fn calculate_websocket_accept(key: &[u8]) -> String {
let digest = crypto::sha1_two_part(key, b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
base64::encode(digest)
}
struct EnclaveServer {
identity_private_key_bytes: [u8; 32],
}
impl EnclaveServer {
fn handle_connection(&mut self, conn: TcpStream) {
eprintln!("Accepted connection from {:?}", conn.peer_addr().unwrap());
let mut seen_first_line = false;
let mut websocket_key: Option<String> = None;
let mut websocket_protocol: Option<String> = None;
let mut has_reauthentication_header = false;
loop {
let line = match next_line(&conn) {
Ok(line) => line,
Err(NextLineError::OtherError) => return,
Err(NextLineError::TlsHandshake) => panic!(
"TLS handshake received. This server only speaks plaintext. Ensure that you have specified the address with ws://, not wss://"
),
};
if line.is_empty() {
break;
}
if !seen_first_line {
seen_first_line = true;
continue;
}
let Some((key, value)) = line.split_once(':') else {
eprintln!("bad header line");
return;
};
let key = key.trim();
match key.to_lowercase().as_str() {
"sec-websocket-key" => websocket_key = Some(String::from(value.trim())),
"sec-websocket-protocol" => websocket_protocol = Some(String::from(value.trim())),
"reauthentication" => has_reauthentication_header = true,
_ => (),
}
}
let Some(websocket_key) = websocket_key else {
eprintln!("bad WebSocket request");
return;
};
match websocket_protocol {
Some(protocol) if protocol.as_str() == WEBSOCKET_PROTOCOL => (),
_ => {
eprintln!("missing expected WebSocket protocol");
return;
}
}
const RESPONSE : &[u8] = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ";
let accept_value = calculate_websocket_accept(websocket_key.as_bytes());
const PROTOCOL_HEADER: &[u8] = b"Sec-WebSocket-Protocol: ";
const NEWLINE: &[u8] = b"\r\n";
if !write_all(&conn, RESPONSE)
|| !write_all(&conn, accept_value.as_bytes())
|| !write_all(&conn, NEWLINE)
|| !write_all(&conn, PROTOCOL_HEADER)
|| !write_all(&conn, WEBSOCKET_PROTOCOL.as_bytes())
|| !write_all(&conn, NEWLINE)
|| !write_all(&conn, NEWLINE)
{
return;
}
let Some(handshake_request) = read_msg(&conn) else {
return;
};
let mut handshake_response =
match handshake::respond(&self.identity_private_key_bytes, &handshake_request) {
Ok(r) => r,
Err(e) => {
eprintln!("Failed to generate handshake response: {:?}", e);
return;
}
};
if !write_msg(&conn, &handshake_response.response) {
return;
}
let Some(cmd_request) = read_msg(&conn) else {
return;
};
let Ok(commands) = handshake_response.crypter.decrypt(&cmd_request) else {
eprintln!("Failed to decrypt commands");
return;
};
const CONFIDENTIAL_PATH: &str = "state.confidential";
const TRANSPARENT_PATH: &str = "state.transparent";
let client_state = std::fs::read(CONFIDENTIAL_PATH)
.and_then(|confidential| {
std::fs::read(TRANSPARENT_PATH).map(|transparent| {
ClientState::Explicit(processor::StateData { transparent, confidential })
})
})
.unwrap_or(ClientState::Initial);
let cbor_response = match processor::process_client_msg(
client_state,
processor::ExternalContext {
// This timestamp is fixed so that any XML files submitted by tests will be
// considered unexpired.
current_time_epoch_millis: 1707344402000,
client_device_identifier: Vec::new(),
is_reauthenticated: has_reauthentication_header,
},
&handshake_response.handshake_hash,
commands,
) {
Ok((result_array, state_update)) => {
let state_data = match state_update {
StateUpdate::Major(data) => Some(data),
StateUpdate::Minor(data) => Some(data),
StateUpdate::None => None,
};
if let Some(state_data) = state_data {
std::fs::write(CONFIDENTIAL_PATH, &state_data.confidential).unwrap();
std::fs::write(TRANSPARENT_PATH, &state_data.transparent).unwrap();
}
cbor!({"ok": result_array})
}
Err(err) => {
eprintln!("{:?}", err);
let err = match err {
processor::Error::UnknownClient => Value::Int(0),
processor::Error::Str(s) => Value::String(String::from(s)),
_ => Value::String(format!("{:?}", err)),
};
cbor!({"err": err})
}
};
let cmd_response = handshake_response.crypter.encrypt(&cbor_response.to_bytes()).unwrap();
write_msg(&conn, &cmd_response);
}
}
fn main() {
// The corresponding hex public that has to be manually provided to the
// test client is:
// 046b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c2964fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5
let mut identity_private_key_bytes = [0u8; 32];
identity_private_key_bytes[31] = 1;
let mut service = EnclaveServer { identity_private_key_bytes };
let scalar: P256Scalar = (&identity_private_key_bytes).try_into().unwrap();
eprintln!("Public key is {}", hex::encode(scalar.compute_public_key()));
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let local_addr = listener.local_addr().unwrap();
println!("{}", local_addr.port());
eprintln!("Listening on ws://{}", local_addr);
for stream in listener.incoming() {
let stream = stream.unwrap();
stream.set_nodelay(true).unwrap();
service.handle_connection(stream);
}
}