diff --git a/usb-host/src/backend/kmod/queue.rs b/usb-host/src/backend/kmod/queue.rs index 6461ed5..557f713 100644 --- a/usb-host/src/backend/kmod/queue.rs +++ b/usb-host/src/backend/kmod/queue.rs @@ -1,10 +1,10 @@ use alloc::sync::Arc; -use core::pin::Pin; -use core::task::Context; -use core::task::Poll; use core::{ cell::UnsafeCell, - sync::atomic::{AtomicBool, AtomicUsize, Ordering}, + hint::spin_loop, + pin::Pin, + sync::atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering}, + task::{Context, Poll}, }; use futures::task::AtomicWaker; @@ -35,12 +35,17 @@ impl Clone for Finished { } pub struct FinishedInner { - data: UnsafeCell>>>, + data: BTreeMap>>, } +const SLOT_EMPTY: u8 = 0; +const SLOT_WRITING: u8 = 1; +const SLOT_READY: u8 = 2; +const SLOT_READING: u8 = 3; + pub struct FinishedData { taken: AtomicBool, - finished: AtomicBool, + state: AtomicU8, waker: AtomicWaker, data: UnsafeCell>, } @@ -48,7 +53,7 @@ pub struct FinishedData { impl FinishedData { fn new() -> Self { Self { - finished: AtomicBool::new(false), + state: AtomicU8::new(SLOT_EMPTY), taken: AtomicBool::new(false), waker: AtomicWaker::new(), data: UnsafeCell::new(None), @@ -56,19 +61,16 @@ impl FinishedData { } } -unsafe impl Send for FinishedInner {} -unsafe impl Sync for FinishedInner {} -unsafe impl Send for FinishedData {} -unsafe impl Sync for FinishedData {} +unsafe impl Send for FinishedData {} +unsafe impl Sync for FinishedData {} + +unsafe impl Send for FinishedInner {} +unsafe impl Sync for FinishedInner {} impl FinishedInner { fn clear_finished(&self, addr: BusAddr) { - if let Some(data) = unsafe { &mut *self.data.get() }.get(&addr) { - data.finished.store(false, Ordering::Release); - data.taken.store(false, Ordering::Release); - unsafe { - (*data.data.get()).take(); - } + if let Some(data) = self.data.get(&addr) { + data.clear(); } } } @@ -81,9 +83,7 @@ impl Finished { data.insert(addr, Arc::new(FinishedData::new())); } Self { - inner: Arc::new(FinishedInner { - data: UnsafeCell::new(data), - }), + inner: Arc::new(FinishedInner { data }), } } @@ -92,13 +92,8 @@ impl Finished { } pub fn set_finished(&self, addr: BusAddr, value: C) { - let data = unsafe { &mut *self.inner.data.get() }; - if let Some(slot) = data.get_mut(&addr) { - unsafe { - *slot.data.get() = Some(value); - } - slot.finished.store(true, Ordering::Release); - slot.waker.wake(); + if let Some(slot) = self.inner.data.get(&addr) { + slot.set_finished(value); } else if take_queue_log_budget() { warn!( "usb queue: completion address {:#x} is not registered", @@ -112,8 +107,7 @@ impl Finished { } fn waiter(&self, addr: BusAddr) -> &FinishedData { - let data = unsafe { &mut *self.inner.data.get() }; - let slot = data.get(&addr).unwrap(); + let slot = self.inner.data.get(&addr).unwrap(); if slot.taken.load(Ordering::Acquire) { panic!("waiter called after take_waiter"); } @@ -126,7 +120,7 @@ impl Finished { } pub fn take_waiter(&self, addr: BusAddr) -> TWaiter { - let data = unsafe { &mut *self.inner.data.get() }.get(&addr).unwrap(); + let data = self.inner.data.get(&addr).unwrap(); if data.taken.swap(true, Ordering::AcqRel) { panic!("take_waiter called multiple times for the same addr"); } @@ -161,10 +155,76 @@ impl FinishedData { self.waker.register(waker); } + fn clear(&self) { + loop { + match self.state.load(Ordering::Acquire) { + SLOT_EMPTY => return, + SLOT_READY => { + if self + .state + .compare_exchange( + SLOT_READY, + SLOT_READING, + Ordering::AcqRel, + Ordering::Acquire, + ) + .is_ok() + { + unsafe { + (*self.data.get()).take(); + } + self.state.store(SLOT_EMPTY, Ordering::Release); + return; + } + } + SLOT_WRITING | SLOT_READING => spin_loop(), + _ => { + self.state.store(SLOT_EMPTY, Ordering::Release); + return; + } + } + } + } + + pub fn set_finished(&self, value: C) { + if self + .state + .compare_exchange( + SLOT_EMPTY, + SLOT_WRITING, + Ordering::AcqRel, + Ordering::Acquire, + ) + .is_err() + { + if take_queue_log_budget() { + warn!("usb queue: dropping duplicate completion for busy slot"); + } + return; + } + + unsafe { + *self.data.get() = Some(value); + } + self.state.store(SLOT_READY, Ordering::Release); + self.waker.wake(); + } + pub fn get_finished(&self) -> Option { - if !self.finished.load(Ordering::Acquire) { + if self + .state + .compare_exchange( + SLOT_READY, + SLOT_READING, + Ordering::AcqRel, + Ordering::Acquire, + ) + .is_err() + { return None; } - unsafe { (*self.data.get()).take() } + let value = unsafe { (*self.data.get()).take() }; + self.state.store(SLOT_EMPTY, Ordering::Release); + value } } diff --git a/usb-host/src/backend/kmod/xhci/endpoint.rs b/usb-host/src/backend/kmod/xhci/endpoint.rs index 9fda014..5c11453 100644 --- a/usb-host/src/backend/kmod/xhci/endpoint.rs +++ b/usb-host/src/backend/kmod/xhci/endpoint.rs @@ -85,6 +85,7 @@ enum SubmittedTdKind { #[derive(Clone, Copy)] struct IsoPacketTd { trb: TransferId, + final_packet: bool, event: Option, actual: Option, } @@ -365,21 +366,21 @@ impl Endpoint { &mut self, id: RequestId, request_id: EndpointRequestId, - packets: &[IsoPacketTd], ) -> Option> { - for (index, packet) in packets.iter().copied().enumerate() { + let packet_count = self.iso_packet_count(request_id)?; + for index in 0..packet_count { if self.iso_packet_done(request_id, index) { continue; } + let (packet_trb, requested) = self.iso_packet_info(request_id, index)?; - let Some(event) = self.ring.get_finished(packet.trb.0) else { + let Some(event) = self.ring.get_finished(packet_trb.0) else { continue; }; - let requested = self.iso_requested_length(request_id, index)?; let actual = match iso_packet_actual_length(requested, event) { Ok(actual) => actual, Err(err) => { - let cleanup_result = self.complete_request(request_id, packet.trb, event); + let cleanup_result = self.complete_request(request_id, packet_trb, event); let result = match cleanup_result { Ok(_) => Err(err), Err(cleanup_err) => Err(cleanup_err), @@ -389,10 +390,11 @@ impl Endpoint { }; let fatal = iso_packet_is_fatal(event); - let all_completed = self.record_iso_packet(request_id, index, event, actual)?; - if fatal || all_completed { + let should_complete = + self.record_iso_packet(request_id, index, event, actual, fatal)?; + if should_complete { return Some( - self.complete_request(request_id, packet.trb, event) + self.complete_request(request_id, packet_trb, event) .map(|transfer| transfer_to_completion(id, transfer)), ); } @@ -400,6 +402,15 @@ impl Endpoint { None } + fn iso_packet_count(&self, request_id: EndpointRequestId) -> Option { + self.inflight + .get(&request_id) + .and_then(|submitted| match &submitted.kind { + SubmittedTdKind::Iso { packets } => Some(packets.len()), + _ => None, + }) + } + fn iso_packet_done(&self, request_id: EndpointRequestId, index: usize) -> bool { self.inflight .get(&request_id) @@ -412,13 +423,23 @@ impl Endpoint { .unwrap_or(true) } - fn iso_requested_length(&self, request_id: EndpointRequestId, index: usize) -> Option { - self.inflight - .get(&request_id) - .and_then(|submitted| match &submitted.transfer.kind { - TransferKind::Isochronous { packet_lengths } => packet_lengths.get(index).copied(), - _ => None, - }) + fn iso_packet_info( + &self, + request_id: EndpointRequestId, + index: usize, + ) -> Option<(TransferId, usize)> { + let submitted = self.inflight.get(&request_id)?; + let SubmittedTdKind::Iso { packets } = &submitted.kind else { + return None; + }; + let packet = packets.get(index)?; + let requested = match &submitted.transfer.kind { + TransferKind::Isochronous { packet_lengths } => { + packet_lengths.get(index).copied().unwrap_or(0) + } + _ => return None, + }; + Some((packet.trb, requested)) } fn record_iso_packet( @@ -427,21 +448,25 @@ impl Endpoint { index: usize, event: TransferEvent, actual: usize, + fatal: bool, ) -> Option { let submitted = self.inflight.get_mut(&request_id)?; let SubmittedTdKind::Iso { packets } = &mut submitted.kind else { return None; }; - for packet in packets.iter_mut().take(index) { - if packet.actual.is_none() { - packet.actual = Some(0); + let final_packet = packets.get(index).is_some_and(|packet| packet.final_packet); + if final_packet || fatal { + for packet in packets.iter_mut().take(index) { + if packet.actual.is_none() { + packet.actual = Some(0); + } } } if let Some(packet) = packets.get_mut(index) { packet.event = Some(event); packet.actual = Some(actual); } - Some(packets.iter().all(|packet| packet.actual.is_some())) + Some(final_packet || fatal || packets.iter().all(|packet| packet.actual.is_some())) } fn enque_trb(&mut self, trb: transfer::Allowed) -> TransferId { @@ -473,6 +498,7 @@ impl Endpoint { ); packets.push(IsoPacketTd { trb, + final_packet: last_packet, event: None, actual: None, }); @@ -746,7 +772,7 @@ impl EndpointOp for Endpoint { SubmittedTdKind::Control(control_td) => { self.reclaim_control_request(id, request_id, control_td) } - SubmittedTdKind::Iso { packets } => self.reclaim_iso_request(id, request_id, &packets), + SubmittedTdKind::Iso { .. } => self.reclaim_iso_request(id, request_id), } } diff --git a/usb-host/src/backend/kmod/xhci/event.rs b/usb-host/src/backend/kmod/xhci/event.rs index 3c46dcd..27eeae2 100644 --- a/usb-host/src/backend/kmod/xhci/event.rs +++ b/usb-host/src/backend/kmod/xhci/event.rs @@ -25,7 +25,7 @@ pub struct EventRing { unsafe impl Send for EventRing {} unsafe impl Sync for EventRing {} -const EVENT_RING_SEGMENTS: usize = 16; +const EVENT_RING_SEGMENTS: usize = 2; impl EventRing { pub fn new(max_segments: usize, dma: &Kernel) -> Result {