diff --git a/tests/test_driver.py b/tests/test_driver.py index b9cf20a..db8022c 100644 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -2,6 +2,7 @@ import threading import time +from importlib import metadata from unittest.mock import MagicMock, patch import pytest @@ -17,9 +18,24 @@ from wherobots.db.region import Region from wherobots.db.runtime import Runtime +# The version DBAPI stamps into its X-Wherobots-Client hop, resolved the same +# way as gen_user_agent_header(). +DBAPI_VERSION = metadata.version("wherobots-python-dbapi") + def _run_connect(mock_post, mock_get, **connect_kwargs): """Drive a successful connect() and return the kwargs passed to requests.post.""" + kwargs, _ = _run_connect_full(mock_post, mock_get, **connect_kwargs) + return kwargs + + +def _run_connect_full(mock_post, mock_get, **connect_kwargs): + """Drive a successful connect(). + + Returns a tuple of (kwargs passed to requests.post, kwargs passed to the + patched connect_direct) so tests can assert on both the HTTP POST and the + WebSocket upgrade path. + """ post_resp = MagicMock() post_resp.status_code = 200 post_resp.url = "https://api.example.com/sql/session/test-id" @@ -38,9 +54,10 @@ def _run_connect(mock_post, mock_get, **connect_kwargs): with patch("wherobots.db.driver.connect_direct") as mock_cd: mock_cd.return_value = MagicMock() connect(api_key="test-key", **connect_kwargs) + _, cd_kwargs = mock_cd.call_args - _, kwargs = mock_post.call_args - return kwargs + _, post_kwargs = mock_post.call_args + return post_kwargs, cd_kwargs class TestConnectRegionRuntime: @@ -204,3 +221,154 @@ def test_cancel_before_ws_connect(self, mock_ws): ) mock_ws.assert_not_called() + + +class TestWherobotsClientHeader: + """connect() emits/appends the shared X-Wherobots-Client hop. + + The header is an ordered, append-only, comma-separated list of hops; each + component appends its own `client=;ver=` hop on the right. DBAPI + appends `client=dbapi;ver=`. + """ + + @patch("wherobots.db.driver.requests.get") + @patch("wherobots.db.driver.requests.post") + def test_appends_dbapi_hop_to_inbound_chain(self, mock_post, mock_get): + """An inbound X-Wherobots-Client chain gets the dbapi hop appended on the right.""" + inbound = "client=claude_web, client=mcp;ver=0.9" + post_kwargs = _run_connect( + mock_post, + mock_get, + extra_headers={"X-Wherobots-Client": inbound}, + ) + assert post_kwargs["headers"]["X-Wherobots-Client"] == ( + f"{inbound}, client=dbapi;ver={DBAPI_VERSION}" + ) + + @patch("wherobots.db.driver.requests.get") + @patch("wherobots.db.driver.requests.post") + def test_sets_dbapi_hop_when_no_inbound_chain(self, mock_post, mock_get): + """With no inbound chain the header is exactly the dbapi hop.""" + post_kwargs = _run_connect(mock_post, mock_get) + assert ( + post_kwargs["headers"]["X-Wherobots-Client"] + == f"client=dbapi;ver={DBAPI_VERSION}" + ) + + @patch("wherobots.db.driver.requests.get") + @patch("wherobots.db.driver.requests.post") + def test_appends_case_insensitively(self, mock_post, mock_get): + """A differently-cased inbound header is matched and collapsed to one canonical header.""" + inbound = "client=claude_web" + post_kwargs = _run_connect( + mock_post, + mock_get, + extra_headers={"x-wherobots-client": inbound}, + ) + headers = post_kwargs["headers"] + # No differently-cased duplicate remains (would send two HTTP headers). + client_keys = [k for k in headers if k.lower() == "x-wherobots-client"] + assert client_keys == ["X-Wherobots-Client"] + assert headers["X-Wherobots-Client"] == ( + f"{inbound}, client=dbapi;ver={DBAPI_VERSION}" + ) + + @patch("wherobots.db.driver.requests.get") + @patch("wherobots.db.driver.requests.post") + def test_forwards_arbitrary_extra_headers(self, mock_post, mock_get): + """Arbitrary extra_headers keys are merged into the request headers.""" + post_kwargs = _run_connect( + mock_post, + mock_get, + extra_headers={"X-Trace-Id": "abc-123", "X-Custom": "yes"}, + ) + assert post_kwargs["headers"]["X-Trace-Id"] == "abc-123" + assert post_kwargs["headers"]["X-Custom"] == "yes" + + @patch("wherobots.db.driver.requests.get") + @patch("wherobots.db.driver.requests.post") + def test_extra_headers_do_not_override_auth(self, mock_post, mock_get): + """The advisory header must never clobber the auth header.""" + post_kwargs = _run_connect( + mock_post, + mock_get, + extra_headers={"X-Wherobots-Client": "client=claude_web"}, + ) + # api_key auth is set by _run_connect; extra_headers must not remove it. + assert post_kwargs["headers"]["X-API-Key"] == "test-key" + + @patch("wherobots.db.driver.requests.get") + @patch("wherobots.db.driver.requests.post") + def test_extra_headers_cannot_spoof_api_key(self, mock_post, mock_get): + """A hostile X-API-Key in extra_headers cannot override the real api_key. + + Auth headers are applied *after* extra_headers is merged, so the real + credential always wins. This is the security invariant: the advisory + header path can never be used to inject or swap credentials. + """ + # _run_connect passes api_key="test-key"; try to spoof it via extra_headers. + post_kwargs = _run_connect( + mock_post, + mock_get, + extra_headers={"X-API-Key": "evil-key"}, + ) + assert post_kwargs["headers"]["X-API-Key"] == "test-key" + + @patch("wherobots.db.driver.requests.get") + @patch("wherobots.db.driver.requests.post") + def test_extra_headers_cannot_spoof_bearer_token(self, mock_post, mock_get): + """A hostile Authorization in extra_headers cannot override the real token.""" + post_resp = MagicMock() + post_resp.status_code = 200 + post_resp.url = "https://api.example.com/sql/session/test-id" + post_resp.raise_for_status = MagicMock() + mock_post.return_value = post_resp + + get_resp = MagicMock() + get_resp.status_code = 200 + get_resp.raise_for_status = MagicMock() + get_resp.json.return_value = { + "status": "READY", + "appMeta": {"url": "https://compute.example.com/sql/org/session-id"}, + } + mock_get.return_value = get_resp + + with patch("wherobots.db.driver.connect_direct") as mock_cd: + mock_cd.return_value = MagicMock() + connect( + token="real-token", + extra_headers={"Authorization": "Bearer evil-token"}, + ) + + _, post_kwargs = mock_post.call_args + assert post_kwargs["headers"]["Authorization"] == "Bearer real-token" + + @patch("wherobots.db.driver.requests.get") + @patch("wherobots.db.driver.requests.post") + def test_ws_upgrade_carries_appended_header(self, mock_post, mock_get): + """The appended X-Wherobots-Client also flows to the WebSocket upgrade path.""" + inbound = "client=claude_web, client=mcp;ver=0.9" + _post_kwargs, cd_kwargs = _run_connect_full( + mock_post, + mock_get, + extra_headers={"X-Wherobots-Client": inbound}, + ) + assert cd_kwargs["headers"]["X-Wherobots-Client"] == ( + f"{inbound}, client=dbapi;ver={DBAPI_VERSION}" + ) + + @patch("wherobots.db.driver.websockets.sync.client.connect") + def test_connect_direct_forwards_headers_to_ws(self, mock_ws): + """connect_direct forwards headers verbatim to the websocket upgrade.""" + headers = {"X-Wherobots-Client": f"client=dbapi;ver={DBAPI_VERSION}"} + mock_ws.return_value = MagicMock() + with patch("wherobots.db.driver.Connection") as mock_conn: + mock_conn.return_value = MagicMock() + connect_direct( + uri="wss://compute.example.com/sql/org/session-id", + headers=headers, + ) + _, ws_kwargs = mock_ws.call_args + assert ws_kwargs["additional_headers"]["X-Wherobots-Client"] == ( + f"client=dbapi;ver={DBAPI_VERSION}" + ) diff --git a/wherobots/db/driver.py b/wherobots/db/driver.py index ae2d0d1..2b9f40c 100644 --- a/wherobots/db/driver.py +++ b/wherobots/db/driver.py @@ -54,18 +54,59 @@ DEFAULT_HTTP_TIMEOUT = 30 -def gen_user_agent_header(): +def _resolve_dbapi_version() -> str: + """Resolve this package's version, or "unknown" if it can't be found.""" try: - package_version = metadata.version("wherobots-python-dbapi") + return metadata.version("wherobots-python-dbapi") except PackageNotFoundError: - package_version = "unknown" + return "unknown" + + +# Resolved once at import: `importlib.metadata.version` scans the installed +# package database on each call, and the version can't change within a process. +_DBAPI_VERSION: Final[str] = _resolve_dbapi_version() + + +def gen_user_agent_header(): python_version = platform.python_version() system = platform.system().lower() return { - "User-Agent": f"wherobots-python-dbapi/{package_version} os/{system} python/{python_version}" + "User-Agent": f"wherobots-python-dbapi/{_DBAPI_VERSION} os/{system} python/{python_version}" } +# Canonical name of the shared, cross-service client-chain header. +WHEROBOTS_CLIENT_HEADER: Final[str] = "X-Wherobots-Client" + + +def _append_wherobots_client_hop(headers: Dict[str, str]) -> None: + """Append this driver's hop to the shared ``X-Wherobots-Client`` header. + + ``X-Wherobots-Client`` is an ordered, append-only, comma-separated list of + hops; the leftmost is the origin and each component appends its own hop on + the right. DBAPI's hop is ``client=dbapi;ver=``. + + The lookup is case-insensitive because HTTP header names are, and any + differently-cased inbound key is collapsed into the canonical + ``X-Wherobots-Client`` key so the request carries exactly one such header. + + This header is advisory: it is informational only and must never affect + authentication. ``headers`` is mutated in place. + """ + dbapi_hop = f"client=dbapi;ver={_DBAPI_VERSION}" + + # Find any existing hop chain case-insensitively and remove differently + # cased duplicates so we don't emit two headers. + existing_chain: Union[str, None] = None + for key in [k for k in headers if k.lower() == WHEROBOTS_CLIENT_HEADER.lower()]: + existing_chain = headers.pop(key) + + if existing_chain: + headers[WHEROBOTS_CLIENT_HEADER] = f"{existing_chain}, {dbapi_hop}" + else: + headers[WHEROBOTS_CLIENT_HEADER] = dbapi_hop + + def connect( host: str = DEFAULT_ENDPOINT, token: Union[str, None] = None, @@ -82,6 +123,7 @@ def connect( data_compression: Union[DataCompression, None] = None, geometry_representation: Union[GeometryRepresentation, None] = None, cancel_event: Union[threading.Event, None] = None, + extra_headers: Union[Dict[str, str], None] = None, ) -> Connection: """Create a connection to a Wherobots SQL session. @@ -96,6 +138,13 @@ def connect( your organization — only set this if you intend to use a specific region instead of the one your administrator has configured. When omitted, your organization's default region is used. + :param extra_headers: Optional extra HTTP headers to send on the session + requests and the WebSocket upgrade. Merged after the driver's own + headers, so callers can pass through tracing/correlation headers. If it + includes an ``X-Wherobots-Client`` hop chain, this driver appends its + own ``client=dbapi;ver=`` hop to the right of it. This header + is advisory only and never affects authentication; ``extra_headers`` + cannot be used to override the ``Authorization``/``X-API-Key`` headers. """ if not token and not api_key: raise ValueError("At least one of `token` or `api_key` is required") @@ -103,11 +152,19 @@ def connect( raise ValueError("`token` and `api_key` can't be both provided") headers = gen_user_agent_header() + # Merge caller-supplied headers first so the driver's own auth headers, + # applied below, always win and can never be overridden. + if extra_headers: + headers.update(extra_headers) if token: headers["Authorization"] = f"Bearer {token}" elif api_key: headers["X-API-Key"] = api_key + # Append this driver's hop to the shared client-chain header. Done after the + # auth headers are set so it operates on the final, merged header set. + _append_wherobots_client_hop(headers) + host = host or DEFAULT_ENDPOINT session_type = session_type or DEFAULT_SESSION_TYPE