diff --git a/ohttp/src/err.rs b/ohttp/src/err.rs index ec7741e..212b234 100644 --- a/ohttp/src/err.rs +++ b/ohttp/src/err.rs @@ -41,6 +41,9 @@ pub enum Error { #[cfg(feature = "stream")] #[error("writes are not supported after closing")] WriteAfterClose, + #[cfg(feature = "stream")] + #[error("read a zero-length chunk")] + ZeroLengthRead, } impl From for Error { diff --git a/ohttp/src/stream.rs b/ohttp/src/stream.rs index 0c402ac..6b2b92c 100644 --- a/ohttp/src/stream.rs +++ b/ohttp/src/stream.rs @@ -422,7 +422,11 @@ impl ChunkReader { }; output[..pt.len()].copy_from_slice(&pt); *self = Self::length(); - Some(Poll::Ready(Ok(pt.len()))) + if pt.is_empty() { + Some(ioerror(Error::ZeroLengthRead)) + } else { + Some(Poll::Ready(Ok(pt.len()))) + } } else { buf.reserve_exact(*length); buf.extend_from_slice(&output[..r]); @@ -540,12 +544,10 @@ impl ChunkReader { if last { *self = Self::Done; } else { - *self = Self::length(); if pt.is_empty() { - // We can't return zero length, as that means "end of stream". - // So read the next chunk if this one was empty. - continue; + return ioerror(Error::ZeroLengthRead); } + *self = Self::length(); } pt.len() }; @@ -814,18 +816,21 @@ impl AsyncRead for ClientResponse { mod test { use std::{ io::Result as IoResult, - pin::Pin, + pin::{pin, Pin}, task::{Context, Poll}, }; use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use log::trace; use pin_project::pin_project; - use sync_async::{Dribble, Pipe, SplitAt, Stutter, SyncRead, SyncResolve, Unadapt}; + use sync_async::{ + noop_context, Dribble, Pipe, SplitAt, Stutter, SyncRead, SyncResolve, Unadapt, + }; use crate::{ + stream::ChunkWriter, test::{init, make_config, REQUEST, RESPONSE}, - ClientRequest, Server, + ClientRequest, Error, Server, }; #[test] @@ -1106,4 +1111,36 @@ mod test { let mut server_request = server.decapsulate_stream(&enc_request[..]); assert_eq!(server_request.sync_read_to_end(), LONG_REQUEST); } + + /// Check that zero-length chunks are treated as invalid. + #[test] + fn dos_zero_length() { + init(); + let server_config = make_config(); + let server = Server::new(server_config).unwrap(); + let encoded_config = server.config().encode().unwrap(); + let client = ClientRequest::from_encoded_config(&encoded_config).unwrap(); + let (request_read, request_write) = Pipe::new(); + let mut client_request = client.encapsulate_stream(request_write).unwrap(); + + let pin = Pin::new(&mut client_request.writer); + let mut projection = pin.project(); + let mut cx = noop_context(); + + // Write out a zero-length chunk before finalizing the request. + let f = ChunkWriter::flush(&mut projection, &mut cx); + assert!(f.is_ready()); + assert!(projection.buf.is_empty()); + ChunkWriter::write_chunk(&mut projection, &mut cx, &[], false).unwrap(); + client_request.write_all(REQUEST).sync_resolve().unwrap(); + + let mut buf = Vec::new(); + let mut server_request = server.decapsulate_stream(request_read); + let fut = server_request.read_to_end(&mut buf); + let err = pin!(fut).sync_resolve().unwrap_err(); + assert!(matches!( + err.get_ref().unwrap().downcast_ref().unwrap(), + Error::ZeroLengthRead + )); + } } diff --git a/sync-async/src/lib.rs b/sync-async/src/lib.rs index f4a9826..efd3a0a 100644 --- a/sync-async/src/lib.rs +++ b/sync-async/src/lib.rs @@ -12,7 +12,8 @@ use futures::{ }; use pin_project::pin_project; -fn noop_context() -> Context<'static> { +#[must_use] +pub fn noop_context() -> Context<'static> { use std::{ ptr::null, task::{RawWaker, RawWakerVTable, Waker},