Skip to content
Merged
124 changes: 92 additions & 32 deletions usb-host/src/backend/kmod/queue.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use alloc::sync::Arc;
use core::pin::Pin;
use core::task::Context;
use core::task::Poll;
use core::{
cell::UnsafeCell,
sync::atomic::{AtomicBool, AtomicUsize, Ordering},
hint::spin_loop,
pin::Pin,
sync::atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering},
task::{Context, Poll},
};
use futures::task::AtomicWaker;

Expand Down Expand Up @@ -35,40 +35,42 @@ impl<C> Clone for Finished<C> {
}

pub struct FinishedInner<C> {
data: UnsafeCell<BTreeMap<BusAddr, Arc<FinishedData<C>>>>,
data: BTreeMap<BusAddr, Arc<FinishedData<C>>>,
}

const SLOT_EMPTY: u8 = 0;
const SLOT_WRITING: u8 = 1;
const SLOT_READY: u8 = 2;
const SLOT_READING: u8 = 3;

pub struct FinishedData<C> {
taken: AtomicBool,
finished: AtomicBool,
state: AtomicU8,
waker: AtomicWaker,
data: UnsafeCell<Option<C>>,
}

impl<C> FinishedData<C> {
fn new() -> Self {
Self {
finished: AtomicBool::new(false),
state: AtomicU8::new(SLOT_EMPTY),
taken: AtomicBool::new(false),
waker: AtomicWaker::new(),
data: UnsafeCell::new(None),
}
}
}

unsafe impl<C> Send for FinishedInner<C> {}
unsafe impl<C> Sync for FinishedInner<C> {}
unsafe impl<C> Send for FinishedData<C> {}
unsafe impl<C> Sync for FinishedData<C> {}
unsafe impl<C: Send> Send for FinishedData<C> {}
unsafe impl<C: Send> Sync for FinishedData<C> {}

unsafe impl<C: Send> Send for FinishedInner<C> {}
unsafe impl<C: Send> Sync for FinishedInner<C> {}

impl<C> FinishedInner<C> {
fn clear_finished(&self, addr: BusAddr) {
if let Some(data) = unsafe { &mut *self.data.get() }.get(&addr) {
data.finished.store(false, Ordering::Release);
data.taken.store(false, Ordering::Release);
unsafe {
(*data.data.get()).take();
}
if let Some(data) = self.data.get(&addr) {
data.clear();
}
}
}
Expand All @@ -81,9 +83,7 @@ impl<C> Finished<C> {
data.insert(addr, Arc::new(FinishedData::new()));
}
Self {
inner: Arc::new(FinishedInner {
data: UnsafeCell::new(data),
}),
inner: Arc::new(FinishedInner { data }),
}
}

Expand All @@ -92,13 +92,8 @@ impl<C> Finished<C> {
}

pub fn set_finished(&self, addr: BusAddr, value: C) {
let data = unsafe { &mut *self.inner.data.get() };
if let Some(slot) = data.get_mut(&addr) {
unsafe {
*slot.data.get() = Some(value);
}
slot.finished.store(true, Ordering::Release);
slot.waker.wake();
if let Some(slot) = self.inner.data.get(&addr) {
slot.set_finished(value);
} else if take_queue_log_budget() {
warn!(
"usb queue: completion address {:#x} is not registered",
Expand All @@ -112,8 +107,7 @@ impl<C> Finished<C> {
}

fn waiter(&self, addr: BusAddr) -> &FinishedData<C> {
let data = unsafe { &mut *self.inner.data.get() };
let slot = data.get(&addr).unwrap();
let slot = self.inner.data.get(&addr).unwrap();
if slot.taken.load(Ordering::Acquire) {
panic!("waiter called after take_waiter");
}
Expand All @@ -126,7 +120,7 @@ impl<C> Finished<C> {
}

pub fn take_waiter(&self, addr: BusAddr) -> TWaiter<C> {
let data = unsafe { &mut *self.inner.data.get() }.get(&addr).unwrap();
let data = self.inner.data.get(&addr).unwrap();
if data.taken.swap(true, Ordering::AcqRel) {
panic!("take_waiter called multiple times for the same addr");
}
Expand Down Expand Up @@ -161,10 +155,76 @@ impl<C> FinishedData<C> {
self.waker.register(waker);
}

fn clear(&self) {
loop {
match self.state.load(Ordering::Acquire) {
SLOT_EMPTY => return,
SLOT_READY => {
if self
.state
.compare_exchange(
SLOT_READY,
SLOT_READING,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
{
unsafe {
(*self.data.get()).take();
}
self.state.store(SLOT_EMPTY, Ordering::Release);
return;
}
}
SLOT_WRITING | SLOT_READING => spin_loop(),
_ => {
self.state.store(SLOT_EMPTY, Ordering::Release);
return;
}
}
}
}

pub fn set_finished(&self, value: C) {
if self
.state
.compare_exchange(
SLOT_EMPTY,
SLOT_WRITING,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_err()
{
if take_queue_log_budget() {
warn!("usb queue: dropping duplicate completion for busy slot");
}
return;
}

unsafe {
*self.data.get() = Some(value);
}
self.state.store(SLOT_READY, Ordering::Release);
self.waker.wake();
}

pub fn get_finished(&self) -> Option<C> {
if !self.finished.load(Ordering::Acquire) {
if self
.state
.compare_exchange(
SLOT_READY,
SLOT_READING,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_err()
{
return None;
}
unsafe { (*self.data.get()).take() }
let value = unsafe { (*self.data.get()).take() };
self.state.store(SLOT_EMPTY, Ordering::Release);
value
}
}
66 changes: 46 additions & 20 deletions usb-host/src/backend/kmod/xhci/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ enum SubmittedTdKind {
#[derive(Clone, Copy)]
struct IsoPacketTd {
trb: TransferId,
final_packet: bool,
event: Option<TransferEvent>,
actual: Option<usize>,
}
Expand Down Expand Up @@ -365,21 +366,21 @@ impl Endpoint {
&mut self,
id: RequestId,
request_id: EndpointRequestId,
packets: &[IsoPacketTd],
) -> Option<Result<TransferCompletion, TransferError>> {
for (index, packet) in packets.iter().copied().enumerate() {
let packet_count = self.iso_packet_count(request_id)?;
for index in 0..packet_count {
if self.iso_packet_done(request_id, index) {
continue;
}
let (packet_trb, requested) = self.iso_packet_info(request_id, index)?;

let Some(event) = self.ring.get_finished(packet.trb.0) else {
let Some(event) = self.ring.get_finished(packet_trb.0) else {
continue;
};
let requested = self.iso_requested_length(request_id, index)?;
let actual = match iso_packet_actual_length(requested, event) {
Ok(actual) => actual,
Err(err) => {
let cleanup_result = self.complete_request(request_id, packet.trb, event);
let cleanup_result = self.complete_request(request_id, packet_trb, event);
let result = match cleanup_result {
Ok(_) => Err(err),
Err(cleanup_err) => Err(cleanup_err),
Expand All @@ -389,17 +390,27 @@ impl Endpoint {
};

let fatal = iso_packet_is_fatal(event);
let all_completed = self.record_iso_packet(request_id, index, event, actual)?;
if fatal || all_completed {
let should_complete =
self.record_iso_packet(request_id, index, event, actual, fatal)?;
if should_complete {
return Some(
self.complete_request(request_id, packet.trb, event)
self.complete_request(request_id, packet_trb, event)
.map(|transfer| transfer_to_completion(id, transfer)),
);
}
}
None
}

fn iso_packet_count(&self, request_id: EndpointRequestId) -> Option<usize> {
self.inflight
.get(&request_id)
.and_then(|submitted| match &submitted.kind {
SubmittedTdKind::Iso { packets } => Some(packets.len()),
_ => None,
})
}

fn iso_packet_done(&self, request_id: EndpointRequestId, index: usize) -> bool {
self.inflight
.get(&request_id)
Expand All @@ -412,13 +423,23 @@ impl Endpoint {
.unwrap_or(true)
}

fn iso_requested_length(&self, request_id: EndpointRequestId, index: usize) -> Option<usize> {
self.inflight
.get(&request_id)
.and_then(|submitted| match &submitted.transfer.kind {
TransferKind::Isochronous { packet_lengths } => packet_lengths.get(index).copied(),
_ => None,
})
fn iso_packet_info(
&self,
request_id: EndpointRequestId,
index: usize,
) -> Option<(TransferId, usize)> {
let submitted = self.inflight.get(&request_id)?;
let SubmittedTdKind::Iso { packets } = &submitted.kind else {
return None;
};
let packet = packets.get(index)?;
let requested = match &submitted.transfer.kind {
TransferKind::Isochronous { packet_lengths } => {
packet_lengths.get(index).copied().unwrap_or(0)
}
_ => return None,
};
Some((packet.trb, requested))
}

fn record_iso_packet(
Expand All @@ -427,21 +448,25 @@ impl Endpoint {
index: usize,
event: TransferEvent,
actual: usize,
fatal: bool,
) -> Option<bool> {
let submitted = self.inflight.get_mut(&request_id)?;
let SubmittedTdKind::Iso { packets } = &mut submitted.kind else {
return None;
};
for packet in packets.iter_mut().take(index) {
if packet.actual.is_none() {
packet.actual = Some(0);
let final_packet = packets.get(index).is_some_and(|packet| packet.final_packet);
if final_packet || fatal {
for packet in packets.iter_mut().take(index) {
if packet.actual.is_none() {
packet.actual = Some(0);
}
}
}
if let Some(packet) = packets.get_mut(index) {
packet.event = Some(event);
packet.actual = Some(actual);
}
Some(packets.iter().all(|packet| packet.actual.is_some()))
Some(final_packet || fatal || packets.iter().all(|packet| packet.actual.is_some()))
}

fn enque_trb(&mut self, trb: transfer::Allowed) -> TransferId {
Expand Down Expand Up @@ -473,6 +498,7 @@ impl Endpoint {
);
packets.push(IsoPacketTd {
trb,
final_packet: last_packet,
event: None,
actual: None,
});
Expand Down Expand Up @@ -746,7 +772,7 @@ impl EndpointOp for Endpoint {
SubmittedTdKind::Control(control_td) => {
self.reclaim_control_request(id, request_id, control_td)
}
SubmittedTdKind::Iso { packets } => self.reclaim_iso_request(id, request_id, &packets),
SubmittedTdKind::Iso { .. } => self.reclaim_iso_request(id, request_id),
}
}

Expand Down
2 changes: 1 addition & 1 deletion usb-host/src/backend/kmod/xhci/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub struct EventRing {
unsafe impl Send for EventRing {}
unsafe impl Sync for EventRing {}

const EVENT_RING_SEGMENTS: usize = 16;
const EVENT_RING_SEGMENTS: usize = 2;

impl EventRing {
pub fn new(max_segments: usize, dma: &Kernel) -> Result<Self> {
Expand Down
Loading