diff --git a/tests/common.py b/tests/common.py index 4420b17..8fbe225 100644 --- a/tests/common.py +++ b/tests/common.py @@ -84,6 +84,7 @@ def __init__(self) -> None: self.hw_address = t.EUI64.convert("11:22:33:44:55:66:77:88") self.handlers: dict[str, Callable[[Any, int], Awaitable[Any]]] = { "ping": self.on_ping, + "reset": self.on_status, "configure": self.on_configure, "get_network_info": self.on_get_network_info, "get_hw_address": self.on_get_hw_address, @@ -159,6 +160,38 @@ async def send_event_data( {"type": "event", "id": request_id, "event": event, "data": data} ) + async def send_confirm( + self, request_id: int, *, next_hop: str | None = None, reason: str | None = None + ) -> None: + if reason is not None: + data: dict[str, Any] = { + "id": request_id, + "status": "failed", + "reason": reason, + } + else: + data = {"id": request_id, "status": "confirmed", "next_hop": next_hop} + + await self.ws.send_json( + {"type": "notification", "event": "send_confirm", "data": data} + ) + + async def aps_ack_confirm( + self, request_id: int, *, reason: str | None = None + ) -> None: + if reason is not None: + data: dict[str, Any] = { + "id": request_id, + "status": "failed", + "reason": reason, + } + else: + data = {"id": request_id, "status": "confirmed"} + + await self.ws.send_json( + {"type": "notification", "event": "aps_ack_confirm", "data": data} + ) + async def send_notification(self, notification: commands.Notification) -> None: await self.ws.send_json( { @@ -208,8 +241,10 @@ async def on_get_hw_address( async def on_send_aps( self, command: commands.SendAps, request_id: int ) -> commands.Status: - await self.send_event(request_id, "transmitted") - return commands.Status(status="delivered" if command.aps_ack else "sent") + await self.send_confirm(request_id) + if command.aps_ack: + await self.aps_ack_confirm(request_id) + return commands.Status(status="accepted") async def on_energy_scan( self, command: commands.EnergyScan, request_id: int diff --git a/tests/test_api.py b/tests/test_api.py index 58be95d..ed8f85d 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -2,6 +2,7 @@ import asyncio from collections.abc import AsyncIterator +from dataclasses import replace import pytest from zigpy.exceptions import DeliveryError @@ -62,43 +63,49 @@ async def fail(command: commands.Ping, request_id: int) -> commands.Status: await api.request(commands.Ping()) -async def test_request_transmitted( +async def test_request_confirmed(api: RecordingApi, server: SyntheticZiggurat) -> None: + """An APS-ack send resolves once the end-to-end APS ack arrives.""" + await api.request_confirmed(SEND_APS) + assert server.sent(commands.SendAps)[-1].aps_seq == 55 + + +async def test_request_confirmed_next_hop( api: RecordingApi, server: SyntheticZiggurat ) -> None: - await api.request_transmitted(SEND_APS) - assert server.sent(commands.SendAps)[-1].aps_seq == 55 + """A no-ack unicast resolves on the local handoff.""" + await api.request_confirmed(replace(SEND_APS, aps_ack=False)) -async def test_request_transmitted_failure_before_transmission( +async def test_request_confirmed_rejected( api: RecordingApi, server: SyntheticZiggurat ) -> None: + """Stage two: the stack rejects the frame, so the send raises before any confirm.""" + async def fail(command: commands.SendAps, request_id: int) -> commands.Status: raise RpcError("transmit_failed", "channel busy") server.handlers["send_aps"] = fail with pytest.raises(DeliveryError, match="transmit_failed"): - await api.request_transmitted(SEND_APS) + await api.request_confirmed(SEND_APS) -async def test_late_delivery_failure_is_logged( - api: RecordingApi, server: SyntheticZiggurat, caplog: pytest.LogCaptureFixture +async def test_request_confirmed_failure( + api: RecordingApi, server: SyntheticZiggurat ) -> None: + """The frame is handed off but the end-to-end APS ack never arrives.""" + async def ack_timeout( command: commands.SendAps, request_id: int ) -> commands.Status: - await server.send_event(request_id, "transmitted") - raise RpcError("aps_ack_timeout", "no ack") + await server.send_confirm(request_id) + await server.aps_ack_confirm(request_id, reason="APS ack timed out") + return commands.Status(status="accepted") server.handlers["send_aps"] = ack_timeout - # Resolves at the `transmitted` stage; the terminal failure arrives later and is - # logged instead of raised - await api.request_transmitted(SEND_APS) - - async with asyncio.timeout(1): - while "Delivery failed after transmission" not in caplog.text: - await asyncio.sleep(0.01) + with pytest.raises(DeliveryError, match="APS ack timed out"): + await api.request_confirmed(SEND_APS) async def test_unsolicited_messages_are_ignored( @@ -106,11 +113,11 @@ async def test_unsolicited_messages_are_ignored( ) -> None: await server.send_raw("not json") await server.send_raw('{"type": "response", "id": 9999, "result": {}}') - await server.send_raw('{"type": "event", "id": 9999, "event": "transmitted"}') + await server.send_raw('{"type": "event", "id": 9999, "event": "spurious"}') - # A `transmitted` event for a request that did not ask for one + # An unknown event for an in-flight request is ignored (only stream results match) async def eager(command: commands.Ping, request_id: int) -> commands.Status: - await server.send_event(request_id, "transmitted") + await server.send_event(request_id, "spurious") return commands.Status(status="pong") server.handlers["ping"] = eager diff --git a/zigpy_ziggurat/zigbee/application.py b/zigpy_ziggurat/zigbee/application.py index e48fa22..4111a18 100644 --- a/zigpy_ziggurat/zigbee/application.py +++ b/zigpy_ziggurat/zigbee/application.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import asyncio from collections.abc import AsyncGenerator, Callable +from datetime import datetime, timezone import json import logging import math @@ -13,6 +16,7 @@ import zigpy.device import zigpy.endpoint from zigpy.exceptions import DeliveryError, NetworkNotFormed +import zigpy.serial import zigpy.state import zigpy.types as t import zigpy.zdo.types as zdo_t @@ -33,10 +37,14 @@ LinkKeyUpdate, NetworkScan, Notification, + PacketCapture, + PacketCaptureChangeChannel, PermitJoins, Ping, ReceivedApsCommand, Request, + Reset, + ResetType, SendAps, SetChannel, SetNwkUpdateId, @@ -46,6 +54,14 @@ _LOGGER = logging.getLogger(__name__) +_RUST_LOG_LEVELS = { + "ERROR": logging.ERROR, + "WARN": logging.WARNING, + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, + "TRACE": 5, +} + RSSI_MIN = -92 RSSI_MAX = -5 @@ -83,17 +99,12 @@ def map_rssi_to_energy(rssi: float) -> float: class PendingRequest: - """The in-flight state of one request: an optional `transmitted` stage future and - the terminal `response` future.""" + """The in-flight state of one request: the terminal `response` future and, for a + streaming request, the queue its result events land in.""" - def __init__( - self, *, want_transmitted: bool, stream_event: str | None = None - ) -> None: + def __init__(self, *, stream_event: str | None = None) -> None: loop = asyncio.get_running_loop() self.response: asyncio.Future[dict[str, Any]] = loop.create_future() - self.transmitted: asyncio.Future[None] | None = ( - loop.create_future() if want_transmitted else None - ) # For a streaming request: the event name carrying results, and a queue those # results land in. `None` enqueued by the terminal response marks the end. self.stream_event = stream_event @@ -102,61 +113,76 @@ def __init__( ) def fail(self, exc: BaseException) -> None: - if self.transmitted is not None and not self.transmitted.done(): - self.transmitted.set_exception(exc) - if not self.response.done(): self.response.set_exception(exc) -def _make_late_failure_logger( - pending: PendingRequest, -) -> Callable[[asyncio.Future[dict[str, Any]]], None]: - """Consume the terminal result of a request that already resolved at the - `transmitted` stage, so delivery failures are visible but not raised. Failures - from before transmission were already raised to the caller and are not logged.""" +class ZigguratSerialProtocol(zigpy.serial.SerialProtocol): + """The serial half of the line-delimited JSON transport.""" - def log_late_failure(fut: asyncio.Future[dict[str, Any]]) -> None: - if fut.cancelled(): - return + def __init__(self, api: ZigguratApi) -> None: + super().__init__() + self._api = api - exc = fut.exception() - if exc is None: - return + def connection_lost(self, exc: BaseException | None) -> None: + super().connection_lost(exc) + self._api.on_transport_lost(exc) - transmitted = ( - pending.transmitted is not None - and pending.transmitted.done() - and pending.transmitted.exception() is None - ) + def data_received(self, data: bytes) -> None: + super().data_received(data) - if transmitted: - _LOGGER.warning("Delivery failed after transmission: %s", exc) + while (newline := self._buffer.find(b"\n")) >= 0: + line = bytes(self._buffer[:newline]).strip() + del self._buffer[: newline + 1] - return log_late_failure + if not line: + continue + + try: + self._api.handle_message(json.loads(line)) + except Exception: + _LOGGER.exception("Failed to handle message: %r", line) + + def send_line(self, text: str) -> None: + assert self._transport is not None + self._transport.write((text + "\n").encode()) class ZigguratApi: - """The Ziggurat WebSocket API: concurrent requests correlated by id, with - lifecycle events (`accepted`, `transmitted`) preceding each terminal response.""" + """The Ziggurat JSON-RPC API.""" def __init__( self, url: str, on_notification: Callable[[Notification], None], on_disconnect: Callable[[BaseException | None], None], + *, + baudrate: int = 115200, + flow_control: str | None = None, ) -> None: self._url = url self._on_notification = on_notification self._on_disconnect = on_disconnect + self._baudrate = baudrate + self._flow_control = flow_control self._session: aiohttp.ClientSession | None = None self._websocket: aiohttp.ClientWebSocketResponse | None = None + self._serial: ZigguratSerialProtocol | None = None self._receiver_task: asyncio.Task[None] | None = None + self._closing = False self._request_id = 1 self._pending: dict[int, PendingRequest] = {} + self._pending_confirms: dict[int, asyncio.Future[dict[str, Any]]] = {} + self._awaiting_aps_ack: set[int] = set() async def connect(self) -> None: + if self._url.startswith(("ws://", "wss://", "ws+unix://")): + await self._connect_websocket() + else: + await self._connect_serial() + + async def _connect_websocket(self) -> None: if self._url.startswith("ws+unix://"): # The URL's path is the socket path; the HTTP-level host is a placeholder connector = aiohttp.UnixConnector(path=self._url.removeprefix("ws+unix://")) @@ -175,7 +201,22 @@ async def connect(self) -> None: self._receiver_task = asyncio.create_task(self._receive_loop()) + async def _connect_serial(self) -> None: + _LOGGER.debug("Connecting to ziggurat over serial: %s", self._url) + + _, protocol = await zigpy.serial.create_serial_connection( + loop=asyncio.get_running_loop(), + protocol_factory=lambda: ZigguratSerialProtocol(self), + url=self._url, + baudrate=self._baudrate, + flow_control=cast(Any, self._flow_control), + ) + self._serial = cast(ZigguratSerialProtocol, protocol) + await self._serial.wait_until_connected() + async def disconnect(self) -> None: + self._closing = True + if self._receiver_task is not None: self._receiver_task.cancel() self._receiver_task = None @@ -188,6 +229,26 @@ async def disconnect(self) -> None: await self._session.close() self._session = None + if self._serial is not None: + self._serial.close() + await self._serial.wait_until_closed() + self._serial = None + + def on_transport_lost(self, exc: BaseException | None) -> None: + """The serial transport closed.""" + self._fail_pending(ConnectionError("Connection lost")) + + if not self._closing: + self._on_disconnect(exc) + + async def _send_line(self, text: str) -> None: + if self._websocket is not None: + await self._websocket.send_str(text) + elif self._serial is not None: + self._serial.send_line(text) + else: + raise ConnectionError("Not connected") + async def _receive_loop(self) -> None: websocket = self._websocket assert websocket is not None @@ -198,7 +259,7 @@ async def _receive_loop(self) -> None: async for msg in websocket: if msg.type == aiohttp.WSMsgType.TEXT: try: - self._handle_message(json.loads(msg.data)) + self.handle_message(json.loads(msg.data)) except Exception: _LOGGER.exception("Failed to handle message: %r", msg.data) elif msg.type == aiohttp.WSMsgType.ERROR: @@ -218,28 +279,36 @@ async def _receive_loop(self) -> None: def _fail_pending(self, exc: BaseException) -> None: for pending in self._pending.values(): - pending.response.add_done_callback(_make_late_failure_logger(pending)) pending.fail(exc) self._pending.clear() - def _handle_message(self, msg: dict[str, Any]) -> None: + for confirm in self._pending_confirms.values(): + if not confirm.done(): + confirm.set_exception(exc) + + self._pending_confirms.clear() + self._awaiting_aps_ack.clear() + + def handle_message(self, msg: dict[str, Any]) -> None: _LOGGER.debug("Received: %r", msg) msg_type = msg["type"] if msg_type == "notification": - self._on_notification(NOTIFICATIONS[msg["event"]].from_dict(msg["data"])) + event = msg["event"] + if event == "log": + self._handle_log(msg["data"]) + elif event == "send_confirm": + self._handle_send_confirm(msg["data"]) + elif event == "aps_ack_confirm": + self._handle_aps_ack_confirm(msg["data"]) + else: + self._on_notification(NOTIFICATIONS[event].from_dict(msg["data"])) elif msg_type == "event": pending = self._pending.get(msg["id"]) if pending is None: pass - elif ( - msg["event"] == "transmitted" - and pending.transmitted is not None - and not pending.transmitted.done() - ): - pending.transmitted.set_result(None) elif pending.events is not None and msg["event"] == pending.stream_event: pending.events.put_nowait(msg["data"]) elif msg_type == "response": @@ -255,44 +324,94 @@ def _handle_message(self, msg: dict[str, Any]) -> None: elif not pending.response.done(): pending.response.set_result(msg["result"]) + def _handle_log(self, data: dict[str, Any]) -> None: + """Tunnel a firmware `log` notification.""" + level = _RUST_LOG_LEVELS.get(data["level"], logging.INFO) + logger = logging.getLogger("ziggurat.fw." + data["target"].replace("::", ".")) + logger.log(level, "%s", data["message"]) + + def _handle_send_confirm(self, data: dict[str, Any]) -> None: + """Resolve a no-ack send, or an APS-ack send whose handoff failed.""" + request_id = data["id"] + confirm = self._pending_confirms.get(request_id) + if confirm is None or confirm.done(): + return + # A successful handoff is not final for an APS-ack send: its aps_ack_confirm is. + if data["status"] == "confirmed" and request_id in self._awaiting_aps_ack: + return + self._awaiting_aps_ack.discard(request_id) + confirm.set_result(data) + + def _handle_aps_ack_confirm(self, data: dict[str, Any]) -> None: + """Resolve an APS-ack send with its end-to-end result.""" + request_id = data["id"] + confirm = self._pending_confirms.get(request_id) + self._awaiting_aps_ack.discard(request_id) + if confirm is not None and not confirm.done(): + confirm.set_result(data) + async def request(self, command: Request[RESPONSE_T]) -> RESPONSE_T: - result = await self._send_request(command, want_transmitted=False) + result = await self._send_request(command) assert result is not None # `response_type` is a plain ClassVar: it cannot carry the type variable return cast(RESPONSE_T, command.response_type.from_dict(result)) - async def request_transmitted(self, command: Request[Any]) -> None: - """Resolve once the frame is on the air instead of waiting for delivery.""" - await self._send_request(command, want_transmitted=True) + async def request_confirmed(self, command: SendAps) -> None: + """Enqueue a send and resolve once its terminal confirmation arrives: the + end-to-end APS ack for an ack-requested unicast, otherwise the local handoff. + Raises `DeliveryError` if the stack rejects the frame or the confirmation + reports failure.""" + async with asyncio.timeout(30): + result = await self._send_request( + command, want_confirm=True, expect_aps_ack=command.aps_ack + ) + assert result is not None + + if result["status"] == "failed": + raise DeliveryError(result["reason"]) async def _send_request( - self, command: Request[Any], *, want_transmitted: bool + self, + command: Request[Any], + *, + want_confirm: bool = False, + expect_aps_ack: bool = False, ) -> dict[str, Any] | None: request_id = self._request_id self._request_id = (self._request_id + 1) % 2**32 or 1 - pending = PendingRequest(want_transmitted=want_transmitted) + pending = PendingRequest() self._pending[request_id] = pending + # Register the confirmation before sending so it cannot race the notification. + confirm: asyncio.Future[dict[str, Any]] | None = None + if want_confirm: + confirm = asyncio.get_running_loop().create_future() + self._pending_confirms[request_id] = confirm + if expect_aps_ack: + self._awaiting_aps_ack.add(request_id) + message = { "id": request_id, "method": command.method, "params": command.to_dict(), } _LOGGER.debug("Sending: %r", message) - assert self._websocket is not None - await self._websocket.send_str(json.dumps(message)) + await self._send_line(json.dumps(message)) - if want_transmitted: - # The terminal response continues in the background; an end-to-end - # delivery failure after transmission is logged, not raised - pending.response.add_done_callback(_make_late_failure_logger(pending)) - assert pending.transmitted is not None - await pending.transmitted - return None + if not want_confirm: + return await pending.response - return await pending.response + # Stage two: the stack accepts (a success response) or rejects (an error, + # raised here). Stage three: the `send_confirm` notification keyed by this id. + assert confirm is not None + try: + await pending.response + return await confirm + finally: + self._pending_confirms.pop(request_id, None) + self._awaiting_aps_ack.discard(request_id) async def request_stream( self, command: StreamingRequest[Any, EVENT_T] @@ -303,9 +422,7 @@ async def request_stream( request_id = self._request_id self._request_id = (self._request_id + 1) % 2**32 or 1 - pending = PendingRequest( - want_transmitted=False, stream_event=command.event_name - ) + pending = PendingRequest(stream_event=command.event_name) self._pending[request_id] = pending message = { @@ -314,8 +431,7 @@ async def request_stream( "params": command.to_dict(), } _LOGGER.debug("Sending: %r", message) - assert self._websocket is not None - await self._websocket.send_str(json.dumps(message)) + await self._send_line(json.dumps(message)) assert pending.events is not None events = pending.events @@ -364,18 +480,26 @@ def __init__(self, config: dict[str, Any]) -> None: self._api: ZigguratApi | None = None async def connect(self) -> None: - # The device path is the WebSocket URL of the ziggurat server - url = self._config[zigpy.config.CONF_DEVICE][zigpy.config.CONF_DEVICE_PATH] + # The device path is either the WebSocket URL of a ziggurat server or the serial + # port of a ziggurat firmware (e.g. an ESP32-C6 over USB-Serial-JTAG) + device = self._config[zigpy.config.CONF_DEVICE] + url = device[zigpy.config.CONF_DEVICE_PATH] # zigpy types `connection_lost` as Exception-only but handles None fine api = ZigguratApi( url, self.on_notification, self.connection_lost, # type: ignore[arg-type] + baudrate=device[zigpy.config.CONF_DEVICE_BAUDRATE], + flow_control=device[zigpy.config.CONF_DEVICE_FLOW_CONTROL], ) await api.connect() self._api = api + # Clear any transient radio state left by a previous client (e.g. a packet + # capture still streaming on the firmware) so this session starts from idle. + await api.request(Reset(reset_type=ResetType.SOFT)) + async def disconnect(self) -> None: if self._api is not None: try: @@ -392,6 +516,12 @@ async def start_network(self) -> None: self._register_coordinator_device() await self.register_endpoints() + url = self._config[zigpy.config.CONF_DEVICE][zigpy.config.CONF_DEVICE_PATH] + if not url.startswith(("ws://", "wss://", "ws+unix://")): + self._concurrent_requests_semaphore.max_concurrency = 64 + else: + self._concurrent_requests_semaphore.max_concurrency = 128 + def _register_coordinator_device(self) -> None: coordinator = ZigguratCoordinator( self, self.state.node_info.ieee, self.state.node_info.nwk @@ -604,6 +734,23 @@ async def _network_scan( protocol_version=beacon.protocol_version, ) + async def _packet_capture( + self, channel: int + ) -> AsyncGenerator[t.CapturedPacket, None]: + assert self._api is not None + async for packet in self._api.request_stream(PacketCapture(channel=channel)): + yield t.CapturedPacket( + timestamp=datetime.now(timezone.utc), + rssi=packet.rssi, + lqi=packet.lqi, + channel=packet.channel, + data=bytes.fromhex(packet.data), + ) + + async def _packet_capture_change_channel(self, channel: int) -> None: + assert self._api is not None + await self._api.request(PacketCaptureChangeChannel(channel=channel)) + async def write_network_info( self, *, @@ -921,23 +1068,25 @@ async def send_packet(self, packet: t.ZigbeePacket) -> None: # The server selects the link key by EUI64 destination_eui64 = self.get_device(nwk=destination).ieee - # Resolves once the frame is on the air (EZSP `messageSent` parity); the - # APS-ack delivery result arrives later and is logged by the API layer + # Resolves once the send is confirmed: passive-ack quorum for a broadcast, + # next-hop acceptance for a no-ack unicast, or the end-to-end APS ack. A + # rejected or failed send raises `DeliveryError`. assert self._api is not None - await self._api.request_transmitted( - SendAps( - delivery_mode=delivery_mode, - destination_eui64=destination_eui64, - destination=destination, - profile_id=packet.profile_id, - cluster_id=packet.cluster_id or 0x0000, - src_ep=packet.src_ep or 0, - dst_ep=packet.dst_ep or 0, - aps_ack=t.TransmitOptions.ACK in packet.tx_options, - aps_encryption=aps_encryption, - radius=packet.radius or 30, - aps_seq=packet.tsn, - priority=packet.priority if packet.priority is not None else 0, - data=packet.data.serialize(), + async with self._limit_concurrency(priority=packet.priority): + await self._api.request_confirmed( + SendAps( + delivery_mode=delivery_mode, + destination_eui64=destination_eui64, + destination=destination, + profile_id=packet.profile_id, + cluster_id=packet.cluster_id or 0x0000, + src_ep=packet.src_ep or 0, + dst_ep=packet.dst_ep or 0, + aps_ack=t.TransmitOptions.ACK in packet.tx_options, + aps_encryption=aps_encryption, + radius=packet.radius or 30, + aps_seq=packet.tsn, + priority=packet.priority if packet.priority is not None else 0, + data=packet.data.serialize(), + ) ) - ) diff --git a/zigpy_ziggurat/zigbee/commands.py b/zigpy_ziggurat/zigbee/commands.py index 7a39bd4..5601fef 100644 --- a/zigpy_ziggurat/zigbee/commands.py +++ b/zigpy_ziggurat/zigbee/commands.py @@ -301,6 +301,22 @@ class SetProvisionalKey(Request[Status]): key: t.KeyData +class ResetType(enum.StrEnum): + # Stop transient radio activity (packet capture) and return to idle, leaving any + # configured network running. Sent by the client on connect as a session reset. + SOFT = "soft" + # Reboot/reset the radio. + HARD = "hard" + + +@dataclass +class Reset(Request[Status]): + method = "reset" + response_type = Status + + reset_type: ResetType + + @dataclass class SetChannel(Request[Status]): method = "set_channel" @@ -309,6 +325,33 @@ class SetChannel(Request[Status]): channel: int +@dataclass +class CapturedPacketEvent(Response): + channel: t.uint8_t + rssi: t.int8s + lqi: t.uint8_t + # Hex-encoded 802.15.4 MAC frame (FCS stripped) + data: str + + +@dataclass +class PacketCapture(StreamingRequest[Status, CapturedPacketEvent]): + method = "packet_capture" + response_type = Status + event_type = CapturedPacketEvent + event_name = "captured_packet" + + channel: int + + +@dataclass +class PacketCaptureChangeChannel(Request[Status]): + method = "packet_capture_change_channel" + response_type = Status + + channel: int + + @dataclass class SetNwkUpdateId(Request[Status]): method = "set_nwk_update_id"