diff --git a/providers/common/io/docs/index.rst b/providers/common/io/docs/index.rst index 909d83bd41df5..6b3aae5bf744f 100644 --- a/providers/common/io/docs/index.rst +++ b/providers/common/io/docs/index.rst @@ -38,6 +38,7 @@ Transferring a file Operators Object Storage XCom Backend + Object Storage State Store Backend .. toctree:: :hidden: diff --git a/providers/common/io/docs/state_store_backend.rst b/providers/common/io/docs/state_store_backend.rst new file mode 100644 index 0000000000000..b61134ef95db8 --- /dev/null +++ b/providers/common/io/docs/state_store_backend.rst @@ -0,0 +1,81 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Object Storage State Store Backend +=================================== + +The default state store backend is :class:`~airflow.state.metastore.MetastoreStateBackend`, which persists +task and asset state in the Airflow metadata database via the API Server's Execution API. For larger values, +you may want to store state on object storage directly from the task instead. + +To enable object storage for task and asset state store, set ``state_store_backend`` in the ``[workers]`` +section to ``airflow.providers.common.io.state_store.backend.StateStoreObjectStorageBackend``, and set +``state_store_objectstorage_path`` to the desired base location. The connection id is obtained from the +user part of the URL, e.g. ``state_store_objectstorage_path = s3://conn_id@mybucket/task-state/``. + +Task state is stored under ``////`` and asset state under +``assets//`` beneath the configured base path. + +By default (``state_store_objectstorage_threshold = 0``) all serialized values are offloaded to object storage. +Set ``state_store_objectstorage_threshold`` to a positive number of bytes to only offload values whose +serialized size meets or exceeds the threshold, anything smaller are stored in the Airflow metadata database. + +Optionally set ``state_store_objectstorage_compression`` to an fsspec-supported compression algorithm such as +``gzip`` or ``snappy`` to compress values before writing. + +The following example stores all task and asset state in S3, compressed with gzip:: + + [workers] + state_store_backend = airflow.providers.common.io.state_store.backend.StateStoreObjectStorageBackend + + [common.io] + state_store_objectstorage_path = s3://conn_id@mybucket/task-state/ + state_store_objectstorage_compression = gzip + +To only offload values larger than 1 MB:: + + [workers] + state_store_backend = airflow.providers.common.io.state_store.backend.StateStoreObjectStorageBackend + + [common.io] + state_store_objectstorage_path = s3://conn_id@mybucket/task-state/ + state_store_objectstorage_threshold = 1048576 + +Using the local filesystem (useful for development):: + + [workers] + state_store_backend = airflow.providers.common.io.state_store.backend.StateStoreObjectStorageBackend + + [common.io] + state_store_objectstorage_path = file:///var/airflow/task-state/ + +.. note:: + + Compression requires the relevant library to be installed in your Python environment. + For example, ``snappy`` requires ``python-snappy``. Gzip and bz2 work out of the box. + +.. note:: + + ``expires_at`` is not enforced by this backend. Values written to object storage persist + indefinitely until explicitly deleted. Use your object storage provider's lifecycle policies + (e.g. S3 lifecycle rules, GCS object lifecycle) to automatically expire old state. + +.. note:: + + Task state paths are keyed on ``(dag_id, run_id, task_id, map_index)`` and are stable across + task retries. This makes this backend suitable for operators that use + :class:`~airflow.sdk.ResumableJobMixin` to reconnect to external jobs after a retry. diff --git a/providers/common/io/provider.yaml b/providers/common/io/provider.yaml index 7329ccde6602b..8831783d665e0 100644 --- a/providers/common/io/provider.yaml +++ b/providers/common/io/provider.yaml @@ -115,3 +115,31 @@ config: type: string example: "gz" default: "" + state_store_objectstorage_path: + description: | + Base path on object storage for the task/asset state store backend, in URL format. + When set, StateStoreObjectStorageBackend will persist task and asset state under this + prefix, organised as //// for tasks and + assets// for assets. + version_added: 1.8.0 + type: string + example: "s3://conn_id@bucket/task-state/" + default: "" + state_store_objectstorage_threshold: + description: | + Threshold in bytes for offloading serialized state store values to object storage. 0 means + always offload to object storage. Any positive number means values will be offloaded + only when their serialized size is at least that many bytes. Must be non-negative. + version_added: 1.8.0 + type: integer + example: "1000000" + default: "0" + state_store_objectstorage_compression: + description: | + Compression algorithm to use when writing task/asset state store values to object storage. + Supported algorithms are a.o.: gzip, bz2, lzma, and xz. If not specified, + no compression will be used. The same algorithm must be available on all workers. + version_added: 1.8.0 + type: string + example: "gzip" + default: "" diff --git a/providers/common/io/src/airflow/providers/common/io/get_provider_info.py b/providers/common/io/src/airflow/providers/common/io/get_provider_info.py index a08546d9ed5fa..f63cab0b8c1c2 100644 --- a/providers/common/io/src/airflow/providers/common/io/get_provider_info.py +++ b/providers/common/io/src/airflow/providers/common/io/get_provider_info.py @@ -84,6 +84,27 @@ def get_provider_info(): "example": "gz", "default": "", }, + "state_store_objectstorage_path": { + "description": "Base path on object storage for the task/asset state store backend, in URL format.\nWhen set, StateStoreObjectStorageBackend will persist task and asset state under this\nprefix, organised as //// for tasks and\nassets// for assets.\n", + "version_added": "1.8.0", + "type": "string", + "example": "s3://conn_id@bucket/task-state/", + "default": "", + }, + "state_store_objectstorage_threshold": { + "description": "Threshold in bytes for offloading serialized state store values to object storage. 0 means\nalways offload to object storage. Any positive number means values will be offloaded\nonly when their serialized size is at least that many bytes. Must be non-negative.\n", + "version_added": "1.8.0", + "type": "integer", + "example": "1000000", + "default": "0", + }, + "state_store_objectstorage_compression": { + "description": "Compression algorithm to use when writing task/asset state store values to object storage.\nSupported algorithms are a.o.: gzip, bz2, lzma, and xz. If not specified,\nno compression will be used. The same algorithm must be available on all workers.\n", + "version_added": "1.8.0", + "type": "string", + "example": "gzip", + "default": "", + }, }, } }, diff --git a/providers/common/io/src/airflow/providers/common/io/state_store/__init__.py b/providers/common/io/src/airflow/providers/common/io/state_store/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/common/io/src/airflow/providers/common/io/state_store/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/common/io/src/airflow/providers/common/io/state_store/backend.py b/providers/common/io/src/airflow/providers/common/io/state_store/backend.py new file mode 100644 index 0000000000000..dd9ad29cafe2d --- /dev/null +++ b/providers/common/io/src/airflow/providers/common/io/state_store/backend.py @@ -0,0 +1,246 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from functools import cache +from typing import TYPE_CHECKING +from urllib.parse import quote, urlsplit + +import fsspec.utils + +from airflow.providers.common.compat.sdk import conf + +if TYPE_CHECKING: + from datetime import datetime + + from pydantic import JsonValue + from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.orm import Session + + +from airflow.sdk import ObjectStoragePath +from airflow.sdk._shared.state import AssetScope, BaseStoreBackend, StoreScope, TaskScope + +SECTION = "common.io" + + +@cache +def _get_base_path() -> ObjectStoragePath: + return ObjectStoragePath(conf.get_mandatory_value(SECTION, "state_store_objectstorage_path")) + + +@cache +def _get_compression() -> str | None: + value = conf.get(SECTION, "state_store_objectstorage_compression", fallback=None) + return value or None + + +@cache +def _get_threshold() -> int: + value = conf.getint(SECTION, "state_store_objectstorage_threshold", fallback=0) + if value < 0: + raise ValueError( + f"[{SECTION}] state_store_objectstorage_threshold must be non-negative, got {value}." + ) + return value + + +def _get_compression_suffix() -> str: + compression = _get_compression() + if not compression: + return "" + for suffix, c in fsspec.utils.compressions.items(): + if c == compression: + return f".{suffix}" + raise ValueError(f"Compression {compression!r} is not supported.") + + +def _sanitise_segment(value: str) -> str: + if not value or value in (".", ".."): + raise ValueError(f"Invalid path segment: {value!r}") + return quote(value, safe="") + + +def _build_task_path(scope: TaskScope, key: str) -> ObjectStoragePath: + suffix = _get_compression_suffix() + return ( + _get_base_path() + / _sanitise_segment(scope.dag_id) + / _sanitise_segment(scope.run_id) + / _sanitise_segment(scope.task_id) + / str(scope.map_index) + / f"{_sanitise_segment(key)}{suffix}" + ) + + +def _build_asset_path(scope: AssetScope, key: str) -> ObjectStoragePath: + suffix = _get_compression_suffix() + asset_identifier = _sanitise_segment(scope.name or scope.uri or str(scope.asset_id)) + return _get_base_path() / "assets" / asset_identifier / f"{_sanitise_segment(key)}{suffix}" + + +def _write_to_object_storage(path: ObjectStoragePath, value: str) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + compression = _get_compression() + with path.open(mode="wb", compression=compression) as f: + f.write(value.encode("utf-8")) + + +def _read_from_object_storage(path: ObjectStoragePath) -> str | None: + try: + with path.open(mode="rb", compression="infer") as f: + return f.read().decode("utf-8") + except FileNotFoundError: + return None + + +def _is_storage_ref(value: str) -> bool: + try: + if not urlsplit(value).scheme: + return False + return ObjectStoragePath(value).is_relative_to(_get_base_path()) + except Exception: + return False + + +def _scope_path(scope: StoreScope, key: str) -> ObjectStoragePath: + match scope: + case TaskScope(): + return _build_task_path(scope, key) + case AssetScope(): + return _build_asset_path(scope, key) + case _: + raise TypeError(f"Unknown scope type: {type(scope)}") + + +class StateStoreObjectStorageBackend(BaseStoreBackend): + """ + Object-storage backend for task and asset store. + + Config keys (all under ``[common.io]``): + + - ``state_store_objectstorage_path``: base path, e.g. ``s3://conn_id@bucket/task-state/`` + - ``state_store_objectstorage_compression``: optional compression, e.g. ``gzip`` + """ + + def get(self, scope: StoreScope, key: str, *, session: Session | None = None) -> str | None: + return _read_from_object_storage(_scope_path(scope, key)) + + def set( + self, + scope: StoreScope, + key: str, + value: str, + *, + expires_at: datetime | None = None, + session: Session | None = None, + ) -> None: + _write_to_object_storage(_scope_path(scope, key), value) + + def delete(self, scope: StoreScope, key: str, *, session: Session | None = None) -> None: + _scope_path(scope, key).unlink(missing_ok=True) + + def clear( + self, scope: StoreScope, *, all_map_indices: bool = False, session: Session | None = None + ) -> None: + match scope: + case TaskScope(): + if all_map_indices: + prefix = ( + _get_base_path() + / _sanitise_segment(scope.dag_id) + / _sanitise_segment(scope.run_id) + / _sanitise_segment(scope.task_id) + ) + for p in prefix.glob("*/*"): + p.unlink(missing_ok=True) + else: + prefix = ( + _get_base_path() + / _sanitise_segment(scope.dag_id) + / _sanitise_segment(scope.run_id) + / _sanitise_segment(scope.task_id) + / str(scope.map_index) + ) + for p in prefix.glob("*"): + p.unlink(missing_ok=True) + case AssetScope(): + asset_identifier = _sanitise_segment(scope.name or scope.uri or str(scope.asset_id)) + prefix = _get_base_path() / "assets" / asset_identifier + for p in prefix.glob("*"): + p.unlink(missing_ok=True) + case _: + raise TypeError(f"Unknown scope type: {type(scope)}") + + async def aget(self, scope: StoreScope, key: str, *, session: AsyncSession | None = None) -> str | None: + raise NotImplementedError + + async def aset( + self, + scope: StoreScope, + key: str, + value: str, + *, + expires_at: datetime | None = None, + session: AsyncSession | None = None, + ) -> None: + raise NotImplementedError + + async def adelete(self, scope: StoreScope, key: str, *, session: AsyncSession | None = None) -> None: + raise NotImplementedError + + async def aclear( + self, scope: StoreScope, *, all_map_indices: bool = False, session: AsyncSession | None = None + ) -> None: + raise NotImplementedError + + def serialize_task_state_store_to_ref(self, *, value: JsonValue, key: str, scope: TaskScope) -> str: + serialized = json.dumps(value) + if len(serialized.encode()) < _get_threshold(): + return serialized + path = _build_task_path(scope, key) + _write_to_object_storage(path, serialized) + return str(path) + + def deserialize_task_state_store_from_ref(self, stored: str) -> JsonValue: + if not stored: + return None + if _is_storage_ref(stored): + data = _read_from_object_storage(ObjectStoragePath(stored)) + if data is not None: + return json.loads(data) + return None + return json.loads(stored) + + def serialize_asset_state_store_to_ref(self, *, value: JsonValue, key: str, scope: AssetScope) -> str: + serialized = json.dumps(value) + if len(serialized.encode()) < _get_threshold(): + return serialized + path = _build_asset_path(scope, key) + _write_to_object_storage(path, serialized) + return str(path) + + def deserialize_asset_state_store_from_ref(self, stored: str) -> JsonValue: + if not stored: + return None + if _is_storage_ref(stored): + data = _read_from_object_storage(ObjectStoragePath(stored)) + if data is not None: + return json.loads(data) + return None + return json.loads(stored) diff --git a/providers/common/io/tests/unit/common/io/state_store/__init__.py b/providers/common/io/tests/unit/common/io/state_store/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/common/io/tests/unit/common/io/state_store/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/common/io/tests/unit/common/io/state_store/test_backend.py b/providers/common/io/tests/unit/common/io/state_store/test_backend.py new file mode 100644 index 0000000000000..6e7c634b85fa2 --- /dev/null +++ b/providers/common/io/tests/unit/common/io/state_store/test_backend.py @@ -0,0 +1,274 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS + +if not AIRFLOW_V_3_3_PLUS: + pytest.skip("Store backend requires Airflow >= 3.3", allow_module_level=True) + +from airflow.providers.common.io.state_store import backend +from airflow.providers.common.io.state_store.backend import ( + StateStoreObjectStorageBackend, + _build_asset_path, + _build_task_path, + _read_from_object_storage, + _write_to_object_storage, +) +from airflow.sdk import ObjectStoragePath +from airflow.sdk.state import AssetScope, TaskScope + +from tests_common.test_utils.config import conf_vars + + +@pytest.fixture(autouse=True) +def clear_caches(): + backend._get_base_path.cache_clear() + backend._get_compression.cache_clear() + backend._get_threshold.cache_clear() + yield + backend._get_base_path.cache_clear() + backend._get_compression.cache_clear() + backend._get_threshold.cache_clear() + + +@pytest.fixture +def base_path(tmp_path): + store_path = tmp_path / "store" + store_path.mkdir() + return f"file://{store_path}" + + +@pytest.fixture +def conf_overrides(base_path): + from tests_common.test_utils.config import conf_vars + + with conf_vars( + { + ("common.io", "state_store_objectstorage_path"): base_path, + ("common.io", "state_store_objectstorage_compression"): "", + } + ): + yield base_path + + +@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="task state store requires Airflow >= 3.3") +class TestPathBuilders: + def test_build_task_path_segments(self, conf_overrides): + scope = TaskScope(dag_id="my_dag", run_id="run_1", task_id="my_task", map_index=-1) + path = _build_task_path(scope, "job_id") + assert str(path).endswith("my_dag/run_1/my_task/-1/job_id") + + def test_build_task_path_encodes_special_chars(self, conf_overrides): + scope = TaskScope(dag_id="a/b", run_id="r/1", task_id="t/x", map_index=0) + path = _build_task_path(scope, "key/name") + assert str(path).endswith("a%2Fb/r%2F1/t%2Fx/0/key%2Fname") + assert "a/b" not in str(path) + assert "key/name" not in str(path) + + def test_build_task_path_distinct_keys_dont_collide(self, conf_overrides): + scope = TaskScope(dag_id="d", run_id="r", task_id="t", map_index=0) + path_slash = _build_task_path(scope, "connector/offset") + path_under = _build_task_path(scope, "connector_offset") + assert str(path_slash) != str(path_under) + + def test_build_task_path_rejects_empty_and_dots(self, conf_overrides): + scope = TaskScope(dag_id="d", run_id="r", task_id="t", map_index=0) + with pytest.raises(ValueError, match="Invalid path segment"): + _build_task_path(scope, "") + with pytest.raises(ValueError, match="Invalid path segment"): + _build_task_path(scope, "..") + + def test_build_asset_path_segments(self, conf_overrides): + scope = AssetScope(name="my_asset") + path = _build_asset_path(scope, "status") + assert str(path).endswith("assets/my_asset/status") + + def test_build_asset_path_uses_uri_when_no_name(self, conf_overrides): + scope = AssetScope(uri="s3://bucket/path") + path_slash = _build_asset_path(scope, "key") + scope2 = AssetScope(uri="s3://bucket_path") + path_under = _build_asset_path(scope2, "key") + assert str(path_slash) != str(path_under) + + def test_compression_suffix_appended(self, tmp_path): + store_path = tmp_path / "store" + store_path.mkdir() + + with conf_vars( + { + ("common.io", "state_store_objectstorage_path"): f"file://{store_path}", + ("common.io", "state_store_objectstorage_compression"): "gzip", + } + ): + backend._get_base_path.cache_clear() + backend._get_compression.cache_clear() + scope = TaskScope(dag_id="d", run_id="r", task_id="t", map_index=-1) + path = _build_task_path(scope, "k") + assert str(path).endswith(".gz") + + +class TestIOPrimitives: + def test_write_and_read_roundtrip(self, conf_overrides): + path = ObjectStoragePath(f"{conf_overrides}/test_key") + _write_to_object_storage(path, '{"value": 42}') + result = _read_from_object_storage(path) + assert result == '{"value": 42}' + + def test_read_missing_returns_none(self, conf_overrides): + path = ObjectStoragePath(f"{conf_overrides}/nonexistent") + assert _read_from_object_storage(path) is None + + def test_write_creates_parent_dirs(self, conf_overrides): + path = ObjectStoragePath(f"{conf_overrides}/a/b/c/key") + _write_to_object_storage(path, "hello") + assert path.exists() + + +class TestStateStoreObjectStorageBackend: + @pytest.fixture + def store(self, conf_overrides): + return StateStoreObjectStorageBackend() + + @pytest.fixture + def task_scope(self): + return TaskScope(dag_id="my_dag", run_id="run_1", task_id="my_task", map_index=-1) + + @pytest.fixture + def asset_scope(self): + return AssetScope(name="my_asset") + + def test_set_and_get_task(self, store, task_scope): + store.set(task_scope, "k", "hello") + assert store.get(task_scope, "k") == "hello" + + def test_get_missing_returns_none(self, store, task_scope): + assert store.get(task_scope, "missing") is None + + def test_delete_task(self, store, task_scope): + store.set(task_scope, "k", "v") + store.delete(task_scope, "k") + assert store.get(task_scope, "k") is None + + def test_delete_missing_is_noop(self, store, task_scope): + store.delete(task_scope, "does_not_exist") + + def test_clear_task_single_map_index(self, store, task_scope): + store.set(task_scope, "k1", "v1") + store.set(task_scope, "k2", "v2") + store.clear(task_scope) + assert store.get(task_scope, "k1") is None + assert store.get(task_scope, "k2") is None + + def test_clear_task_all_map_indices(self, store): + scope0 = TaskScope(dag_id="d", run_id="r", task_id="t", map_index=0) + scope1 = TaskScope(dag_id="d", run_id="r", task_id="t", map_index=1) + store.set(scope0, "k", "v0") + store.set(scope1, "k", "v1") + store.clear(scope0, all_map_indices=True) + assert store.get(scope0, "k") is None + assert store.get(scope1, "k") is None + + def test_set_and_get_asset(self, store, asset_scope): + store.set(asset_scope, "status", "ok") + assert store.get(asset_scope, "status") == "ok" + + def test_clear_asset(self, store, asset_scope): + store.set(asset_scope, "k1", "v1") + store.set(asset_scope, "k2", "v2") + store.clear(asset_scope) + assert store.get(asset_scope, "k1") is None + assert store.get(asset_scope, "k2") is None + + def test_serialize_and_deserialize_task(self, store, task_scope): + ref = store.serialize_task_state_store_to_ref(value={"x": 1}, key="job_id", scope=task_scope) + assert ref.startswith("file://") + result = store.deserialize_task_state_store_from_ref(ref) + assert result == {"x": 1} + + def test_serialize_and_deserialize_asset(self, store, asset_scope): + ref = store.serialize_asset_state_store_to_ref(value=[1, 2, 3], key="result", scope=asset_scope) + assert ref.startswith("file://") + result = store.deserialize_asset_state_store_from_ref(ref) + assert result == [1, 2, 3] + + def test_deserialize_missing_ref_returns_none(self, store, conf_overrides): + result = store.deserialize_task_state_store_from_ref(f"{conf_overrides}/no/such/path") + assert result is None + + def test_task_serialize_offloads_to_storage(self, task_scope, base_path): + with conf_vars( + { + ("common.io", "state_store_objectstorage_path"): base_path, + ("common.io", "state_store_objectstorage_threshold"): "0", + } + ): + backend._get_threshold.cache_clear() + store = StateStoreObjectStorageBackend() + ref = store.serialize_task_state_store_to_ref(value={"x": 1}, key="k", scope=task_scope) + assert ref.startswith("file://") + + def test_asset_serialize_offloads_to_storage(self, asset_scope, base_path): + with conf_vars( + { + ("common.io", "state_store_objectstorage_path"): base_path, + ("common.io", "state_store_objectstorage_threshold"): "0", + } + ): + backend._get_threshold.cache_clear() + store = StateStoreObjectStorageBackend() + ref = store.serialize_asset_state_store_to_ref(value={"x": 1}, key="k", scope=asset_scope) + assert ref.startswith("file://") + + def test_task_serialize_to_db_when_below_threshold(self, task_scope, base_path): + with conf_vars( + { + ("common.io", "state_store_objectstorage_path"): base_path, + ("common.io", "state_store_objectstorage_threshold"): "10000", + } + ): + backend._get_threshold.cache_clear() + store = StateStoreObjectStorageBackend() + ref = store.serialize_task_state_store_to_ref(value={"x": 1}, key="k", scope=task_scope) + assert not ref.startswith("file://") + assert store.deserialize_task_state_store_from_ref(ref) == {"x": 1} + + def test_asset_serialize_to_db_when_below_threshold(self, asset_scope, base_path): + with conf_vars( + { + ("common.io", "state_store_objectstorage_path"): base_path, + ("common.io", "state_store_objectstorage_threshold"): "10000", + } + ): + backend._get_threshold.cache_clear() + store = StateStoreObjectStorageBackend() + ref = store.serialize_asset_state_store_to_ref(value={"x": 1}, key="k", scope=asset_scope) + assert not ref.startswith("file://") + assert store.deserialize_asset_state_store_from_ref(ref) == {"x": 1} + + def test_negative_threshold_raises(self, base_path): + with conf_vars( + { + ("common.io", "state_store_objectstorage_path"): base_path, + ("common.io", "state_store_objectstorage_threshold"): "-1", + } + ): + backend._get_threshold.cache_clear() + with pytest.raises(ValueError, match="must be non-negative"): + backend._get_threshold()