From c4e7e4cb4b624d7738c2016d7b7885c6774ff9c2 Mon Sep 17 00:00:00 2001 From: Vlad Semenov Date: Tue, 18 Mar 2025 22:19:44 +0300 Subject: [PATCH 1/4] anemo: add KnownPeers batch_update (#66) --- crates/anemo/src/network/connection_manager.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/crates/anemo/src/network/connection_manager.rs b/crates/anemo/src/network/connection_manager.rs index eaec5aca..ce01fb7d 100644 --- a/crates/anemo/src/network/connection_manager.rs +++ b/crates/anemo/src/network/connection_manager.rs @@ -743,6 +743,19 @@ impl KnownPeers { self.inner_mut().insert(peer_info.peer_id, peer_info) } + pub fn batch_update<'a>( + &self, + to_remove: impl Iterator, + to_insert: impl Iterator, + ) -> (Vec>, Vec>) { + let mut inner = self.inner_mut(); + let removed = to_remove.map(|peer_id| inner.remove(peer_id)).collect(); + let inserted = to_insert + .map(|peer_info| inner.insert(peer_info.peer_id, peer_info)) + .collect(); + (removed, inserted) + } + fn inner(&self) -> std::sync::RwLockReadGuard<'_, HashMap> { self.0.read().unwrap() } From 4b5f0f1d06a31c8ef78ec2e5b446bc633e4e2f77 Mon Sep 17 00:00:00 2001 From: Andrew Schran Date: Fri, 19 Dec 2025 11:52:49 -0500 Subject: [PATCH 2/4] add PartialEq, Eq to some types (#69) --- crates/anemo/src/types/address.rs | 2 +- crates/anemo/src/types/mod.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/anemo/src/types/address.rs b/crates/anemo/src/types/address.rs index 16964e12..3c92cc72 100644 --- a/crates/anemo/src/types/address.rs +++ b/crates/anemo/src/types/address.rs @@ -1,5 +1,5 @@ /// Representation of a network address that is dial-able in Anemo -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum Address { /// A plain SocketAddr SocketAddr(std::net::SocketAddr), diff --git a/crates/anemo/src/types/mod.rs b/crates/anemo/src/types/mod.rs index ec3c21ab..ec2a1d63 100644 --- a/crates/anemo/src/types/mod.rs +++ b/crates/anemo/src/types/mod.rs @@ -43,7 +43,7 @@ pub mod header { pub const TIMEOUT: &str = "timeout"; } -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum PeerAffinity { /// Always attempt to maintain a connection with this Peer. High, @@ -54,7 +54,7 @@ pub enum PeerAffinity { Never, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct PeerInfo { pub peer_id: PeerId, pub affinity: PeerAffinity, From 9248f8d3951df750360308df22618e5978ad6dc0 Mon Sep 17 00:00:00 2001 From: Andrew Schran Date: Mon, 2 Mar 2026 13:13:00 -0600 Subject: [PATCH 3/4] replace LengthDelimitedCodec with custom MessageFrameCodec (#70) --- crates/anemo/src/config.rs | 10 +- crates/anemo/src/crypto.rs | 5 +- crates/anemo/src/network/request_handler.rs | 7 +- crates/anemo/src/network/wire.rs | 363 +++++++++++++++++++- crates/anemo/src/types/mod.rs | 9 +- crates/anemo/src/types/response.rs | 9 +- 6 files changed, 367 insertions(+), 36 deletions(-) diff --git a/crates/anemo/src/config.rs b/crates/anemo/src/config.rs index 4bfceee5..2ffd4127 100644 --- a/crates/anemo/src/config.rs +++ b/crates/anemo/src/config.rs @@ -63,13 +63,13 @@ pub struct Config { /// /// This limit is applied in the following ways: /// - Inbound connections from [`KnownPeers`] with [`PeerAffinity::High`] or - /// [`PeerAffinity::Allowed`] bypass this limit. All other inbound - /// connections are only accepted if the total number of inbound and outbound - /// connections, irrespective of affinity, is less than this limit. + /// [`PeerAffinity::Allowed`] bypass this limit. All other inbound + /// connections are only accepted if the total number of inbound and outbound + /// connections, irrespective of affinity, is less than this limit. /// - Outbound connections explicitly made by the application via [`Network::connect`] or - /// [`Network::connect_with_peer_id`] bypass this limit. + /// [`Network::connect_with_peer_id`] bypass this limit. /// - Outbound connections made in the background, due to configured [`KnownPeers`], to peers with - /// [`PeerAffinity::High`] bypass this limit and are always attempted. + /// [`PeerAffinity::High`] bypass this limit and are always attempted. /// /// If unspecified, there will be no limit on the number of concurrent connections. /// diff --git a/crates/anemo/src/crypto.rs b/crates/anemo/src/crypto.rs index 50d21588..e753e4af 100644 --- a/crates/anemo/src/crypto.rs +++ b/crates/anemo/src/crypto.rs @@ -326,9 +326,12 @@ fn pki_error(error: webpki::Error) -> rustls::Error { BadDer | BadDerTime => { rustls::Error::InvalidCertificate(rustls::CertificateError::BadEncoding) } + #[allow(deprecated)] InvalidSignatureForPublicKey | UnsupportedSignatureAlgorithm - | UnsupportedSignatureAlgorithmForPublicKey => { + | UnsupportedSignatureAlgorithmForPublicKey + | UnsupportedSignatureAlgorithmContext(..) + | UnsupportedSignatureAlgorithmForPublicKeyContext(..) => { rustls::Error::InvalidCertificate(rustls::CertificateError::BadSignature) } e => { diff --git a/crates/anemo/src/network/request_handler.rs b/crates/anemo/src/network/request_handler.rs index ecee741c..9d5aaf87 100644 --- a/crates/anemo/src/network/request_handler.rs +++ b/crates/anemo/src/network/request_handler.rs @@ -1,3 +1,4 @@ +use super::wire::MessageFrameCodec; use super::{ wire::{network_message_frame_codec, read_request, write_response}, ActivePeers, @@ -10,7 +11,7 @@ use bytes::Bytes; use quinn::RecvStream; use std::convert::Infallible; use std::sync::Arc; -use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; +use tokio_util::codec::{FramedRead, FramedWrite}; use tower::{util::BoxCloneService, ServiceExt}; use tracing::{debug, trace}; @@ -121,8 +122,8 @@ impl InboundRequestHandler { struct BiStreamRequestHandler { connection: Connection, service: BoxCloneService, Response, Infallible>, - send_stream: FramedWrite, - recv_stream: FramedRead, + send_stream: FramedWrite, + recv_stream: FramedRead, } impl BiStreamRequestHandler { diff --git a/crates/anemo/src/network/wire.rs b/crates/anemo/src/network/wire.rs index 3a126600..03b5f5b0 100644 --- a/crates/anemo/src/network/wire.rs +++ b/crates/anemo/src/network/wire.rs @@ -9,23 +9,124 @@ use crate::{ Config, Request, Response, Result, }; use anyhow::{anyhow, bail}; -use bytes::{BufMut, Bytes, BytesMut}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; +use std::io; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; +use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite}; const ANEMO: &[u8; 5] = b"anemo"; -/// Returns a fully configured length-delimited codec for writing/reading -/// serialized frames to/from a socket. -pub(crate) fn network_message_frame_codec(config: &Config) -> LengthDelimitedCodec { - let mut builder = LengthDelimitedCodec::builder(); +/// Maximum number of bytes to pre-allocate when decoding a frame. +const MAX_PREALLOCATION: usize = 1 << 20; // 1 MB + +/// Default maximum frame length. +const DEFAULT_MAX_FRAME_LENGTH: usize = 8 * 1024 * 1024; // 8 MB + +/// A length-delimited codec that uses the same wire format as tokio-util's +/// `LengthDelimitedCodec` (4-byte big-endian length prefix + data). +pub(crate) struct MessageFrameCodec { + state: DecodeState, + max_frame_length: usize, +} + +#[derive(Debug, Clone, Copy)] +enum DecodeState { + /// Waiting for the 4-byte length prefix. + Head, + /// Accumulating body bytes; stores the total expected frame length. + Data(usize), +} + +impl MessageFrameCodec { + fn new(max_frame_length: usize) -> Self { + Self { + state: DecodeState::Head, + max_frame_length, + } + } +} + +impl Decoder for MessageFrameCodec { + type Item = BytesMut; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> io::Result> { + loop { + match self.state { + DecodeState::Head => { + if src.len() < 4 { + src.reserve(4 - src.len()); + return Ok(None); + } + + let len = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize; + src.advance(4); + + if len > self.max_frame_length { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "frame of length {} exceeds max frame length of {}", + len, self.max_frame_length, + ), + )); + } + + if len == 0 { + return Ok(Some(BytesMut::new())); + } + + // Only pre-allocate up to MAX_PREALLOCATION. + let to_reserve = len.min(MAX_PREALLOCATION); + src.reserve(to_reserve); + + self.state = DecodeState::Data(len); + } + DecodeState::Data(frame_len) => { + if src.len() < frame_len { + // Reserve only up to MAX_PREALLOCATION beyond what we already have. + let remaining = frame_len - src.len(); + let to_reserve = remaining.min(MAX_PREALLOCATION); + src.reserve(to_reserve); + return Ok(None); + } + + self.state = DecodeState::Head; + return Ok(Some(src.split_to(frame_len))); + } + } + } + } +} + +impl Encoder for MessageFrameCodec { + type Error = io::Error; + + fn encode(&mut self, data: Bytes, dst: &mut BytesMut) -> io::Result<()> { + let len = data.len(); + if len > self.max_frame_length { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!( + "frame of length {} exceeds max frame length of {}", + len, self.max_frame_length, + ), + )); + } - if let Some(max_frame_size) = config.max_frame_size() { - builder.max_frame_length(max_frame_size); + dst.reserve(4 + len); + dst.put_u32(len as u32); + dst.extend_from_slice(&data); + Ok(()) } +} - builder.length_field_length(4).big_endian().new_codec() +/// Returns a fully configured message frame codec for writing/reading +/// serialized frames to/from a socket. +pub(crate) fn network_message_frame_codec(config: &Config) -> MessageFrameCodec { + let max_frame_length = config.max_frame_size().unwrap_or(DEFAULT_MAX_FRAME_LENGTH); + MessageFrameCodec::new(max_frame_length) } /// Anemo requires mTLS in order to ensure that both sides of the connections are authenticated by @@ -83,7 +184,7 @@ pub(crate) async fn write_version_frame( } pub(crate) async fn write_request( - send_stream: &mut FramedWrite, + send_stream: &mut FramedWrite, request: Request, ) -> Result<()> { // Write Version Frame @@ -105,7 +206,7 @@ pub(crate) async fn write_request( } pub(crate) async fn write_response( - send_stream: &mut FramedWrite, + send_stream: &mut FramedWrite, response: Response, ) -> Result<()> { // Write Version Frame @@ -129,7 +230,7 @@ pub(crate) async fn write_response( } pub(crate) async fn read_request( - recv_stream: &mut FramedRead, + recv_stream: &mut FramedRead, ) -> Result> { // Read Version Frame let version = read_version_frame(recv_stream.get_mut()).await?; @@ -154,7 +255,7 @@ pub(crate) async fn read_request( } pub(crate) async fn read_response( - recv_stream: &mut FramedRead, + recv_stream: &mut FramedRead, ) -> Result> { // Read Version Frame let version = read_version_frame(recv_stream.get_mut()).await?; @@ -225,3 +326,239 @@ mod test { assert_eq!(HEADER.as_ref(), buf); } } + +#[cfg(test)] +mod message_frame_codec_tests { + use super::{MessageFrameCodec, DEFAULT_MAX_FRAME_LENGTH, MAX_PREALLOCATION}; + use bytes::{BufMut, Bytes, BytesMut}; + use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec}; + + fn new_legacy_codec() -> LengthDelimitedCodec { + LengthDelimitedCodec::builder() + .length_field_length(4) + .big_endian() + .new_codec() + } + + fn legacy_encode(codec: &mut LengthDelimitedCodec, data: &[u8]) -> BytesMut { + let mut buf = BytesMut::new(); + codec + .encode(Bytes::copy_from_slice(data), &mut buf) + .unwrap(); + buf + } + + fn custom_encode(codec: &mut MessageFrameCodec, data: &[u8]) -> BytesMut { + let mut buf = BytesMut::new(); + codec + .encode(Bytes::copy_from_slice(data), &mut buf) + .unwrap(); + buf + } + + #[test] + fn empty_frame_legacy_to_custom() { + let mut enc = new_legacy_codec(); + let wire = legacy_encode(&mut enc, &[]); + + let mut dec = MessageFrameCodec::new(DEFAULT_MAX_FRAME_LENGTH); + let frame = dec.decode(&mut wire.clone()).unwrap().unwrap(); + assert!(frame.is_empty()); + } + + #[test] + fn empty_frame_custom_to_legacy() { + let mut enc = MessageFrameCodec::new(DEFAULT_MAX_FRAME_LENGTH); + let wire = custom_encode(&mut enc, &[]); + + let mut dec = new_legacy_codec(); + let frame = dec.decode(&mut wire.clone()).unwrap().unwrap(); + assert!(frame.is_empty()); + } + + #[test] + fn small_frame_legacy_to_custom() { + let data: Vec = (0..64).collect(); + let mut enc = new_legacy_codec(); + let wire = legacy_encode(&mut enc, &data); + + let mut dec = MessageFrameCodec::new(DEFAULT_MAX_FRAME_LENGTH); + let frame = dec.decode(&mut wire.clone()).unwrap().unwrap(); + assert_eq!(&frame[..], &data); + } + + #[test] + fn small_frame_custom_to_legacy() { + let data: Vec = (0..64).collect(); + let mut enc = MessageFrameCodec::new(DEFAULT_MAX_FRAME_LENGTH); + let wire = custom_encode(&mut enc, &data); + + let mut dec = new_legacy_codec(); + let frame = dec.decode(&mut wire.clone()).unwrap().unwrap(); + assert_eq!(&frame[..], &data); + } + + #[test] + fn medium_frame_around_preallocation_boundary() { + let size = MAX_PREALLOCATION + 1; + let data: Vec = (0..size).map(|i| (i % 256) as u8).collect(); + + // Legacy encode -> custom decode + let mut enc = new_legacy_codec(); + let wire = legacy_encode(&mut enc, &data); + let mut dec = MessageFrameCodec::new(DEFAULT_MAX_FRAME_LENGTH); + let frame = dec.decode(&mut wire.clone()).unwrap().unwrap(); + assert_eq!(&frame[..], &data); + + // Custom encode -> legacy decode + let mut enc = MessageFrameCodec::new(DEFAULT_MAX_FRAME_LENGTH); + let wire = custom_encode(&mut enc, &data); + let mut dec = new_legacy_codec(); + let frame = dec.decode(&mut wire.clone()).unwrap().unwrap(); + assert_eq!(&frame[..], &data); + } + + #[test] + fn large_frame_above_preallocation_limit() { + let size = MAX_PREALLOCATION * 3 + 1; + let data: Vec = (0..size).map(|i| (i % 256) as u8).collect(); + + // Encode with custom codec, decode with custom codec (full buffer available) + let mut enc = MessageFrameCodec::new(DEFAULT_MAX_FRAME_LENGTH); + let wire = custom_encode(&mut enc, &data); + let mut dec = MessageFrameCodec::new(DEFAULT_MAX_FRAME_LENGTH); + let frame = dec.decode(&mut wire.clone()).unwrap().unwrap(); + assert_eq!(frame.len(), size); + assert_eq!(&frame[..], &data); + + // Cross-codec: legacy encode -> custom decode + let mut enc = new_legacy_codec(); + let wire = legacy_encode(&mut enc, &data); + let mut dec = MessageFrameCodec::new(DEFAULT_MAX_FRAME_LENGTH); + let frame = dec.decode(&mut wire.clone()).unwrap().unwrap(); + assert_eq!(&frame[..], &data); + + // Cross-codec: custom encode -> legacy decode + let mut enc = MessageFrameCodec::new(DEFAULT_MAX_FRAME_LENGTH); + let wire = custom_encode(&mut enc, &data); + let mut dec = new_legacy_codec(); + let frame = dec.decode(&mut wire.clone()).unwrap().unwrap(); + assert_eq!(&frame[..], &data); + } + + #[test] + fn max_frame_length_rejection_custom() { + let max = 128; + let data = vec![0u8; max + 1]; + + // Encode should fail + let mut enc = MessageFrameCodec::new(max); + let mut buf = BytesMut::new(); + assert!(enc.encode(Bytes::from(data.clone()), &mut buf).is_err()); + + // Decode should fail when length prefix exceeds max + let mut dec = MessageFrameCodec::new(max); + let mut wire = BytesMut::new(); + wire.put_u32((max + 1) as u32); + assert!(dec.decode(&mut wire).is_err()); + } + + #[test] + fn max_frame_length_rejection_matches_legacy() { + let max = 128; + let data = vec![0u8; max + 1]; + + let mut legacy = LengthDelimitedCodec::builder() + .length_field_length(4) + .big_endian() + .max_frame_length(max) + .new_codec(); + let mut buf = BytesMut::new(); + let legacy_result = legacy.encode(Bytes::from(data.clone()), &mut buf); + + let mut custom = MessageFrameCodec::new(max); + let mut buf = BytesMut::new(); + let custom_result = custom.encode(Bytes::from(data), &mut buf); + + // Both should reject oversized frames + assert!(legacy_result.is_err()); + assert!(custom_result.is_err()); + } + + #[test] + fn incremental_delivery_one_byte_at_a_time() { + let data: Vec = (0..256).map(|i| (i % 256) as u8).collect(); + let mut enc = MessageFrameCodec::new(DEFAULT_MAX_FRAME_LENGTH); + let wire = custom_encode(&mut enc, &data); + + let mut dec = MessageFrameCodec::new(DEFAULT_MAX_FRAME_LENGTH); + let mut src = BytesMut::new(); + + // Feed one byte at a time; should return None until complete + for i in 0..wire.len() - 1 { + src.extend_from_slice(&wire[i..i + 1]); + assert!( + dec.decode(&mut src).unwrap().is_none(), + "should not produce frame at byte {}", + i + ); + } + + // Feed the last byte + src.extend_from_slice(&wire[wire.len() - 1..]); + let frame = dec.decode(&mut src).unwrap().unwrap(); + assert_eq!(&frame[..], &data); + } + + #[test] + fn multiple_frames_in_sequence() { + let data1: Vec = (0..100).collect(); + let data2: Vec = (100..250).collect(); + + let mut enc = MessageFrameCodec::new(DEFAULT_MAX_FRAME_LENGTH); + let mut wire = BytesMut::new(); + enc.encode(Bytes::from(data1.clone()), &mut wire).unwrap(); + enc.encode(Bytes::from(data2.clone()), &mut wire).unwrap(); + + let mut dec = MessageFrameCodec::new(DEFAULT_MAX_FRAME_LENGTH); + let frame1 = dec.decode(&mut wire).unwrap().unwrap(); + assert_eq!(&frame1[..], &data1); + let frame2 = dec.decode(&mut wire).unwrap().unwrap(); + assert_eq!(&frame2[..], &data2); + } + + #[test] + fn preallocation_cap() { + // Create a frame with a large claimed length + let claimed_len: usize = 64 * 1024 * 1024; // 64 MB + let mut wire = BytesMut::new(); + wire.put_u32(claimed_len as u32); + // Only provide a few bytes of actual data — not the full frame + wire.extend_from_slice(&[0u8; 64]); + + let mut dec = MessageFrameCodec::new(claimed_len); + // Decode should return None (not enough data) + assert!(dec.decode(&mut wire).unwrap().is_none()); + + // The buffer capacity should be bounded to at most MAX_PREALLOCATION plus some overhead. + assert!( + wire.capacity() < MAX_PREALLOCATION * 2, + "buffer capacity {} should be bounded", + wire.capacity(), + ); + } + + #[test] + fn wire_format_compatibility() { + // Verify the wire format is identical between legacy and custom codecs + let data = b"hello world"; + + let mut legacy = new_legacy_codec(); + let legacy_wire = legacy_encode(&mut legacy, data); + + let mut custom = MessageFrameCodec::new(DEFAULT_MAX_FRAME_LENGTH); + let custom_wire = custom_encode(&mut custom, data); + + assert_eq!(legacy_wire, custom_wire); + } +} diff --git a/crates/anemo/src/types/mod.rs b/crates/anemo/src/types/mod.rs index ec2a1d63..570f27b9 100644 --- a/crates/anemo/src/types/mod.rs +++ b/crates/anemo/src/types/mod.rs @@ -9,9 +9,10 @@ pub use peer_id::{ConnectionOrigin, Direction, PeerId}; pub use http::Extensions; use quinn::ConnectionError; -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] #[repr(u16)] pub enum Version { + #[default] V1 = 1, } @@ -28,12 +29,6 @@ impl Version { } } -impl Default for Version { - fn default() -> Self { - Self::V1 - } -} - pub type HeaderMap = std::collections::HashMap; pub mod header { diff --git a/crates/anemo/src/types/response.rs b/crates/anemo/src/types/response.rs index 3be27102..7838655c 100644 --- a/crates/anemo/src/types/response.rs +++ b/crates/anemo/src/types/response.rs @@ -48,10 +48,11 @@ impl RawResponseHeader { } } -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] #[repr(u16)] #[non_exhaustive] pub enum StatusCode { + #[default] Success = 200, BadRequest = 400, NotFound = 404, @@ -105,12 +106,6 @@ impl StatusCode { } } -impl Default for StatusCode { - fn default() -> Self { - Self::Success - } -} - impl std::fmt::Display for StatusCode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { use StatusCode::*; From e42c0a9020be56fffbdf921098a68ffd66ea836c Mon Sep 17 00:00:00 2001 From: Andrew Schran Date: Tue, 28 Apr 2026 12:23:46 -0400 Subject: [PATCH 4/4] anemo: tighten Router builder ordering via typestate Move route, add_rpc_service, and merge onto Router; route_layer transitions to Router, on which only further route_layer calls are available. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/anemo/src/lib.rs | 2 +- crates/anemo/src/routing/mod.rs | 147 +++++++++++++++++++++++++------- 2 files changed, 119 insertions(+), 30 deletions(-) diff --git a/crates/anemo/src/lib.rs b/crates/anemo/src/lib.rs index a36093c4..2396fa91 100644 --- a/crates/anemo/src/lib.rs +++ b/crates/anemo/src/lib.rs @@ -12,7 +12,7 @@ pub mod types; pub use config::{Config, QuicConfig}; pub use error::{Error, Result}; pub use network::{Builder, KnownPeers, Network, NetworkRef, Peer}; -pub use routing::Router; +pub use routing::{Router, ServicesOpen, ServicesSealed}; #[doc(inline)] pub use types::{request::Request, response::Response, ConnectionOrigin, Direction, PeerId}; diff --git a/crates/anemo/src/routing/mod.rs b/crates/anemo/src/routing/mod.rs index 137f6058..3189eeeb 100644 --- a/crates/anemo/src/routing/mod.rs +++ b/crates/anemo/src/routing/mod.rs @@ -4,6 +4,7 @@ use std::{ collections::{BTreeMap, HashMap}, convert::Infallible, fmt, + marker::PhantomData, sync::Arc, }; use tower::{ @@ -32,21 +33,52 @@ impl RouteId { } } +/// Marker type for a [`Router`] that is still accepting service registrations. +/// +/// In this state the router exposes [`Router::route`], [`Router::add_rpc_service`], +/// and [`Router::merge`]. Calling [`Router::route_layer`] consumes the router and +/// returns one in the [`ServicesSealed`] state — no further services can be added +/// after that point. +pub struct ServicesOpen; + +/// Marker type for a [`Router`] that has been sealed by applying a layer. +/// +/// In this state the router exposes only [`Router::route_layer`] (to stack +/// additional layers); attempting to add services is a compile error. +pub struct ServicesSealed; + /// The router type for composing handlers and services. -#[derive(Clone)] -pub struct Router { +/// +/// `Router` carries a typestate parameter `S` that distinguishes a router which +/// is still accepting service registrations ([`ServicesOpen`], the default) +/// from one which has had a layer applied and can therefore no longer accept +/// new services ([`ServicesSealed`]). +pub struct Router { routes: BTreeMap, matcher: RouteMatcher, fallback: Route, + _state: PhantomData S>, +} + +impl Clone for Router { + fn clone(&self) -> Self { + Self { + routes: self.routes.clone(), + matcher: self.matcher.clone(), + fallback: self.fallback.clone(), + _state: PhantomData, + } + } } -impl Router { +impl Router { #[allow(clippy::new_without_default)] pub fn new() -> Self { Self { routes: Default::default(), matcher: Default::default(), fallback: Route::new(not_found::NotFound), + _state: PhantomData, } } @@ -64,7 +96,9 @@ impl Router { panic!("Paths must start with a `/`"); } - if ::downcast_ref::(&service).is_some() { + if ::downcast_ref::>(&service).is_some() + || ::downcast_ref::>(&service).is_some() + { panic!("Invalid route: `Router::route` cannot be used with `Router`s.") } @@ -95,18 +129,20 @@ impl Router { self.route(&path, service) } - /// Merge two routers into one. + /// Merge another router's routes into this one. /// /// This is useful for breaking apps into smaller pieces and combining them - /// into one. - pub fn merge(mut self, other: R) -> Self + /// into one. The other router may itself be in either typestate; any layers + /// it has already applied remain baked into the imported routes. + pub fn merge(mut self, other: R) -> Self where - R: Into, + R: Into>, { let Router { routes, matcher, fallback, + _state: _, } = other.into(); for (id, route) in routes { @@ -124,9 +160,10 @@ impl Router { self } - /// Apply a [`tower::Layer`] to the router that will only run if the request matches - /// a route. - pub fn route_layer(self, layer: L) -> Self + /// Apply a [`tower::Layer`] to all currently registered routes, returning a + /// [`Router`]. To stack additional layers, chain further + /// calls to [`Router::route_layer`] on the returned router. + pub fn route_layer(self, layer: L) -> Router where L: tower::Layer, L::Service: Service, Response = Response, Error = Infallible> @@ -139,25 +176,62 @@ impl Router { routes, matcher, fallback, + _state: _, } = self; - let routes = routes - .into_iter() - .map(|(id, route)| { - let route = Route::new(layer.layer(route)); - (id, route) - }) - .collect(); - Router { + routes: layer_routes(routes, layer), + matcher, + fallback, + _state: PhantomData, + } + } +} + +impl Router { + /// Apply an additional [`tower::Layer`] on top of the layers already applied + /// to this router's routes. + pub fn route_layer(self, layer: L) -> Self + where + L: tower::Layer, + L::Service: Service, Response = Response, Error = Infallible> + + Clone + + Send + + 'static, + >>::Future: Send + 'static, + { + let Router { routes, matcher, fallback, + _state: _, + } = self; + + Router { + routes: layer_routes(routes, layer), + matcher, + fallback, + _state: PhantomData, } } } -impl Service> for Router { +fn layer_routes(routes: BTreeMap, layer: L) -> BTreeMap +where + L: tower::Layer, + L::Service: Service, Response = Response, Error = Infallible> + + Clone + + Send + + 'static, + >>::Future: Send + 'static, +{ + routes + .into_iter() + .map(|(id, route)| (id, Route::new(layer.layer(route)))) + .collect() +} + +impl Service> for Router { type Response = Response; type Error = Infallible; type Future = @@ -415,19 +489,17 @@ mod test { } #[tokio::test] - async fn middleware_applies_to_routes_above() { - let pending = + async fn middleware_applies_to_all_routes_registered_before_seal() { + let pending_one = + tower::service_fn(|_request| async { std::future::pending::>().await }); + let pending_two = tower::service_fn(|_request| async { std::future::pending::>().await }); - let ready = tower::service_fn(|_request| async { - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - Ok(Response::new(Bytes::from_static(b"ready!"))) - }); let router = Router::new() - .route("/one", pending) + .route("/one", pending_one) + .route("/two", pending_two) .route_layer(crate::middleware::timeout::inbound::TimeoutLayer::new( Some(std::time::Duration::new(0, 0)), - )) - .route("/two", ready); + )); let request = Request::new(Bytes::new()).with_route("/one"); let response = router.clone().oneshot(request).await.unwrap(); @@ -435,6 +507,23 @@ mod test { let request = Request::new(Bytes::new()).with_route("/two"); let response = router.clone().oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::RequestTimeout); + } + + #[tokio::test] + async fn stacked_layers_compose_on_sealed_router() { + let ready = tower::service_fn(|_request| async { + Ok(Response::new(Bytes::from_static(b"ready!"))) + }); + // The first call to `route_layer` returns Router; the + // second is the impl on the sealed state, stacking another layer. + let router = Router::new() + .route("/one", ready) + .route_layer(tower::layer::util::Identity::new()) + .route_layer(tower::layer::util::Identity::new()); + + let request = Request::new(Bytes::new()).with_route("/one"); + let response = router.oneshot(request).await.unwrap(); assert_eq!(response.status(), StatusCode::Success); }