From 83485d92cb8248bbb2dbebf157dca1f28ff011a3 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Sun, 14 Jun 2026 22:40:57 -0400 Subject: [PATCH] Resolve serialized exception nodes without importing the stored class name When deserializing AIRFLOW_EXC_SER / BASE_EXC_SER nodes, BaseSerialization resolved the exception class with import_string() on a name taken from the serialized blob. Resolve it against in-memory classes instead, so a stored DAG never imports a class named in the blob: - AIRFLOW_EXC_SER: look the name up in a map of loaded AirflowException subclasses, built once from the in-memory subclass tree; a name that is not a registered AirflowException subclass is rejected. - BASE_EXC_SER: resolve against the fixed {KeyError, AttributeError} set that the encoder is the only producer of. Unknown or disallowed names raise DeserializationError instead of being imported. The trigger-node branch is handled separately. Generated-by: Claude Opus 4.8 following the guidelines at https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst#gen-ai-assisted-contributions --- .../serialization/serialized_objects.py | 61 +++++++++++++++++-- .../serialization/test_dag_serialization.py | 58 ++++++++++++++++++ 2 files changed, 115 insertions(+), 4 deletions(-) diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index 4ba0bbefdce7a..03ffd7de4d95e 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -29,7 +29,7 @@ import math import sys import weakref -from collections.abc import Collection, Iterable, Mapping +from collections.abc import Collection, Iterable, Iterator, Mapping from functools import cache, cached_property, lru_cache from inspect import signature from textwrap import dedent @@ -41,7 +41,7 @@ from dateutil import relativedelta from pendulum.tz.timezone import FixedTimezone, Timezone -from airflow._shared.module_loading import import_string, qualname +from airflow._shared.module_loading import qualname from airflow._shared.timezones.timezone import from_timestamp, parse_timezone, utcnow from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest from airflow.exceptions import AirflowException, DeserializationError, SerializationError @@ -236,6 +236,54 @@ def _decode_priority_weight_strategy(var: str) -> PriorityWeightStrategy: return priority_weight_strategy_class() +# Builtin exceptions the serializer emits as ``BASE_EXC_SER``. Only these are ever +# serialized (see the encode side), so deserialization resolves the stored name against +# this fixed map instead of importing it -- ``builtins.eval`` / ``builtins.exec`` and any +# other name are rejected without importing anything. +_DESERIALIZABLE_BUILTIN_EXCEPTIONS: dict[str, type[BaseException]] = { + "KeyError": KeyError, + "AttributeError": AttributeError, +} + + +def _iter_subclasses(cls: type) -> Iterator[type]: + """Yield every (transitive) subclass of ``cls``.""" + for sub in cls.__subclasses__(): + yield sub + yield from _iter_subclasses(sub) + + +@cache +def _serializable_airflow_exceptions() -> dict[str, type[AirflowException]]: + """ + Map ``"." -> AirflowException subclass``, used to resolve ``AIRFLOW_EXC_SER`` nodes. + + Built once, from the in-memory ``AirflowException`` subclass tree (never from the + attacker-controlled stored name), and never rebuilt -- a name absent from it is rejected, not + imported. ``airflow.exceptions`` is imported by this module, so every built-in Airflow exception + is registered by the time this is first called; exceptions defined later are not added. + """ + return { + f"{cls.__module__}.{cls.__name__}": cls + for cls in (AirflowException, *_iter_subclasses(AirflowException)) + } + + +def _resolve_airflow_exception(exc_cls_name: str) -> type[AirflowException]: + """ + Resolve a serialized ``AirflowException`` class name to the loaded class, without importing it. + + The name is matched against the once-built ``AirflowException`` subclass map, so deserializing a + stored DAG never runs the top-level code of a module named in the blob. A name that is not a + registered ``AirflowException`` subclass -- e.g. an attacker's ``subprocess.check_output`` -- is + rejected rather than imported. + """ + exc_cls = _serializable_airflow_exceptions().get(exc_cls_name) + if exc_cls is None: + raise DeserializationError(f"Refusing to deserialize unknown exception class {exc_cls_name!r}") + return exc_cls + + def _encode_start_trigger_args(var: StartTriggerArgs) -> dict[str, Any]: """Encode a StartTriggerArgs.""" @@ -658,9 +706,14 @@ def deserialize(cls, encoded_var: Any) -> Any: kwargs = deser["kwargs"] del deser if type_ == DAT.AIRFLOW_EXC_SER: - exc_cls = import_string(exc_cls_name) + exc_cls: type[BaseException] = _resolve_airflow_exception(exc_cls_name) else: - exc_cls = import_string(f"builtins.{exc_cls_name}") + builtin_exc_cls = _DESERIALIZABLE_BUILTIN_EXCEPTIONS.get(exc_cls_name) + if builtin_exc_cls is None: + raise DeserializationError( + f"Refusing to deserialize disallowed builtin exception {exc_cls_name!r}" + ) + exc_cls = builtin_exc_cls return exc_cls(*args, **kwargs) elif type_ == DAT.SET: return {cls.deserialize(v) for v in var} diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index 07459a2011711..80762d6c45e47 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -2790,6 +2790,64 @@ def test_create_dagrun_accepts_partition_key_for_partitioned_at_runtime_dag(self dr = dag_maker.create_dagrun(partition_key="runtime-key") assert dr.partition_key == "runtime-key" + def test_airflow_exc_deserialization_rejects_unknown_class(self): + """An AIRFLOW_EXC_SER name that is not a loaded AirflowException subclass is rejected. + + The name is resolved against the in-memory subclass tree, so an attacker-controlled + ``subprocess.check_output`` is never imported. + """ + from airflow.exceptions import DeserializationError + from airflow.serialization.enums import DagAttributeTypes + + encoded = BaseSerialization._encode( + BaseSerialization.serialize( + {"exc_cls_name": "subprocess.check_output", "args": [], "kwargs": {}} + ), + type_=DagAttributeTypes.AIRFLOW_EXC_SER, + ) + with pytest.raises(DeserializationError, match="Refusing to deserialize unknown exception class"): + BaseSerialization.deserialize(encoded) + + def test_airflow_exc_deserialization_roundtrips_airflow_exception(self): + """A genuine AirflowException subclass round-trips via the registry, without importing.""" + from airflow.exceptions import AirflowException + + result = BaseSerialization.deserialize(BaseSerialization.serialize(AirflowException("boom"))) + assert isinstance(result, AirflowException) + assert result.args == ("boom",) + + def test_base_exc_deserialization_rejects_non_allowlisted_builtin(self): + """A BASE_EXC_SER name outside the {KeyError, AttributeError} the encoder emits is rejected.""" + from airflow.exceptions import DeserializationError + from airflow.serialization.enums import DagAttributeTypes + + # ``eval`` is the weaponisable case; ``ValueError`` is a harmless builtin the encoder + # never emits as BASE_EXC_SER -- both must be rejected. + for name in ("eval", "ValueError"): + encoded = BaseSerialization._encode( + BaseSerialization.serialize({"exc_cls_name": name, "args": ["1"], "kwargs": {}}), + type_=DagAttributeTypes.BASE_EXC_SER, + ) + with pytest.raises( + DeserializationError, match="Refusing to deserialize disallowed builtin exception" + ): + BaseSerialization.deserialize(encoded) + + def test_base_exc_deserialization_roundtrips_builtin_exception(self): + """The builtin exceptions the encoder emits (KeyError / AttributeError) still deserialize.""" + from airflow.serialization.enums import DagAttributeTypes + + for exc_type in (KeyError, AttributeError): + encoded = BaseSerialization._encode( + BaseSerialization.serialize( + {"exc_cls_name": exc_type.__name__, "args": ["boom"], "kwargs": {}} + ), + type_=DagAttributeTypes.BASE_EXC_SER, + ) + result = BaseSerialization.deserialize(encoded) + assert isinstance(result, exc_type) + assert result.args == ("boom",) + def test_kubernetes_optional(): """Test that serialization module loads without kubernetes, but deserialization of PODs requires it"""