Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/agent_core/src/network/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pub mod origin_lookup;
pub mod proxy_protocol;
pub mod tcp;
pub mod udp;
pub mod upload_qos;
23 changes: 22 additions & 1 deletion packages/agent_core/src/network/tcp/tcp_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use tokio_util::sync::CancellationToken;
use crate::stats::AgentStats;

use super::tcp_pipe::{PipeDirection, TcpPipe};
use crate::network::upload_qos::UploadFairness;

pub struct TcpClient {
tunn_to_origin: TcpPipe,
Expand All @@ -20,11 +21,30 @@ impl TcpClient {
tunn: TcpStream,
origin: TcpStream,
stats: Option<AgentStats>,
) -> Self {
Self::create_with_stats_and_upload_flow(tunn, origin, stats, None).await
}

pub(super) async fn create_with_stats_and_upload_qos(
tunn: TcpStream,
origin: TcpStream,
stats: Option<AgentStats>,
upload_fairness: UploadFairness,
) -> Self {
Self::create_with_stats_and_upload_flow(tunn, origin, stats, Some(upload_fairness)).await
}

async fn create_with_stats_and_upload_flow(
tunn: TcpStream,
origin: TcpStream,
stats: Option<AgentStats>,
upload_fairness: Option<UploadFairness>,
) -> Self {
let (tunn_read, tunn_write) = tunn.into_split();
let (origin_read, origin_write) = origin.into_split();

let cancel = CancellationToken::new();
let upload_flow = upload_fairness.map(|fairness| fairness.register());

TcpClient {
tunn_to_origin: TcpPipe::new_with_stats(
Expand All @@ -34,12 +54,13 @@ impl TcpClient {
stats.clone(),
PipeDirection::TunnelToOrigin,
),
origin_to_tunn: TcpPipe::new_with_stats(
origin_to_tunn: TcpPipe::new_with_stats_and_upload_flow(
cancel,
origin_read,
tunn_write,
stats,
PipeDirection::OriginToTunnel,
upload_flow,
),
}
}
Expand Down
18 changes: 14 additions & 4 deletions packages/agent_core/src/network/tcp/tcp_clients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use tokio_util::sync::CancellationToken;
use crate::{
network::{
lan_address::LanAddress, origin_lookup::OriginLookup, proxy_protocol::ProxyProtocolHeader,
upload_qos::UploadFairness,
},
stats::AgentStats,
utils::now_milli,
Expand Down Expand Up @@ -52,6 +53,7 @@ struct Worker {
cancel: CancellationToken,
settings: TcpSettings,
stats: AgentStats,
upload_fairness: UploadFairness,

clients: Vec<Client>,
next_client_id: u64,
Expand Down Expand Up @@ -110,19 +112,22 @@ impl TcpClients {
lookup: Arc<OriginLookup>,
stats: AgentStats,
cancel: CancellationToken,
upload_fairness: UploadFairness,
) -> Self {
let quota = build_quota(&settings);
let (events_tx, events_rx) = channel(1024);
let worker_cancel = cancel.child_token();

tokio::spawn(
Worker {
next_client_id: 1,
lookup,
events: events_rx,
events_tx: events_tx.clone(),
cancel: cancel.child_token(),
cancel: worker_cancel,
settings,
stats,
upload_fairness,
clients: Vec::with_capacity(32),
}
.start(),
Expand Down Expand Up @@ -232,6 +237,7 @@ impl Worker {

let event_tx = self.events_tx.clone();
let stats = self.stats.clone();
let upload_fairness = self.upload_fairness.clone();
let cancel = self.cancel.child_token();
tokio::spawn(async move {
let Some(origin_addr) = found.resolve_local(details.port_offset).await
Expand Down Expand Up @@ -395,9 +401,13 @@ impl Worker {
}
}

let tcp_client =
TcpClient::create_with_stats(tunn_stream, origin_stream, Some(stats))
.await;
let tcp_client = TcpClient::create_with_stats_and_upload_qos(
tunn_stream,
origin_stream,
Some(stats),
upload_fairness,
)
.await;
let event = Event::ConnectedClient(Client {
id: client_id,
added_at: now_milli(),
Expand Down
103 changes: 101 additions & 2 deletions packages/agent_core/src/network/tcp/tcp_pipe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use tokio_util::sync::CancellationToken;
use crate::stats::AgentStats;
use crate::utils::now_milli;

use crate::network::upload_qos::{UPLOAD_QOS_SLICE_SIZE, UploadFlow};

const TCP_PIPE_BUFFER_SIZE: usize = 16 * 1024;

/// Direction of data flow for stats tracking
Expand Down Expand Up @@ -58,6 +60,20 @@ impl TcpPipe {
to: W,
stats: Option<AgentStats>,
direction: PipeDirection,
) -> Self {
Self::new_with_stats_and_upload_flow(cancel, from, to, stats, direction, None)
}

pub(super) fn new_with_stats_and_upload_flow<
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
>(
cancel: CancellationToken,
from: R,
to: W,
stats: Option<AgentStats>,
direction: PipeDirection,
upload_flow: Option<UploadFlow>,
) -> Self {
let shared = Arc::new(Shared {
last_activity: AtomicU64::new(now_milli()),
Expand All @@ -74,6 +90,7 @@ impl TcpPipe {
to,
stats,
direction,
upload_flow,
}
.start(),
);
Expand Down Expand Up @@ -113,11 +130,17 @@ struct Worker<R: AsyncRead + Unpin, W: AsyncWrite + Unpin> {
to: W,
stats: Option<AgentStats>,
direction: PipeDirection,
upload_flow: Option<UploadFlow>,
}

impl<R: AsyncRead + Unpin, W: AsyncWrite + Unpin> Worker<R, W> {
pub async fn start(mut self) {
let mut buffer = vec![0u8; TCP_PIPE_BUFFER_SIZE];
let buffer_size = if self.upload_flow.is_some() {
UPLOAD_QOS_SLICE_SIZE
} else {
TCP_PIPE_BUFFER_SIZE
};
let mut buffer = vec![0u8; buffer_size];

loop {
// Keep the pipe cooperative when both sockets stay continuously ready.
Expand Down Expand Up @@ -151,7 +174,23 @@ impl<R: AsyncRead + Unpin, W: AsyncWrite + Unpin> Worker<R, W> {
break;
}

if let Err(error) = self.to.write_all(&buffer[..byte_count]).await {
if let Some(upload_flow) = &self.upload_flow
&& !upload_flow.acquire(byte_count, &self.cancel).await
{
tracing::info!("TcpPipe upload QoS acquire failed");
break;
}

let Some(write_res) = self
.cancel
.run_until_cancelled(self.to.write_all(&buffer[..byte_count]))
.await
else {
tracing::info!("TcpPipe cancelled");
break;
};

if let Err(error) = write_res {
tracing::error!(?error, "failed to write data");
break;
}
Expand All @@ -176,3 +215,63 @@ impl<R: AsyncRead + Unpin, W: AsyncWrite + Unpin> Worker<R, W> {
self.shared.last_activity.store(u64::MAX, Ordering::Release);
}
}

#[cfg(test)]
mod tests {
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
time::{Duration, timeout},
};

use super::*;
use crate::network::upload_qos::{UPLOAD_QOS_SLICE_SIZE, UploadFairness};

async fn wait_for_bytes(pipe: &TcpPipe, expected: u64) {
timeout(Duration::from_secs(1), async {
while pipe.bytes_written() != expected {
tokio::time::sleep(Duration::from_millis(5)).await;
}
})
.await
.expect("pipe did not reach expected byte count");
}

#[tokio::test]
async fn upload_qos_pipe_reads_in_qos_sized_slices() {
let cancel = CancellationToken::new();
let fairness = UploadFairness::new(cancel.child_token());
let upload_flow = fairness.register();
let (mut source_write, source_read) = tokio::io::duplex(UPLOAD_QOS_SLICE_SIZE * 4);
let (sink_write, mut sink_read) = tokio::io::duplex(UPLOAD_QOS_SLICE_SIZE);

let pipe = TcpPipe::new_with_stats_and_upload_flow(
cancel.clone(),
source_read,
sink_write,
None,
PipeDirection::OriginToTunnel,
Some(upload_flow),
);

let payload = vec![7u8; UPLOAD_QOS_SLICE_SIZE * 2];
let write_task = tokio::spawn(async move {
source_write
.write_all(&payload)
.await
.expect("source write should succeed");
});

wait_for_bytes(&pipe, UPLOAD_QOS_SLICE_SIZE as u64).await;

let mut read_buffer = vec![0u8; UPLOAD_QOS_SLICE_SIZE];
sink_read
.read_exact(&mut read_buffer)
.await
.expect("sink read should succeed");

wait_for_bytes(&pipe, (UPLOAD_QOS_SLICE_SIZE * 2) as u64).await;
write_task.await.expect("source writer should finish");

cancel.cancel();
}
}
28 changes: 23 additions & 5 deletions packages/agent_core/src/network/udp/udp_clients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@ use tokio::{
net::UdpSocket,
sync::mpsc::{Receiver, channel},
};
use tokio_util::sync::CancellationToken;

use crate::network::{
lan_address::LanAddress, origin_lookup::OriginLookup, proxy_protocol::ProxyProtocolHeader,
lan_address::LanAddress,
origin_lookup::OriginLookup,
proxy_protocol::ProxyProtocolHeader,
upload_qos::{UploadFairness, UploadFlow},
};
use crate::stats::AgentStats;
use playit_agent_proto::udp_proto::UdpFlow;
Expand All @@ -40,6 +44,8 @@ pub struct UdpClients {

new_client_limiter: DefaultDirectRateLimiter,
stats: AgentStats,
cancel: CancellationToken,
upload_fairness: UploadFairness,
}

struct Client {
Expand All @@ -49,6 +55,7 @@ struct Client {
target_addr: SocketAddr,
port_offset: u16,
flow: UdpFlow,
upload_flow: UploadFlow,

/* when dropped, rx task get killed */
receiver: UdpReceiver,
Expand Down Expand Up @@ -100,6 +107,8 @@ impl UdpClients {
lookup: Arc<OriginLookup>,
packets: Packets,
stats: AgentStats,
cancel: CancellationToken,
upload_fairness: UploadFairness,
) -> Self {
let (origin_tx, origin_rx) = channel(2048);

Expand All @@ -115,6 +124,8 @@ impl UdpClients {
rx: origin_rx,
new_client_limiter: RateLimiter::direct(build_quota(&settings)),
stats,
cancel,
upload_fairness,
}
}

Expand Down Expand Up @@ -186,10 +197,6 @@ impl UdpClients {

client.from_origin_ts = now_ms;

// Track bytes going out (from origin to tunnel)
let packet_len = packet.packet.len() as u64;
self.stats.add_bytes_out(packet_len);

let mut flow = client.flow;
match &mut flow {
UdpFlow::V4 {
Expand All @@ -211,6 +218,16 @@ impl UdpClients {
_ => unreachable!(),
}

let packet_len = packet.packet.len();
let upload_len = packet_len + flow.footer_len();
if !client.upload_flow.acquire(upload_len, &self.cancel).await {
tracing::info!("UDP upload QoS acquire failed");
return None;
}

// Track bytes going out (from origin to tunnel)
self.stats.add_bytes_out(packet_len as u64);

Some((flow, packet.packet))
}

Expand Down Expand Up @@ -324,6 +341,7 @@ impl UdpClients {
port_offset: extension.port_offset,
receiver,
flow: client_flow,
upload_flow: self.upload_fairness.register(),
from_tunnel_ts: now_ms,
from_origin_ts: now_ms,
};
Expand Down
Loading