Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion plugins/communication_protocols/gql/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "utcp-gql"
version = "1.1.3"
version = "1.1.4"
authors = [
{ name = "UTCP Contributors" },
]
Expand Down
19 changes: 17 additions & 2 deletions plugins/communication_protocols/gql/src/utcp_gql/_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,10 @@ def _same_origin(a: str, b: str) -> bool:
return False


def _scrub_cross_origin_credentials(kwargs: dict) -> None:
def _scrub_cross_origin_credentials(
kwargs: dict,
extra_auth_header_names: Optional[frozenset] = None,
) -> None:
"""Strip auth-bearing kwargs in place when crossing origins.

Mirrors ``utcp_http._security._scrub_cross_origin_credentials`` --
Expand All @@ -229,12 +232,15 @@ def _scrub_cross_origin_credentials(kwargs: dict) -> None:
``data``) so 307/308 redirects cannot resend an OAuth POST body
to a new origin.
"""
extra = extra_auth_header_names or frozenset()
headers = kwargs.get("headers")
if headers is not None:
scrubbed: Dict[str, Any] = {}
for k, v in dict(headers).items():
if _header_is_auth_sensitive(k):
continue
if isinstance(k, str) and k.lower() in extra:
continue
scrubbed[k] = v
kwargs["headers"] = scrubbed

Expand All @@ -254,6 +260,7 @@ async def safe_request_with_redirects(
*,
context: str,
max_redirects: int = 5,
auth_header_names: Optional[Any] = None,
**kwargs: Any,
) -> AsyncIterator[Any]:
"""Issue an aiohttp request that re-validates every redirect hop.
Expand Down Expand Up @@ -291,6 +298,12 @@ async def safe_request_with_redirects(
# We control redirect behavior ourselves; refuse to let callers override.
kwargs.pop("allow_redirects", None)

extra_auth_header_names = frozenset(
n.lower()
for n in (auth_header_names or [])
if isinstance(n, str)
)

current_url = url
current_method = method
hops = 0
Expand Down Expand Up @@ -335,7 +348,9 @@ async def safe_request_with_redirects(

# Strip auth-bearing kwargs on cross-origin redirect.
if not _same_origin(current_url, next_url):
_scrub_cross_origin_credentials(kwargs)
_scrub_cross_origin_credentials(
kwargs, extra_auth_header_names
)

if response.status == 303:
current_method = "GET"
Expand Down
2 changes: 1 addition & 1 deletion plugins/communication_protocols/http/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "utcp-http"
version = "1.1.6"
version = "1.1.7"
authors = [
{ name = "UTCP Contributors" },
]
Expand Down
32 changes: 30 additions & 2 deletions plugins/communication_protocols/http/src/utcp_http/_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,10 @@ def _same_origin(a: str, b: str) -> bool:
return False


def _scrub_cross_origin_credentials(kwargs: dict) -> None:
def _scrub_cross_origin_credentials(
kwargs: dict,
extra_auth_header_names: Optional[frozenset] = None,
) -> None:
"""Strip auth-bearing kwargs in place when crossing origins.

Aligns the redirect helper with browser / requests / curl
Expand Down Expand Up @@ -314,6 +317,7 @@ def _scrub_cross_origin_credentials(kwargs: dict) -> None:
Callers invoke this BEFORE issuing the next hop, only when the
redirect target's origin differs from the current URL's origin.
"""
extra = extra_auth_header_names or frozenset()
headers = kwargs.get("headers")
if headers is not None:
# Build a new dict so we never mutate the caller's headers
Expand All @@ -322,6 +326,11 @@ def _scrub_cross_origin_credentials(kwargs: dict) -> None:
for k, v in dict(headers).items():
if _header_is_auth_sensitive(k):
continue
# Strip caller-configured custom auth header names
# (e.g. ``ApiKeyAuth`` with ``var_name="X-MyApp"``)
# that don't match the auth-pattern regex on their own.
if isinstance(k, str) and k.lower() in extra:
continue
scrubbed[k] = v
kwargs["headers"] = scrubbed

Expand Down Expand Up @@ -351,6 +360,7 @@ async def safe_request_with_redirects(
*,
context: str,
max_redirects: int = 5,
auth_header_names: Optional[Any] = None,
**kwargs: Any,
) -> AsyncIterator[Any]:
"""Issue an aiohttp request that re-validates every redirect hop.
Expand Down Expand Up @@ -388,6 +398,22 @@ async def safe_request_with_redirects(
# We control redirect behavior ourselves; refuse to let callers override.
kwargs.pop("allow_redirects", None)

# Pull caller-configured auth header names so the cross-origin
# scrub can strip them too. Private contract -- callers attach
# this via ``_apply_auth`` to declare which header names they
# populated with a secret. Never sent on the wire.
# ``auth_header_names`` is the explicit declaration of which
# header names the caller populated with a secret. Used to extend
# the cross-origin scrub beyond the canonical set so a
# custom-named API-key header (e.g. ``X-MyApp``) configured via
# ``ApiKeyAuth`` / ``OAuth2UserAuth`` is also stripped on
# cross-origin redirect.
extra_auth_header_names = frozenset(
n.lower()
for n in (auth_header_names or [])
if isinstance(n, str)
)

current_url = url
current_method = method
hops = 0
Expand Down Expand Up @@ -437,7 +463,9 @@ async def safe_request_with_redirects(
# API key would be forwarded along. Mirrors browser /
# requests / curl behaviour.
if not _same_origin(current_url, next_url):
_scrub_cross_origin_credentials(kwargs)
_scrub_cross_origin_credentials(
kwargs, extra_auth_header_names
)

if response.status == 303:
current_method = "GET"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,37 +76,65 @@ def __init__(self, logger: Optional[Callable[[str], None]] = None):
self._session: Optional[aiohttp.ClientSession] = None
self._oauth_tokens: Dict[str, Dict[str, Any]] = {}

@staticmethod
def _assert_no_crlf(value: Optional[str], field_name: str) -> None:
"""Refuse CR/LF in attacker-influenceable strings that will land
in HTTP headers. aiohttp blocks these at request time, but
keeping the trust boundary inside UTCP means a transport swap
cannot silently regress.
"""
if not isinstance(value, str):
return
if "\r" in value or "\n" in value:
raise ValueError(
f"Refusing to construct request: {field_name} contains CR/LF, "
f"which would enable HTTP header injection."
)

def _apply_auth(self, provider: HttpCallTemplate, headers: Dict[str, str], query_params: Dict[str, Any]) -> tuple:
"""Apply authentication to the request based on the provider's auth configuration.

Returns:
tuple: (auth_obj, cookies) where auth_obj is for aiohttp basic auth and cookies is a dict
tuple ``(auth_obj, cookies, auth_header_names)``:
* ``auth_obj``: aiohttp BasicAuth for HTTP basic, or None.
* ``cookies``: dict of cookies to attach.
* ``auth_header_names``: list of header names that
received a secret. Threaded into
``safe_request_with_redirects`` so a custom-named auth
header (e.g. ``ApiKeyAuth(var_name="X-MyApp")``) is
also stripped on cross-origin redirect.
"""
auth = None
cookies = {}

auth_header_names: List[str] = []

if provider.auth:
if isinstance(provider.auth, ApiKeyAuth):
if provider.auth.api_key:
self._assert_no_crlf(provider.auth.var_name, "ApiKeyAuth.var_name")
if provider.auth.location == "header":
headers[provider.auth.var_name] = provider.auth.api_key
auth_header_names.append(provider.auth.var_name)
elif provider.auth.location == "query":
query_params[provider.auth.var_name] = provider.auth.api_key
elif provider.auth.location == "cookie":
cookies[provider.auth.var_name] = provider.auth.api_key
else:
logger.error("API key not found for ApiKeyAuth.")
raise ValueError("API key for ApiKeyAuth not found.")

elif isinstance(provider.auth, BasicAuth):
auth = AiohttpBasicAuth(provider.auth.username, provider.auth.password)

elif isinstance(provider.auth, OAuth2Auth):
# OAuth2 tokens are always sent in the Authorization header
# We'll handle this separately since it requires async token retrieval
pass

return auth, cookies
# We'll handle this separately since it requires async token retrieval.
# We DO declare ``Authorization`` here so the scrubber treats
# the resulting bearer header as cross-origin-sensitive even
# if it slipped past the regex.
auth_header_names.append("Authorization")

return auth, cookies, auth_header_names

async def register_manual(self, caller, manual_call_template: CallTemplate) -> RegisterManualResult:
"""REQUIRED
Expand Down Expand Up @@ -136,7 +164,7 @@ async def register_manual(self, caller, manual_call_template: CallTemplate) -> R
query_params = {}

# Handle authentication
auth, cookies = self._apply_auth(manual_call_template, request_headers, query_params)
auth, cookies, auth_header_names = self._apply_auth(manual_call_template, request_headers, query_params)

# Handle OAuth2 separately since it requires async token retrieval
if manual_call_template.auth and isinstance(manual_call_template.auth, OAuth2Auth):
Expand Down Expand Up @@ -180,6 +208,7 @@ async def register_manual(self, caller, manual_call_template: CallTemplate) -> R
data=data,
cookies=cookies,
timeout=aiohttp.ClientTimeout(total=10.0),
auth_header_names=auth_header_names,
) as response:
response.raise_for_status() # Raise exception for 4XX/5XX responses

Expand Down Expand Up @@ -288,7 +317,7 @@ async def call_tool(self, caller, tool_name: str, tool_args: Dict[str, Any], too
query_params = remaining_args

# Handle authentication
auth, cookies = self._apply_auth(tool_call_template, request_headers, query_params)
auth, cookies, auth_header_names = self._apply_auth(tool_call_template, request_headers, query_params)

# Handle OAuth2 separately since it requires async token retrieval
if tool_call_template.auth and isinstance(tool_call_template.auth, OAuth2Auth):
Expand Down Expand Up @@ -328,6 +357,7 @@ async def call_tool(self, caller, tool_name: str, tool_args: Dict[str, Any], too
data=data,
cookies=cookies,
timeout=aiohttp.ClientTimeout(total=30.0),
auth_header_names=auth_header_names,
) as response:
response.raise_for_status()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,37 +38,50 @@ class SseCommunicationProtocol(CommunicationProtocol):
def __init__(self, logger: Optional[Callable[[str], None]] = None):
self._oauth_tokens: Dict[str, Dict[str, Any]] = {}

@staticmethod
def _assert_no_crlf(value: Optional[str], field_name: str) -> None:
if not isinstance(value, str):
return
if "\r" in value or "\n" in value:
raise ValueError(
f"Refusing to construct request: {field_name} contains CR/LF, "
f"which would enable HTTP header injection."
)

def _apply_auth(self, provider: SseCallTemplate, headers: Dict[str, str], query_params: Dict[str, Any]) -> tuple:
"""Apply authentication to the request based on the provider's auth configuration.

Returns:
tuple: (auth_obj, cookies) where auth_obj is for aiohttp basic auth and cookies is a dict
tuple ``(auth_obj, cookies, auth_header_names)``.
"""
auth = None
cookies = {}

auth_header_names: List[str] = []

if provider.auth:
if isinstance(provider.auth, ApiKeyAuth):
if provider.auth.api_key:
self._assert_no_crlf(provider.auth.var_name, "ApiKeyAuth.var_name")
if provider.auth.location == "header":
headers[provider.auth.var_name] = provider.auth.api_key
auth_header_names.append(provider.auth.var_name)
elif provider.auth.location == "query":
query_params[provider.auth.var_name] = provider.auth.api_key
elif provider.auth.location == "cookie":
cookies[provider.auth.var_name] = provider.auth.api_key
else:
logger.error("API key not found for ApiKeyAuth.")
raise ValueError("API key for ApiKeyAuth not found.")

elif isinstance(provider.auth, BasicAuth):
auth = AiohttpBasicAuth(provider.auth.username, provider.auth.password)

elif isinstance(provider.auth, OAuth2Auth):
# OAuth2 tokens are always sent in the Authorization header
# We'll handle this separately since it requires async token retrieval
pass
return auth, cookies
# OAuth2 tokens are always sent in the Authorization header.
# Declared so cross-origin scrub recognises it.
auth_header_names.append("Authorization")

return auth, cookies, auth_header_names

async def register_manual(self, caller, manual_call_template: CallTemplate) -> RegisterManualResult:
"""REQUIRED
Expand All @@ -90,7 +103,7 @@ async def register_manual(self, caller, manual_call_template: CallTemplate) -> R

# Handle authentication
query_params: Dict[str, Any] = {}
auth, cookies = self._apply_auth(manual_call_template, request_headers, query_params)
auth, cookies, auth_header_names = self._apply_auth(manual_call_template, request_headers, query_params)

# Handle OAuth2 separately as it's async
if isinstance(manual_call_template.auth, OAuth2Auth):
Expand Down Expand Up @@ -133,6 +146,7 @@ async def register_manual(self, caller, manual_call_template: CallTemplate) -> R
json=json_data,
data=data,
timeout=aiohttp.ClientTimeout(total=10.0),
auth_header_names=auth_header_names,
) as response:
response.raise_for_status()
response_data = await response.json()
Expand Down Expand Up @@ -199,7 +213,11 @@ async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str,
query_params = remaining_args

# Handle authentication
auth, cookies = self._apply_auth(tool_call_template, request_headers, query_params)
# ``auth_header_names`` unused in the streaming path because
# SSE handshake uses ``allow_redirects=False`` -- there is no
# redirect chain to scrub. Reserved for future use if
# streaming ever supports per-hop validation.
auth, cookies, _auth_header_names = self._apply_auth(tool_call_template, request_headers, query_params)

# Handle OAuth2 separately as it's async
if isinstance(tool_call_template.auth, OAuth2Auth):
Expand Down
Loading
Loading