Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 57 additions & 4 deletions airflow-core/src/airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import math
import sys
import weakref
from collections.abc import Collection, Iterable, Mapping
from collections.abc import Collection, Iterable, Iterator, Mapping
from functools import cache, cached_property, lru_cache
from inspect import signature
from textwrap import dedent
Expand All @@ -41,7 +41,7 @@
from dateutil import relativedelta
from pendulum.tz.timezone import FixedTimezone, Timezone

from airflow._shared.module_loading import import_string, qualname
from airflow._shared.module_loading import qualname
from airflow._shared.timezones.timezone import from_timestamp, parse_timezone, utcnow
from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest
from airflow.exceptions import AirflowException, DeserializationError, SerializationError
Expand Down Expand Up @@ -236,6 +236,54 @@ def _decode_priority_weight_strategy(var: str) -> PriorityWeightStrategy:
return priority_weight_strategy_class()


# Builtin exceptions the serializer emits as ``BASE_EXC_SER``. Only these are ever
# serialized (see the encode side), so deserialization resolves the stored name against
# this fixed map instead of importing it -- ``builtins.eval`` / ``builtins.exec`` and any
# other name are rejected without importing anything.
_DESERIALIZABLE_BUILTIN_EXCEPTIONS: dict[str, type[BaseException]] = {
"KeyError": KeyError,
"AttributeError": AttributeError,
}


def _iter_subclasses(cls: type) -> Iterator[type]:
"""Yield every (transitive) subclass of ``cls``."""
for sub in cls.__subclasses__():
yield sub
yield from _iter_subclasses(sub)


@cache
def _serializable_airflow_exceptions() -> dict[str, type[AirflowException]]:
"""
Map ``"<module>.<name>" -> AirflowException subclass``, used to resolve ``AIRFLOW_EXC_SER`` nodes.

Built once, from the in-memory ``AirflowException`` subclass tree (never from the
attacker-controlled stored name), and never rebuilt -- a name absent from it is rejected, not
imported. ``airflow.exceptions`` is imported by this module, so every built-in Airflow exception
is registered by the time this is first called; exceptions defined later are not added.
"""
return {
f"{cls.__module__}.{cls.__name__}": cls
for cls in (AirflowException, *_iter_subclasses(AirflowException))
}


def _resolve_airflow_exception(exc_cls_name: str) -> type[AirflowException]:
"""
Resolve a serialized ``AirflowException`` class name to the loaded class, without importing it.

The name is matched against the once-built ``AirflowException`` subclass map, so deserializing a
stored DAG never runs the top-level code of a module named in the blob. A name that is not a
registered ``AirflowException`` subclass -- e.g. an attacker's ``subprocess.check_output`` -- is
rejected rather than imported.
"""
exc_cls = _serializable_airflow_exceptions().get(exc_cls_name)
if exc_cls is None:
raise DeserializationError(f"Refusing to deserialize unknown exception class {exc_cls_name!r}")
return exc_cls


def _encode_start_trigger_args(var: StartTriggerArgs) -> dict[str, Any]:
"""Encode a StartTriggerArgs."""

Expand Down Expand Up @@ -658,9 +706,14 @@ def deserialize(cls, encoded_var: Any) -> Any:
kwargs = deser["kwargs"]
del deser
if type_ == DAT.AIRFLOW_EXC_SER:
exc_cls = import_string(exc_cls_name)
exc_cls: type[BaseException] = _resolve_airflow_exception(exc_cls_name)
else:
exc_cls = import_string(f"builtins.{exc_cls_name}")
builtin_exc_cls = _DESERIALIZABLE_BUILTIN_EXCEPTIONS.get(exc_cls_name)
if builtin_exc_cls is None:
raise DeserializationError(
f"Refusing to deserialize disallowed builtin exception {exc_cls_name!r}"
)
exc_cls = builtin_exc_cls
return exc_cls(*args, **kwargs)
elif type_ == DAT.SET:
return {cls.deserialize(v) for v in var}
Expand Down
58 changes: 58 additions & 0 deletions airflow-core/tests/unit/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2790,6 +2790,64 @@ def test_create_dagrun_accepts_partition_key_for_partitioned_at_runtime_dag(self
dr = dag_maker.create_dagrun(partition_key="runtime-key")
assert dr.partition_key == "runtime-key"

def test_airflow_exc_deserialization_rejects_unknown_class(self):
"""An AIRFLOW_EXC_SER name that is not a loaded AirflowException subclass is rejected.

The name is resolved against the in-memory subclass tree, so an attacker-controlled
``subprocess.check_output`` is never imported.
"""
from airflow.exceptions import DeserializationError
from airflow.serialization.enums import DagAttributeTypes

encoded = BaseSerialization._encode(
BaseSerialization.serialize(
{"exc_cls_name": "subprocess.check_output", "args": [], "kwargs": {}}
),
type_=DagAttributeTypes.AIRFLOW_EXC_SER,
)
with pytest.raises(DeserializationError, match="Refusing to deserialize unknown exception class"):
BaseSerialization.deserialize(encoded)

def test_airflow_exc_deserialization_roundtrips_airflow_exception(self):
"""A genuine AirflowException subclass round-trips via the registry, without importing."""
from airflow.exceptions import AirflowException

result = BaseSerialization.deserialize(BaseSerialization.serialize(AirflowException("boom")))
assert isinstance(result, AirflowException)
assert result.args == ("boom",)

def test_base_exc_deserialization_rejects_non_allowlisted_builtin(self):
"""A BASE_EXC_SER name outside the {KeyError, AttributeError} the encoder emits is rejected."""
from airflow.exceptions import DeserializationError
from airflow.serialization.enums import DagAttributeTypes

# ``eval`` is the weaponisable case; ``ValueError`` is a harmless builtin the encoder
# never emits as BASE_EXC_SER -- both must be rejected.
for name in ("eval", "ValueError"):
encoded = BaseSerialization._encode(
BaseSerialization.serialize({"exc_cls_name": name, "args": ["1"], "kwargs": {}}),
type_=DagAttributeTypes.BASE_EXC_SER,
)
with pytest.raises(
DeserializationError, match="Refusing to deserialize disallowed builtin exception"
):
BaseSerialization.deserialize(encoded)

def test_base_exc_deserialization_roundtrips_builtin_exception(self):
"""The builtin exceptions the encoder emits (KeyError / AttributeError) still deserialize."""
from airflow.serialization.enums import DagAttributeTypes

for exc_type in (KeyError, AttributeError):
encoded = BaseSerialization._encode(
BaseSerialization.serialize(
{"exc_cls_name": exc_type.__name__, "args": ["boom"], "kwargs": {}}
),
type_=DagAttributeTypes.BASE_EXC_SER,
)
result = BaseSerialization.deserialize(encoded)
assert isinstance(result, exc_type)
assert result.args == ("boom",)


def test_kubernetes_optional():
"""Test that serialization module loads without kubernetes, but deserialization of PODs requires it"""
Expand Down
Loading