Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 170 additions & 2 deletions tests/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import threading
import time
from importlib import metadata
from unittest.mock import MagicMock, patch

import pytest
Expand All @@ -17,9 +18,24 @@
from wherobots.db.region import Region
from wherobots.db.runtime import Runtime

# The version DBAPI stamps into its X-Wherobots-Client hop, resolved the same
# way as gen_user_agent_header().
DBAPI_VERSION = metadata.version("wherobots-python-dbapi")


def _run_connect(mock_post, mock_get, **connect_kwargs):
"""Drive a successful connect() and return the kwargs passed to requests.post."""
kwargs, _ = _run_connect_full(mock_post, mock_get, **connect_kwargs)
return kwargs


def _run_connect_full(mock_post, mock_get, **connect_kwargs):
"""Drive a successful connect().

Returns a tuple of (kwargs passed to requests.post, kwargs passed to the
patched connect_direct) so tests can assert on both the HTTP POST and the
WebSocket upgrade path.
"""
post_resp = MagicMock()
post_resp.status_code = 200
post_resp.url = "https://api.example.com/sql/session/test-id"
Expand All @@ -38,9 +54,10 @@ def _run_connect(mock_post, mock_get, **connect_kwargs):
with patch("wherobots.db.driver.connect_direct") as mock_cd:
mock_cd.return_value = MagicMock()
connect(api_key="test-key", **connect_kwargs)
_, cd_kwargs = mock_cd.call_args

_, kwargs = mock_post.call_args
return kwargs
_, post_kwargs = mock_post.call_args
return post_kwargs, cd_kwargs

Comment thread
salty-hambot[bot] marked this conversation as resolved.

class TestConnectRegionRuntime:
Expand Down Expand Up @@ -204,3 +221,154 @@ def test_cancel_before_ws_connect(self, mock_ws):
)

mock_ws.assert_not_called()


class TestWherobotsClientHeader:
"""connect() emits/appends the shared X-Wherobots-Client hop.

The header is an ordered, append-only, comma-separated list of hops; each
component appends its own `client=<token>;ver=<v>` hop on the right. DBAPI
appends `client=dbapi;ver=<version>`.
"""

@patch("wherobots.db.driver.requests.get")
@patch("wherobots.db.driver.requests.post")
def test_appends_dbapi_hop_to_inbound_chain(self, mock_post, mock_get):
"""An inbound X-Wherobots-Client chain gets the dbapi hop appended on the right."""
inbound = "client=claude_web, client=mcp;ver=0.9"
post_kwargs = _run_connect(
mock_post,
mock_get,
extra_headers={"X-Wherobots-Client": inbound},
)
assert post_kwargs["headers"]["X-Wherobots-Client"] == (
f"{inbound}, client=dbapi;ver={DBAPI_VERSION}"
)

@patch("wherobots.db.driver.requests.get")
@patch("wherobots.db.driver.requests.post")
def test_sets_dbapi_hop_when_no_inbound_chain(self, mock_post, mock_get):
"""With no inbound chain the header is exactly the dbapi hop."""
post_kwargs = _run_connect(mock_post, mock_get)
assert (
post_kwargs["headers"]["X-Wherobots-Client"]
== f"client=dbapi;ver={DBAPI_VERSION}"
)

@patch("wherobots.db.driver.requests.get")
@patch("wherobots.db.driver.requests.post")
def test_appends_case_insensitively(self, mock_post, mock_get):
"""A differently-cased inbound header is matched and collapsed to one canonical header."""
inbound = "client=claude_web"
post_kwargs = _run_connect(
mock_post,
mock_get,
extra_headers={"x-wherobots-client": inbound},
)
headers = post_kwargs["headers"]
# No differently-cased duplicate remains (would send two HTTP headers).
client_keys = [k for k in headers if k.lower() == "x-wherobots-client"]
assert client_keys == ["X-Wherobots-Client"]
assert headers["X-Wherobots-Client"] == (
f"{inbound}, client=dbapi;ver={DBAPI_VERSION}"
)

@patch("wherobots.db.driver.requests.get")
@patch("wherobots.db.driver.requests.post")
def test_forwards_arbitrary_extra_headers(self, mock_post, mock_get):
"""Arbitrary extra_headers keys are merged into the request headers."""
post_kwargs = _run_connect(
mock_post,
mock_get,
extra_headers={"X-Trace-Id": "abc-123", "X-Custom": "yes"},
)
assert post_kwargs["headers"]["X-Trace-Id"] == "abc-123"
assert post_kwargs["headers"]["X-Custom"] == "yes"

@patch("wherobots.db.driver.requests.get")
@patch("wherobots.db.driver.requests.post")
def test_extra_headers_do_not_override_auth(self, mock_post, mock_get):
"""The advisory header must never clobber the auth header."""
post_kwargs = _run_connect(
mock_post,
mock_get,
extra_headers={"X-Wherobots-Client": "client=claude_web"},
)
# api_key auth is set by _run_connect; extra_headers must not remove it.
assert post_kwargs["headers"]["X-API-Key"] == "test-key"

Comment thread
salty-hambot[bot] marked this conversation as resolved.
@patch("wherobots.db.driver.requests.get")
@patch("wherobots.db.driver.requests.post")
def test_extra_headers_cannot_spoof_api_key(self, mock_post, mock_get):
"""A hostile X-API-Key in extra_headers cannot override the real api_key.

Auth headers are applied *after* extra_headers is merged, so the real
credential always wins. This is the security invariant: the advisory
header path can never be used to inject or swap credentials.
"""
# _run_connect passes api_key="test-key"; try to spoof it via extra_headers.
post_kwargs = _run_connect(
mock_post,
mock_get,
extra_headers={"X-API-Key": "evil-key"},
)
assert post_kwargs["headers"]["X-API-Key"] == "test-key"

@patch("wherobots.db.driver.requests.get")
@patch("wherobots.db.driver.requests.post")
def test_extra_headers_cannot_spoof_bearer_token(self, mock_post, mock_get):
"""A hostile Authorization in extra_headers cannot override the real token."""
post_resp = MagicMock()
post_resp.status_code = 200
post_resp.url = "https://api.example.com/sql/session/test-id"
post_resp.raise_for_status = MagicMock()
mock_post.return_value = post_resp

get_resp = MagicMock()
get_resp.status_code = 200
get_resp.raise_for_status = MagicMock()
get_resp.json.return_value = {
"status": "READY",
"appMeta": {"url": "https://compute.example.com/sql/org/session-id"},
}
mock_get.return_value = get_resp

with patch("wherobots.db.driver.connect_direct") as mock_cd:
mock_cd.return_value = MagicMock()
connect(
token="real-token",
extra_headers={"Authorization": "Bearer evil-token"},
)

_, post_kwargs = mock_post.call_args
assert post_kwargs["headers"]["Authorization"] == "Bearer real-token"

@patch("wherobots.db.driver.requests.get")
@patch("wherobots.db.driver.requests.post")
def test_ws_upgrade_carries_appended_header(self, mock_post, mock_get):
"""The appended X-Wherobots-Client also flows to the WebSocket upgrade path."""
inbound = "client=claude_web, client=mcp;ver=0.9"
_post_kwargs, cd_kwargs = _run_connect_full(
mock_post,
mock_get,
extra_headers={"X-Wherobots-Client": inbound},
)
assert cd_kwargs["headers"]["X-Wherobots-Client"] == (
f"{inbound}, client=dbapi;ver={DBAPI_VERSION}"
)

@patch("wherobots.db.driver.websockets.sync.client.connect")
def test_connect_direct_forwards_headers_to_ws(self, mock_ws):
"""connect_direct forwards headers verbatim to the websocket upgrade."""
headers = {"X-Wherobots-Client": f"client=dbapi;ver={DBAPI_VERSION}"}
mock_ws.return_value = MagicMock()
with patch("wherobots.db.driver.Connection") as mock_conn:
mock_conn.return_value = MagicMock()
connect_direct(
uri="wss://compute.example.com/sql/org/session-id",
headers=headers,
)
_, ws_kwargs = mock_ws.call_args
assert ws_kwargs["additional_headers"]["X-Wherobots-Client"] == (
f"client=dbapi;ver={DBAPI_VERSION}"
)
65 changes: 61 additions & 4 deletions wherobots/db/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,59 @@
DEFAULT_HTTP_TIMEOUT = 30


def gen_user_agent_header():
def _resolve_dbapi_version() -> str:
"""Resolve this package's version, or "unknown" if it can't be found."""
try:
package_version = metadata.version("wherobots-python-dbapi")
return metadata.version("wherobots-python-dbapi")
except PackageNotFoundError:
package_version = "unknown"
return "unknown"


# Resolved once at import: `importlib.metadata.version` scans the installed
# package database on each call, and the version can't change within a process.
_DBAPI_VERSION: Final[str] = _resolve_dbapi_version()


def gen_user_agent_header():
python_version = platform.python_version()
system = platform.system().lower()
return {
"User-Agent": f"wherobots-python-dbapi/{package_version} os/{system} python/{python_version}"
"User-Agent": f"wherobots-python-dbapi/{_DBAPI_VERSION} os/{system} python/{python_version}"
}


# Canonical name of the shared, cross-service client-chain header.
WHEROBOTS_CLIENT_HEADER: Final[str] = "X-Wherobots-Client"


def _append_wherobots_client_hop(headers: Dict[str, str]) -> None:
"""Append this driver's hop to the shared ``X-Wherobots-Client`` header.

``X-Wherobots-Client`` is an ordered, append-only, comma-separated list of
hops; the leftmost is the origin and each component appends its own hop on
the right. DBAPI's hop is ``client=dbapi;ver=<version>``.

The lookup is case-insensitive because HTTP header names are, and any
differently-cased inbound key is collapsed into the canonical
``X-Wherobots-Client`` key so the request carries exactly one such header.

This header is advisory: it is informational only and must never affect
authentication. ``headers`` is mutated in place.
"""
dbapi_hop = f"client=dbapi;ver={_DBAPI_VERSION}"

# Find any existing hop chain case-insensitively and remove differently
# cased duplicates so we don't emit two headers.
existing_chain: Union[str, None] = None
for key in [k for k in headers if k.lower() == WHEROBOTS_CLIENT_HEADER.lower()]:
existing_chain = headers.pop(key)

if existing_chain:
headers[WHEROBOTS_CLIENT_HEADER] = f"{existing_chain}, {dbapi_hop}"
else:
headers[WHEROBOTS_CLIENT_HEADER] = dbapi_hop


def connect(
host: str = DEFAULT_ENDPOINT,
token: Union[str, None] = None,
Expand All @@ -82,6 +123,7 @@ def connect(
data_compression: Union[DataCompression, None] = None,
geometry_representation: Union[GeometryRepresentation, None] = None,
cancel_event: Union[threading.Event, None] = None,
extra_headers: Union[Dict[str, str], None] = None,
) -> Connection:
"""Create a connection to a Wherobots SQL session.

Expand All @@ -96,18 +138,33 @@ def connect(
your organization — only set this if you intend to use a specific region
instead of the one your administrator has configured. When omitted, your
organization's default region is used.
:param extra_headers: Optional extra HTTP headers to send on the session
requests and the WebSocket upgrade. Merged after the driver's own
headers, so callers can pass through tracing/correlation headers. If it
includes an ``X-Wherobots-Client`` hop chain, this driver appends its
own ``client=dbapi;ver=<version>`` hop to the right of it. This header
is advisory only and never affects authentication; ``extra_headers``
cannot be used to override the ``Authorization``/``X-API-Key`` headers.
"""
if not token and not api_key:
raise ValueError("At least one of `token` or `api_key` is required")
if token and api_key:
raise ValueError("`token` and `api_key` can't be both provided")

headers = gen_user_agent_header()
# Merge caller-supplied headers first so the driver's own auth headers,
# applied below, always win and can never be overridden.
if extra_headers:
headers.update(extra_headers)
if token:
headers["Authorization"] = f"Bearer {token}"
elif api_key:
headers["X-API-Key"] = api_key
Comment thread
salty-hambot[bot] marked this conversation as resolved.

# Append this driver's hop to the shared client-chain header. Done after the
# auth headers are set so it operates on the final, merged header set.
_append_wherobots_client_hop(headers)

host = host or DEFAULT_ENDPOINT
session_type = session_type or DEFAULT_SESSION_TYPE

Expand Down
Loading