diff --git a/src/dtls12/message/handshake.rs b/src/dtls12/message/handshake.rs index ed555c3..d17f5d1 100644 --- a/src/dtls12/message/handshake.rs +++ b/src/dtls12/message/handshake.rs @@ -13,6 +13,7 @@ use super::HelloVerifyRequest; use super::ServerHello; use super::ServerKeyExchange; use crate::buffer::Buf; +use arrayvec::ArrayVec; use nom::Err; use nom::IResult; use nom::bytes::complete::take; @@ -20,6 +21,8 @@ use nom::error::{Error, ErrorKind}; use nom::number::complete::be_u8; use nom::number::complete::{be_u16, be_u24}; +const MAX_DEFRAGMENT_HANDSHAKES: usize = 50; + #[derive(Debug, PartialEq, Eq, Default, Clone, Copy)] pub struct Header { pub msg_type: MessageType, @@ -154,8 +157,11 @@ impl Handshake { let Body::Fragment(range) = &first_handshake.body else { unreachable!("Non-Fragment body in defragment()") }; + let mut handled = ArrayVec::<&Handshake, MAX_DEFRAGMENT_HANDSHAKES>::new(); + handled + .try_push(first_handshake) + .map_err(|_| crate::InternalError::too_many_records())?; buffer.extend_from_slice(&first_buffer[range.clone()]); - first_handshake.set_handled(); for (handshake, source_buf) in iter { if handshake.header.msg_type != first_handshake.header.msg_type { @@ -166,7 +172,9 @@ impl Handshake { unreachable!("Non-Fragment body in defragment()") }; - handshake.handled.store(true, Ordering::Relaxed); + handled + .try_push(handshake) + .map_err(|_| crate::InternalError::too_many_records())?; buffer.extend_from_slice(&source_buf[range.clone()]); } @@ -176,7 +184,18 @@ impl Handshake { return Err(crate::InternalError::parse_incomplete()); } - // If transcript is provided, write the handshake header + body before parsing + let (rest, body) = Body::parse(buffer, 0, first_handshake.header.msg_type, cipher_suite)?; + + if !rest.is_empty() && first_handshake.header.msg_type == MessageType::Finished { + debug!("Defragmentation failed. Body::parse() did not consume the entire buffer"); + return Err(crate::InternalError::parse_incomplete()); + } + + for handshake in handled { + handshake.set_handled(); + } + + // If transcript is provided, write the handshake header + body after parsing succeeds. if let Some(transcript) = transcript { transcript.push(first_handshake.header.msg_type.as_u8()); transcript.extend_from_slice(&first_handshake.header.length.to_be_bytes()[1..]); @@ -187,13 +206,6 @@ impl Handshake { transcript.extend_from_slice(&buffer[..first_handshake.header.length as usize]); } - let (rest, body) = Body::parse(buffer, 0, first_handshake.header.msg_type, cipher_suite)?; - - if !rest.is_empty() && first_handshake.header.msg_type == MessageType::Finished { - debug!("Defragmentation failed. Body::parse() did not consume the entire buffer"); - return Err(crate::InternalError::parse_incomplete()); - } - let handshake = Handshake { header: Header { msg_type: first_handshake.header.msg_type, @@ -695,4 +707,33 @@ mod tests { assert!(rest.is_empty()); } + + #[test] + fn failed_defragment_parse_does_not_mark_handshake_or_write_transcript() { + let mut body = MESSAGE[12..].to_vec(); + body[38] = 0; + body[39] = 3; + + let handshake = Handshake::new( + MessageType::ClientHello, + body.len() as u32, + 0, + 0, + body.len() as u32, + Body::Fragment(0..body.len()), + ); + + let mut defragmented_buffer = Buf::new(); + let mut transcript = Buf::new(); + let result = Handshake::defragment( + std::iter::once((&handshake, body.as_slice())), + &mut defragmented_buffer, + None, + Some(&mut transcript), + ); + + assert!(result.is_err()); + assert!(!handshake.is_handled()); + assert!(transcript.is_empty()); + } } diff --git a/src/dtls13/message/handshake.rs b/src/dtls13/message/handshake.rs index 7896bcf..0ddaa83 100644 --- a/src/dtls13/message/handshake.rs +++ b/src/dtls13/message/handshake.rs @@ -5,6 +5,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; use super::{Certificate, CertificateVerify, ClientHello, Dtls13CipherSuite}; use super::{EncryptedExtensions, Finished, ServerHello}; use crate::buffer::Buf; +use arrayvec::ArrayVec; use nom::Err; use nom::IResult; use nom::bytes::complete::take; @@ -12,6 +13,8 @@ use nom::error::{Error, ErrorKind}; use nom::number::complete::be_u8; use nom::number::complete::{be_u16, be_u24}; +const MAX_DEFRAGMENT_HANDSHAKES: usize = 50; + #[derive(Debug, PartialEq, Eq, Default, Clone, Copy)] pub struct Header { pub msg_type: MessageType, @@ -165,8 +168,11 @@ impl Handshake { let Body::Fragment(range) = &first_handshake.body else { unreachable!("Non-Fragment body in defragment()") }; + let mut handled = ArrayVec::<&Handshake, MAX_DEFRAGMENT_HANDSHAKES>::new(); + handled + .try_push(first_handshake) + .map_err(|_| crate::InternalError::too_many_records())?; buffer.extend_from_slice(&first_buffer[range.clone()]); - first_handshake.set_handled(); let mut assembled_end = first_handshake.header.fragment_offset + first_handshake.header.fragment_length; @@ -180,7 +186,9 @@ impl Handshake { unreachable!("Non-Fragment body in defragment()") }; - handshake.handled.store(true, Ordering::Relaxed); + handled + .try_push(handshake) + .map_err(|_| crate::InternalError::too_many_records())?; // Handle overlapping fragment data: skip bytes already assembled let frag_start = handshake.header.fragment_offset as usize; @@ -200,15 +208,6 @@ impl Handshake { return Err(crate::InternalError::parse_incomplete()); } - // If transcript is provided, write the TLS 1.3-style header + body. - // Per RFC 9147 Section 5.2, the transcript uses msg_type(1) + length(3) - // WITHOUT the DTLS-specific message_seq, fragment_offset, fragment_length. - if let Some(transcript) = transcript { - transcript.push(first_handshake.header.msg_type.as_u8()); - transcript.extend_from_slice(&first_handshake.header.length.to_be_bytes()[1..]); - transcript.extend_from_slice(&buffer[..first_handshake.header.length as usize]); - } - let (rest, body) = if allow_unknown_client_hello_suites { Body::parse_allow_unknown_client_hello_suites( buffer, @@ -225,6 +224,19 @@ impl Handshake { return Err(crate::InternalError::parse_incomplete()); } + for handshake in handled { + handshake.set_handled(); + } + + // If transcript is provided, write the TLS 1.3-style header + body after parsing succeeds. + // Per RFC 9147 Section 5.2, the transcript uses msg_type(1) + length(3) + // WITHOUT the DTLS-specific message_seq, fragment_offset, fragment_length. + if let Some(transcript) = transcript { + transcript.push(first_handshake.header.msg_type.as_u8()); + transcript.extend_from_slice(&first_handshake.header.length.to_be_bytes()[1..]); + transcript.extend_from_slice(&buffer[..first_handshake.header.length as usize]); + } + let handshake = Handshake { header: Header { msg_type: first_handshake.header.msg_type, @@ -619,16 +631,19 @@ mod tests { ); let mut buffer = Buf::new(); + let mut transcript = Buf::new(); let result = Handshake::defragment( std::iter::once((&handshake, source.as_slice())), &mut buffer, None, - None, + Some(&mut transcript), ); assert!( result.is_err(), "KeyUpdate bodies with trailing bytes must be rejected" ); + assert!(!handshake.is_handled()); + assert!(transcript.is_empty()); } }