Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 51 additions & 10 deletions src/dtls12/message/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@ use super::HelloVerifyRequest;
use super::ServerHello;
use super::ServerKeyExchange;
use crate::buffer::Buf;
use arrayvec::ArrayVec;
use nom::Err;
use nom::IResult;
use nom::bytes::complete::take;
use nom::error::{Error, ErrorKind};
use nom::number::complete::be_u8;
use nom::number::complete::{be_u16, be_u24};

const MAX_DEFRAGMENT_HANDSHAKES: usize = 50;

#[derive(Debug, PartialEq, Eq, Default, Clone, Copy)]
pub struct Header {
pub msg_type: MessageType,
Expand Down Expand Up @@ -154,8 +157,11 @@ impl Handshake {
let Body::Fragment(range) = &first_handshake.body else {
unreachable!("Non-Fragment body in defragment()")
};
let mut handled = ArrayVec::<&Handshake, MAX_DEFRAGMENT_HANDSHAKES>::new();
handled
.try_push(first_handshake)
.map_err(|_| crate::InternalError::too_many_records())?;
buffer.extend_from_slice(&first_buffer[range.clone()]);
first_handshake.set_handled();

for (handshake, source_buf) in iter {
if handshake.header.msg_type != first_handshake.header.msg_type {
Expand All @@ -166,7 +172,9 @@ impl Handshake {
unreachable!("Non-Fragment body in defragment()")
};

handshake.handled.store(true, Ordering::Relaxed);
handled
.try_push(handshake)
.map_err(|_| crate::InternalError::too_many_records())?;

buffer.extend_from_slice(&source_buf[range.clone()]);
}
Expand All @@ -176,7 +184,18 @@ impl Handshake {
return Err(crate::InternalError::parse_incomplete());
}

// If transcript is provided, write the handshake header + body before parsing
let (rest, body) = Body::parse(buffer, 0, first_handshake.header.msg_type, cipher_suite)?;

if !rest.is_empty() && first_handshake.header.msg_type == MessageType::Finished {
debug!("Defragmentation failed. Body::parse() did not consume the entire buffer");
return Err(crate::InternalError::parse_incomplete());
}

for handshake in handled {
handshake.set_handled();
}

// If transcript is provided, write the handshake header + body after parsing succeeds.
if let Some(transcript) = transcript {
transcript.push(first_handshake.header.msg_type.as_u8());
transcript.extend_from_slice(&first_handshake.header.length.to_be_bytes()[1..]);
Expand All @@ -187,13 +206,6 @@ impl Handshake {
transcript.extend_from_slice(&buffer[..first_handshake.header.length as usize]);
}

let (rest, body) = Body::parse(buffer, 0, first_handshake.header.msg_type, cipher_suite)?;

if !rest.is_empty() && first_handshake.header.msg_type == MessageType::Finished {
debug!("Defragmentation failed. Body::parse() did not consume the entire buffer");
return Err(crate::InternalError::parse_incomplete());
}

let handshake = Handshake {
header: Header {
msg_type: first_handshake.header.msg_type,
Expand Down Expand Up @@ -695,4 +707,33 @@ mod tests {

assert!(rest.is_empty());
}

#[test]
fn failed_defragment_parse_does_not_mark_handshake_or_write_transcript() {
let mut body = MESSAGE[12..].to_vec();
body[38] = 0;
body[39] = 3;

let handshake = Handshake::new(
MessageType::ClientHello,
body.len() as u32,
0,
0,
body.len() as u32,
Body::Fragment(0..body.len()),
);

let mut defragmented_buffer = Buf::new();
let mut transcript = Buf::new();
let result = Handshake::defragment(
std::iter::once((&handshake, body.as_slice())),
&mut defragmented_buffer,
None,
Some(&mut transcript),
);

assert!(result.is_err());
assert!(!handshake.is_handled());
assert!(transcript.is_empty());
}
}
39 changes: 27 additions & 12 deletions src/dtls13/message/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ use std::sync::atomic::{AtomicBool, Ordering};
use super::{Certificate, CertificateVerify, ClientHello, Dtls13CipherSuite};
use super::{EncryptedExtensions, Finished, ServerHello};
use crate::buffer::Buf;
use arrayvec::ArrayVec;
use nom::Err;
use nom::IResult;
use nom::bytes::complete::take;
use nom::error::{Error, ErrorKind};
use nom::number::complete::be_u8;
use nom::number::complete::{be_u16, be_u24};

const MAX_DEFRAGMENT_HANDSHAKES: usize = 50;

#[derive(Debug, PartialEq, Eq, Default, Clone, Copy)]
pub struct Header {
pub msg_type: MessageType,
Expand Down Expand Up @@ -165,8 +168,11 @@ impl Handshake {
let Body::Fragment(range) = &first_handshake.body else {
unreachable!("Non-Fragment body in defragment()")
};
let mut handled = ArrayVec::<&Handshake, MAX_DEFRAGMENT_HANDSHAKES>::new();
handled
.try_push(first_handshake)
.map_err(|_| crate::InternalError::too_many_records())?;
buffer.extend_from_slice(&first_buffer[range.clone()]);
first_handshake.set_handled();

let mut assembled_end =
first_handshake.header.fragment_offset + first_handshake.header.fragment_length;
Expand All @@ -180,7 +186,9 @@ impl Handshake {
unreachable!("Non-Fragment body in defragment()")
};

handshake.handled.store(true, Ordering::Relaxed);
handled
.try_push(handshake)
.map_err(|_| crate::InternalError::too_many_records())?;

// Handle overlapping fragment data: skip bytes already assembled
let frag_start = handshake.header.fragment_offset as usize;
Expand All @@ -200,15 +208,6 @@ impl Handshake {
return Err(crate::InternalError::parse_incomplete());
}

// If transcript is provided, write the TLS 1.3-style header + body.
// Per RFC 9147 Section 5.2, the transcript uses msg_type(1) + length(3)
// WITHOUT the DTLS-specific message_seq, fragment_offset, fragment_length.
if let Some(transcript) = transcript {
transcript.push(first_handshake.header.msg_type.as_u8());
transcript.extend_from_slice(&first_handshake.header.length.to_be_bytes()[1..]);
transcript.extend_from_slice(&buffer[..first_handshake.header.length as usize]);
}

let (rest, body) = if allow_unknown_client_hello_suites {
Body::parse_allow_unknown_client_hello_suites(
buffer,
Expand All @@ -225,6 +224,19 @@ impl Handshake {
return Err(crate::InternalError::parse_incomplete());
}

for handshake in handled {
handshake.set_handled();
}

// If transcript is provided, write the TLS 1.3-style header + body after parsing succeeds.
// Per RFC 9147 Section 5.2, the transcript uses msg_type(1) + length(3)
// WITHOUT the DTLS-specific message_seq, fragment_offset, fragment_length.
if let Some(transcript) = transcript {
transcript.push(first_handshake.header.msg_type.as_u8());
transcript.extend_from_slice(&first_handshake.header.length.to_be_bytes()[1..]);
transcript.extend_from_slice(&buffer[..first_handshake.header.length as usize]);
}

let handshake = Handshake {
header: Header {
msg_type: first_handshake.header.msg_type,
Expand Down Expand Up @@ -619,16 +631,19 @@ mod tests {
);

let mut buffer = Buf::new();
let mut transcript = Buf::new();
let result = Handshake::defragment(
std::iter::once((&handshake, source.as_slice())),
&mut buffer,
None,
None,
Some(&mut transcript),
);

assert!(
result.is_err(),
"KeyUpdate bodies with trailing bytes must be rejected"
);
assert!(!handshake.is_handled());
assert!(transcript.is_empty());
}
}
Loading