From f6f51e9bc55b6906f20141f1afa63c9295e438b3 Mon Sep 17 00:00:00 2001 From: Razvan Radulescu <43811028+h3xxit@users.noreply.github.com> Date: Thu, 18 Jun 2026 01:23:56 +0200 Subject: [PATCH] security fixes --- .../gql/pyproject.toml | 2 +- .../gql/src/utcp_gql/_security.py | 19 ++- .../http/pyproject.toml | 2 +- .../http/src/utcp_http/_security.py | 32 ++++- .../utcp_http/http_communication_protocol.py | 52 ++++++-- .../utcp_http/sse_communication_protocol.py | 42 +++++-- .../streamable_http_communication_protocol.py | 38 ++++-- .../http/tests/test_redirect_security.py | 116 ++++++++++++++++++ .../websocket/pyproject.toml | 2 +- .../websocket/src/utcp_websocket/_security.py | 19 ++- 10 files changed, 280 insertions(+), 44 deletions(-) diff --git a/plugins/communication_protocols/gql/pyproject.toml b/plugins/communication_protocols/gql/pyproject.toml index 3a9010f..5327c6b 100644 --- a/plugins/communication_protocols/gql/pyproject.toml +++ b/plugins/communication_protocols/gql/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "utcp-gql" -version = "1.1.3" +version = "1.1.4" authors = [ { name = "UTCP Contributors" }, ] diff --git a/plugins/communication_protocols/gql/src/utcp_gql/_security.py b/plugins/communication_protocols/gql/src/utcp_gql/_security.py index 089271c..98a4c53 100644 --- a/plugins/communication_protocols/gql/src/utcp_gql/_security.py +++ b/plugins/communication_protocols/gql/src/utcp_gql/_security.py @@ -220,7 +220,10 @@ def _same_origin(a: str, b: str) -> bool: return False -def _scrub_cross_origin_credentials(kwargs: dict) -> None: +def _scrub_cross_origin_credentials( + kwargs: dict, + extra_auth_header_names: Optional[frozenset] = None, +) -> None: """Strip auth-bearing kwargs in place when crossing origins. Mirrors ``utcp_http._security._scrub_cross_origin_credentials`` -- @@ -229,12 +232,15 @@ def _scrub_cross_origin_credentials(kwargs: dict) -> None: ``data``) so 307/308 redirects cannot resend an OAuth POST body to a new origin. """ + extra = extra_auth_header_names or frozenset() headers = kwargs.get("headers") if headers is not None: scrubbed: Dict[str, Any] = {} for k, v in dict(headers).items(): if _header_is_auth_sensitive(k): continue + if isinstance(k, str) and k.lower() in extra: + continue scrubbed[k] = v kwargs["headers"] = scrubbed @@ -254,6 +260,7 @@ async def safe_request_with_redirects( *, context: str, max_redirects: int = 5, + auth_header_names: Optional[Any] = None, **kwargs: Any, ) -> AsyncIterator[Any]: """Issue an aiohttp request that re-validates every redirect hop. @@ -291,6 +298,12 @@ async def safe_request_with_redirects( # We control redirect behavior ourselves; refuse to let callers override. kwargs.pop("allow_redirects", None) + extra_auth_header_names = frozenset( + n.lower() + for n in (auth_header_names or []) + if isinstance(n, str) + ) + current_url = url current_method = method hops = 0 @@ -335,7 +348,9 @@ async def safe_request_with_redirects( # Strip auth-bearing kwargs on cross-origin redirect. if not _same_origin(current_url, next_url): - _scrub_cross_origin_credentials(kwargs) + _scrub_cross_origin_credentials( + kwargs, extra_auth_header_names + ) if response.status == 303: current_method = "GET" diff --git a/plugins/communication_protocols/http/pyproject.toml b/plugins/communication_protocols/http/pyproject.toml index 8e35fa4..9d754a2 100644 --- a/plugins/communication_protocols/http/pyproject.toml +++ b/plugins/communication_protocols/http/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "utcp-http" -version = "1.1.6" +version = "1.1.7" authors = [ { name = "UTCP Contributors" }, ] diff --git a/plugins/communication_protocols/http/src/utcp_http/_security.py b/plugins/communication_protocols/http/src/utcp_http/_security.py index 602b612..5e431cc 100644 --- a/plugins/communication_protocols/http/src/utcp_http/_security.py +++ b/plugins/communication_protocols/http/src/utcp_http/_security.py @@ -284,7 +284,10 @@ def _same_origin(a: str, b: str) -> bool: return False -def _scrub_cross_origin_credentials(kwargs: dict) -> None: +def _scrub_cross_origin_credentials( + kwargs: dict, + extra_auth_header_names: Optional[frozenset] = None, +) -> None: """Strip auth-bearing kwargs in place when crossing origins. Aligns the redirect helper with browser / requests / curl @@ -314,6 +317,7 @@ def _scrub_cross_origin_credentials(kwargs: dict) -> None: Callers invoke this BEFORE issuing the next hop, only when the redirect target's origin differs from the current URL's origin. """ + extra = extra_auth_header_names or frozenset() headers = kwargs.get("headers") if headers is not None: # Build a new dict so we never mutate the caller's headers @@ -322,6 +326,11 @@ def _scrub_cross_origin_credentials(kwargs: dict) -> None: for k, v in dict(headers).items(): if _header_is_auth_sensitive(k): continue + # Strip caller-configured custom auth header names + # (e.g. ``ApiKeyAuth`` with ``var_name="X-MyApp"``) + # that don't match the auth-pattern regex on their own. + if isinstance(k, str) and k.lower() in extra: + continue scrubbed[k] = v kwargs["headers"] = scrubbed @@ -351,6 +360,7 @@ async def safe_request_with_redirects( *, context: str, max_redirects: int = 5, + auth_header_names: Optional[Any] = None, **kwargs: Any, ) -> AsyncIterator[Any]: """Issue an aiohttp request that re-validates every redirect hop. @@ -388,6 +398,22 @@ async def safe_request_with_redirects( # We control redirect behavior ourselves; refuse to let callers override. kwargs.pop("allow_redirects", None) + # Pull caller-configured auth header names so the cross-origin + # scrub can strip them too. Private contract -- callers attach + # this via ``_apply_auth`` to declare which header names they + # populated with a secret. Never sent on the wire. + # ``auth_header_names`` is the explicit declaration of which + # header names the caller populated with a secret. Used to extend + # the cross-origin scrub beyond the canonical set so a + # custom-named API-key header (e.g. ``X-MyApp``) configured via + # ``ApiKeyAuth`` / ``OAuth2UserAuth`` is also stripped on + # cross-origin redirect. + extra_auth_header_names = frozenset( + n.lower() + for n in (auth_header_names or []) + if isinstance(n, str) + ) + current_url = url current_method = method hops = 0 @@ -437,7 +463,9 @@ async def safe_request_with_redirects( # API key would be forwarded along. Mirrors browser / # requests / curl behaviour. if not _same_origin(current_url, next_url): - _scrub_cross_origin_credentials(kwargs) + _scrub_cross_origin_credentials( + kwargs, extra_auth_header_names + ) if response.status == 303: current_method = "GET" diff --git a/plugins/communication_protocols/http/src/utcp_http/http_communication_protocol.py b/plugins/communication_protocols/http/src/utcp_http/http_communication_protocol.py index bab7214..8a3a187 100644 --- a/plugins/communication_protocols/http/src/utcp_http/http_communication_protocol.py +++ b/plugins/communication_protocols/http/src/utcp_http/http_communication_protocol.py @@ -76,20 +76,45 @@ def __init__(self, logger: Optional[Callable[[str], None]] = None): self._session: Optional[aiohttp.ClientSession] = None self._oauth_tokens: Dict[str, Dict[str, Any]] = {} + @staticmethod + def _assert_no_crlf(value: Optional[str], field_name: str) -> None: + """Refuse CR/LF in attacker-influenceable strings that will land + in HTTP headers. aiohttp blocks these at request time, but + keeping the trust boundary inside UTCP means a transport swap + cannot silently regress. + """ + if not isinstance(value, str): + return + if "\r" in value or "\n" in value: + raise ValueError( + f"Refusing to construct request: {field_name} contains CR/LF, " + f"which would enable HTTP header injection." + ) + def _apply_auth(self, provider: HttpCallTemplate, headers: Dict[str, str], query_params: Dict[str, Any]) -> tuple: """Apply authentication to the request based on the provider's auth configuration. - + Returns: - tuple: (auth_obj, cookies) where auth_obj is for aiohttp basic auth and cookies is a dict + tuple ``(auth_obj, cookies, auth_header_names)``: + * ``auth_obj``: aiohttp BasicAuth for HTTP basic, or None. + * ``cookies``: dict of cookies to attach. + * ``auth_header_names``: list of header names that + received a secret. Threaded into + ``safe_request_with_redirects`` so a custom-named auth + header (e.g. ``ApiKeyAuth(var_name="X-MyApp")``) is + also stripped on cross-origin redirect. """ auth = None cookies = {} - + auth_header_names: List[str] = [] + if provider.auth: if isinstance(provider.auth, ApiKeyAuth): if provider.auth.api_key: + self._assert_no_crlf(provider.auth.var_name, "ApiKeyAuth.var_name") if provider.auth.location == "header": headers[provider.auth.var_name] = provider.auth.api_key + auth_header_names.append(provider.auth.var_name) elif provider.auth.location == "query": query_params[provider.auth.var_name] = provider.auth.api_key elif provider.auth.location == "cookie": @@ -97,16 +122,19 @@ def _apply_auth(self, provider: HttpCallTemplate, headers: Dict[str, str], query else: logger.error("API key not found for ApiKeyAuth.") raise ValueError("API key for ApiKeyAuth not found.") - + elif isinstance(provider.auth, BasicAuth): auth = AiohttpBasicAuth(provider.auth.username, provider.auth.password) - + elif isinstance(provider.auth, OAuth2Auth): # OAuth2 tokens are always sent in the Authorization header - # We'll handle this separately since it requires async token retrieval - pass - - return auth, cookies + # We'll handle this separately since it requires async token retrieval. + # We DO declare ``Authorization`` here so the scrubber treats + # the resulting bearer header as cross-origin-sensitive even + # if it slipped past the regex. + auth_header_names.append("Authorization") + + return auth, cookies, auth_header_names async def register_manual(self, caller, manual_call_template: CallTemplate) -> RegisterManualResult: """REQUIRED @@ -136,7 +164,7 @@ async def register_manual(self, caller, manual_call_template: CallTemplate) -> R query_params = {} # Handle authentication - auth, cookies = self._apply_auth(manual_call_template, request_headers, query_params) + auth, cookies, auth_header_names = self._apply_auth(manual_call_template, request_headers, query_params) # Handle OAuth2 separately since it requires async token retrieval if manual_call_template.auth and isinstance(manual_call_template.auth, OAuth2Auth): @@ -180,6 +208,7 @@ async def register_manual(self, caller, manual_call_template: CallTemplate) -> R data=data, cookies=cookies, timeout=aiohttp.ClientTimeout(total=10.0), + auth_header_names=auth_header_names, ) as response: response.raise_for_status() # Raise exception for 4XX/5XX responses @@ -288,7 +317,7 @@ async def call_tool(self, caller, tool_name: str, tool_args: Dict[str, Any], too query_params = remaining_args # Handle authentication - auth, cookies = self._apply_auth(tool_call_template, request_headers, query_params) + auth, cookies, auth_header_names = self._apply_auth(tool_call_template, request_headers, query_params) # Handle OAuth2 separately since it requires async token retrieval if tool_call_template.auth and isinstance(tool_call_template.auth, OAuth2Auth): @@ -328,6 +357,7 @@ async def call_tool(self, caller, tool_name: str, tool_args: Dict[str, Any], too data=data, cookies=cookies, timeout=aiohttp.ClientTimeout(total=30.0), + auth_header_names=auth_header_names, ) as response: response.raise_for_status() diff --git a/plugins/communication_protocols/http/src/utcp_http/sse_communication_protocol.py b/plugins/communication_protocols/http/src/utcp_http/sse_communication_protocol.py index f742a8a..83afac2 100644 --- a/plugins/communication_protocols/http/src/utcp_http/sse_communication_protocol.py +++ b/plugins/communication_protocols/http/src/utcp_http/sse_communication_protocol.py @@ -38,20 +38,33 @@ class SseCommunicationProtocol(CommunicationProtocol): def __init__(self, logger: Optional[Callable[[str], None]] = None): self._oauth_tokens: Dict[str, Dict[str, Any]] = {} + @staticmethod + def _assert_no_crlf(value: Optional[str], field_name: str) -> None: + if not isinstance(value, str): + return + if "\r" in value or "\n" in value: + raise ValueError( + f"Refusing to construct request: {field_name} contains CR/LF, " + f"which would enable HTTP header injection." + ) + def _apply_auth(self, provider: SseCallTemplate, headers: Dict[str, str], query_params: Dict[str, Any]) -> tuple: """Apply authentication to the request based on the provider's auth configuration. - + Returns: - tuple: (auth_obj, cookies) where auth_obj is for aiohttp basic auth and cookies is a dict + tuple ``(auth_obj, cookies, auth_header_names)``. """ auth = None cookies = {} - + auth_header_names: List[str] = [] + if provider.auth: if isinstance(provider.auth, ApiKeyAuth): if provider.auth.api_key: + self._assert_no_crlf(provider.auth.var_name, "ApiKeyAuth.var_name") if provider.auth.location == "header": headers[provider.auth.var_name] = provider.auth.api_key + auth_header_names.append(provider.auth.var_name) elif provider.auth.location == "query": query_params[provider.auth.var_name] = provider.auth.api_key elif provider.auth.location == "cookie": @@ -59,16 +72,16 @@ def _apply_auth(self, provider: SseCallTemplate, headers: Dict[str, str], query_ else: logger.error("API key not found for ApiKeyAuth.") raise ValueError("API key for ApiKeyAuth not found.") - + elif isinstance(provider.auth, BasicAuth): auth = AiohttpBasicAuth(provider.auth.username, provider.auth.password) - + elif isinstance(provider.auth, OAuth2Auth): - # OAuth2 tokens are always sent in the Authorization header - # We'll handle this separately since it requires async token retrieval - pass - - return auth, cookies + # OAuth2 tokens are always sent in the Authorization header. + # Declared so cross-origin scrub recognises it. + auth_header_names.append("Authorization") + + return auth, cookies, auth_header_names async def register_manual(self, caller, manual_call_template: CallTemplate) -> RegisterManualResult: """REQUIRED @@ -90,7 +103,7 @@ async def register_manual(self, caller, manual_call_template: CallTemplate) -> R # Handle authentication query_params: Dict[str, Any] = {} - auth, cookies = self._apply_auth(manual_call_template, request_headers, query_params) + auth, cookies, auth_header_names = self._apply_auth(manual_call_template, request_headers, query_params) # Handle OAuth2 separately as it's async if isinstance(manual_call_template.auth, OAuth2Auth): @@ -133,6 +146,7 @@ async def register_manual(self, caller, manual_call_template: CallTemplate) -> R json=json_data, data=data, timeout=aiohttp.ClientTimeout(total=10.0), + auth_header_names=auth_header_names, ) as response: response.raise_for_status() response_data = await response.json() @@ -199,7 +213,11 @@ async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, query_params = remaining_args # Handle authentication - auth, cookies = self._apply_auth(tool_call_template, request_headers, query_params) + # ``auth_header_names`` unused in the streaming path because + # SSE handshake uses ``allow_redirects=False`` -- there is no + # redirect chain to scrub. Reserved for future use if + # streaming ever supports per-hop validation. + auth, cookies, _auth_header_names = self._apply_auth(tool_call_template, request_headers, query_params) # Handle OAuth2 separately as it's async if isinstance(tool_call_template.auth, OAuth2Auth): diff --git a/plugins/communication_protocols/http/src/utcp_http/streamable_http_communication_protocol.py b/plugins/communication_protocols/http/src/utcp_http/streamable_http_communication_protocol.py index 668735c..fde6cb5 100644 --- a/plugins/communication_protocols/http/src/utcp_http/streamable_http_communication_protocol.py +++ b/plugins/communication_protocols/http/src/utcp_http/streamable_http_communication_protocol.py @@ -35,20 +35,33 @@ class StreamableHttpCommunicationProtocol(CommunicationProtocol): def __init__(self): self._oauth_tokens: Dict[str, Dict[str, Any]] = {} + @staticmethod + def _assert_no_crlf(value: Optional[str], field_name: str) -> None: + if not isinstance(value, str): + return + if "\r" in value or "\n" in value: + raise ValueError( + f"Refusing to construct request: {field_name} contains CR/LF, " + f"which would enable HTTP header injection." + ) + def _apply_auth(self, provider: StreamableHttpCallTemplate, headers: Dict[str, str], query_params: Dict[str, Any]) -> tuple: """Apply authentication to the request based on the provider's auth configuration. - + Returns: - tuple: (auth_obj, cookies) where auth_obj is for aiohttp basic auth and cookies is a dict + tuple ``(auth_obj, cookies, auth_header_names)``. """ auth = None cookies = {} - + auth_header_names: List[str] = [] + if provider.auth: if isinstance(provider.auth, ApiKeyAuth): if provider.auth.api_key: + self._assert_no_crlf(provider.auth.var_name, "ApiKeyAuth.var_name") if provider.auth.location == "header": headers[provider.auth.var_name] = provider.auth.api_key + auth_header_names.append(provider.auth.var_name) elif provider.auth.location == "query": query_params[provider.auth.var_name] = provider.auth.api_key elif provider.auth.location == "cookie": @@ -56,16 +69,14 @@ def _apply_auth(self, provider: StreamableHttpCallTemplate, headers: Dict[str, s else: logger.error("API key not found for ApiKeyAuth.") raise ValueError("API key for ApiKeyAuth not found.") - + elif isinstance(provider.auth, BasicAuth): auth = AiohttpBasicAuth(provider.auth.username, provider.auth.password) - + elif isinstance(provider.auth, OAuth2Auth): - # OAuth2 tokens are always sent in the Authorization header - # We'll handle this separately since it requires async token retrieval - pass - - return auth, cookies + auth_header_names.append("Authorization") + + return auth, cookies, auth_header_names async def close(self): """Close all active connections and clear internal state.""" @@ -93,7 +104,7 @@ async def register_manual(self, caller, manual_call_template: CallTemplate) -> R # Handle authentication query_params: Dict[str, Any] = {} - auth, cookies = self._apply_auth(manual_call_template, request_headers, query_params) + auth, cookies, auth_header_names = self._apply_auth(manual_call_template, request_headers, query_params) # Handle OAuth2 separately as it's async if isinstance(manual_call_template.auth, OAuth2Auth): @@ -136,6 +147,7 @@ async def register_manual(self, caller, manual_call_template: CallTemplate) -> R json=json_data, data=data, timeout=aiohttp.ClientTimeout(total=10.0), + auth_header_names=auth_header_names, ) as response: response.raise_for_status() response_data = await response.json() @@ -228,7 +240,9 @@ async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, query_params = remaining_args # Handle authentication - auth_handler, cookies = self._apply_auth(tool_call_template, request_headers, query_params) + # ``auth_header_names`` unused here -- streaming handshake uses + # ``allow_redirects=False`` (no redirect chain to scrub). + auth_handler, cookies, _auth_header_names = self._apply_auth(tool_call_template, request_headers, query_params) # Handle OAuth2 separately as it's async if isinstance(tool_call_template.auth, OAuth2Auth): diff --git a/plugins/communication_protocols/http/tests/test_redirect_security.py b/plugins/communication_protocols/http/tests/test_redirect_security.py index 178d36d..cbd7d8a 100644 --- a/plugins/communication_protocols/http/tests/test_redirect_security.py +++ b/plugins/communication_protocols/http/tests/test_redirect_security.py @@ -82,6 +82,88 @@ def test_malformed_port_returns_false_not_raise(self, a: str, b: str) -> None: assert _same_origin(a, b) is False +class TestAuthHeaderNamesThreadedToScrub: + """Pin the explicit ``auth_header_names`` channel that lets the + redirect helper strip caller-configured auth header names (e.g. + ``ApiKeyAuth(var_name="X-MyApp")``) on cross-origin redirect -- + such names don't match the auth-pattern regex and would otherwise + survive the scrub. + """ + + @pytest.mark.asyncio + async def test_custom_auth_header_name_stripped_via_explicit_list( + self, aiohttp_server + ) -> None: + captured: dict = {} + + async def _capture(request: web.Request) -> web.Response: + captured["x_myapp"] = request.headers.get("X-MyApp") + return web.json_response({"ok": True}) + + target_app = web.Application() + target_app.router.add_get("/landed", _capture) + target = await aiohttp_server(target_app) + + async def _redirect(request: web.Request) -> web.Response: + raise web.HTTPFound(str(target.make_url("/landed"))) + + attacker_app = web.Application() + attacker_app.router.add_get("/tool", _redirect) + attacker = await aiohttp_server(attacker_app) + + import aiohttp + + async with aiohttp.ClientSession() as session: + async with safe_request_with_redirects( + session, + "GET", + str(attacker.make_url("/tool")), + context="tool invocation", + headers={"X-MyApp": "secret-key"}, + auth_header_names=["X-MyApp"], + ): + pass + + assert captured["x_myapp"] is None, ( + "Caller-declared auth header name leaked cross-origin -- the " + "explicit auth_header_names channel did not extend the scrub." + ) + + @pytest.mark.asyncio + async def test_custom_auth_header_kept_on_same_origin( + self, aiohttp_server + ) -> None: + captured: dict = {} + + async def _capture(request: web.Request) -> web.Response: + captured["x_myapp"] = request.headers.get("X-MyApp") + return web.json_response({"ok": True}) + + async def _redirect(request: web.Request) -> web.Response: + raise web.HTTPFound("/landed") + + app = web.Application() + app.router.add_get("/start", _redirect) + app.router.add_get("/landed", _capture) + server = await aiohttp_server(app) + + import aiohttp + + async with aiohttp.ClientSession() as session: + async with safe_request_with_redirects( + session, + "GET", + str(server.make_url("/start")), + context="tool invocation", + headers={"X-MyApp": "secret-key"}, + auth_header_names=["X-MyApp"], + ): + pass + + # Same-origin: keep the auth header. + assert captured["x_myapp"] == "secret-key" + + class TestAuthHeaderClassifier: """Direct unit tests for ``_header_is_auth_sensitive``. The cross-origin integration tests cover the end-to-end scrub @@ -507,6 +589,40 @@ async def _redirect(request: web.Request) -> web.Response: # --------------------------------------------------------------------------- +class TestCrlfHeaderInjectionRejected: + """Defense-in-depth: refuse CR/LF in attacker-influenceable strings + that will land in HTTP headers. aiohttp blocks these at request + time, but enforcing inside UTCP means a transport swap cannot + silently regress. + """ + + @pytest.mark.parametrize( + "var_name", + [ + "X-Api-Key\r\nX-Injected: yes", + "X-Foo\nX-Other: bar", + "X-Foo\rX-Other: bar", + ], + ) + @pytest.mark.asyncio + async def test_apikey_with_crlf_var_name_rejected(self, var_name: str) -> None: + from utcp.data.auth_implementations.api_key_auth import ApiKeyAuth + + proto = HttpCommunicationProtocol() + tpl = HttpCallTemplate( + name="x", + url="https://api.example.com/x", + http_method="GET", + auth=ApiKeyAuth( + api_key="secret", + var_name=var_name, + location="header", + ), + ) + with pytest.raises(ValueError, match="CR/LF"): + await proto.call_tool(None, "x", {}, tpl) + + class TestOAuth2TokenUrlValidation: @pytest.mark.asyncio async def test_internal_token_url_rejected_at_runtime(self) -> None: diff --git a/plugins/communication_protocols/websocket/pyproject.toml b/plugins/communication_protocols/websocket/pyproject.toml index 05aed34..5e3c521 100644 --- a/plugins/communication_protocols/websocket/pyproject.toml +++ b/plugins/communication_protocols/websocket/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "utcp-websocket" -version = "1.1.3" +version = "1.1.4" authors = [ { name = "UTCP Contributors" }, ] diff --git a/plugins/communication_protocols/websocket/src/utcp_websocket/_security.py b/plugins/communication_protocols/websocket/src/utcp_websocket/_security.py index 42b1862..74714f0 100644 --- a/plugins/communication_protocols/websocket/src/utcp_websocket/_security.py +++ b/plugins/communication_protocols/websocket/src/utcp_websocket/_security.py @@ -274,17 +274,23 @@ def _same_origin(a: str, b: str) -> bool: return False -def _scrub_cross_origin_credentials(kwargs: dict) -> None: +def _scrub_cross_origin_credentials( + kwargs: dict, + extra_auth_header_names: Optional[frozenset] = None, +) -> None: """Strip auth-bearing kwargs in place when crossing origins. Mirrors ``utcp_http._security._scrub_cross_origin_credentials``. """ + extra = extra_auth_header_names or frozenset() headers = kwargs.get("headers") if headers is not None: scrubbed: Dict[str, Any] = {} for k, v in dict(headers).items(): if _header_is_auth_sensitive(k): continue + if isinstance(k, str) and k.lower() in extra: + continue scrubbed[k] = v kwargs["headers"] = scrubbed @@ -304,6 +310,7 @@ async def safe_request_with_redirects( *, context: str, max_redirects: int = 5, + auth_header_names: Optional[Any] = None, **kwargs: Any, ) -> AsyncIterator[Any]: """Issue an aiohttp request that re-validates every redirect hop. @@ -341,6 +348,12 @@ async def safe_request_with_redirects( # We control redirect behavior ourselves; refuse to let callers override. kwargs.pop("allow_redirects", None) + extra_auth_header_names = frozenset( + n.lower() + for n in (auth_header_names or []) + if isinstance(n, str) + ) + current_url = url current_method = method hops = 0 @@ -385,7 +398,9 @@ async def safe_request_with_redirects( # Strip auth-bearing kwargs on cross-origin redirect. if not _same_origin(current_url, next_url): - _scrub_cross_origin_credentials(kwargs) + _scrub_cross_origin_credentials( + kwargs, extra_auth_header_names + ) if response.status == 303: current_method = "GET"