diff --git a/packages/opal-common/opal_common/async_utils.py b/packages/opal-common/opal_common/async_utils.py index fbed27cf9..e682c2991 100644 --- a/packages/opal-common/opal_common/async_utils.py +++ b/packages/opal-common/opal_common/async_utils.py @@ -35,6 +35,23 @@ async def async_function(): ) +async def run_sync_with_timeout( + func: Callable[P_args, T_result], + *args: P_args.args, + timeout: Optional[float] = None, + **kwargs: P_args.kwargs, +) -> T_result: + """Like run_sync, but with an optional timeout. + + If timeout is None or 0, behaves identically to run_sync (no timeout). + If timeout > 0, wraps the call in asyncio.wait_for with the given timeout. + Raises asyncio.TimeoutError if the operation exceeds the timeout. + """ + if timeout: + return await asyncio.wait_for(run_sync(func, *args, **kwargs), timeout=timeout) + return await run_sync(func, *args, **kwargs) + + class TakeANumberQueue: """Enables a task to hold a place in queue prior to having the actual item to be sent over the queue. diff --git a/packages/opal-common/opal_common/tests/test_async_utils.py b/packages/opal-common/opal_common/tests/test_async_utils.py new file mode 100644 index 000000000..e1668df73 --- /dev/null +++ b/packages/opal-common/opal_common/tests/test_async_utils.py @@ -0,0 +1,61 @@ +import asyncio +import time + +import pytest +from opal_common.async_utils import run_sync_with_timeout + + +@pytest.mark.asyncio +async def test_run_sync_with_timeout_completes_normally(): + """A fast function should complete within the timeout.""" + + def fast_func(x, y): + return x + y + + result = await run_sync_with_timeout(fast_func, 2, 3, timeout=5.0) + assert result == 5 + + +@pytest.mark.asyncio +async def test_run_sync_with_timeout_raises_on_slow_function(): + """A slow function should raise asyncio.TimeoutError when it exceeds the timeout.""" + + def slow_func(): + time.sleep(10) + return "done" + + with pytest.raises(asyncio.TimeoutError): + await run_sync_with_timeout(slow_func, timeout=0.3) + + +@pytest.mark.asyncio +async def test_run_sync_with_timeout_no_timeout_when_zero(): + """When timeout is 0, no timeout should be applied (behaves like run_sync).""" + + def quick_func(): + return 42 + + result = await run_sync_with_timeout(quick_func, timeout=0) + assert result == 42 + + +@pytest.mark.asyncio +async def test_run_sync_with_timeout_no_timeout_when_none(): + """When timeout is None, no timeout should be applied (behaves like run_sync).""" + + def quick_func(): + return "ok" + + result = await run_sync_with_timeout(quick_func, timeout=None) + assert result == "ok" + + +@pytest.mark.asyncio +async def test_run_sync_with_timeout_propagates_exceptions(): + """Exceptions from the wrapped function should propagate normally.""" + + def failing_func(): + raise ValueError("test error") + + with pytest.raises(ValueError, match="test error"): + await run_sync_with_timeout(failing_func, timeout=5.0) diff --git a/packages/opal-server/opal_server/git_fetcher.py b/packages/opal-server/opal_server/git_fetcher.py index 67e1016e9..dff0149d8 100644 --- a/packages/opal-server/opal_server/git_fetcher.py +++ b/packages/opal-server/opal_server/git_fetcher.py @@ -11,7 +11,7 @@ import pygit2 from ddtrace import tracer from git import Repo -from opal_common.async_utils import run_sync +from opal_common.async_utils import run_sync, run_sync_with_timeout from opal_common.git_utils.bundle_maker import BundleMaker from opal_common.logger import logger from opal_common.schemas.policy import PolicyBundle @@ -60,7 +60,7 @@ def create_local_branch_ref( if branch_name not in repo.branches.local: base_remote_branch = f"{remote_name}/{base_branch}" if repo.branches.remote.get(base_remote_branch) is not None: - (commit, _) = repo.resolve_refish(base_remote_branch) + commit, _ = repo.resolve_refish(base_remote_branch) else: raise RuntimeError("Base branch was not found on remote") logger.debug( @@ -91,7 +91,7 @@ def get_local_branch(repo: Repository, branch: str) -> Optional[pygit2.Reference @staticmethod def get_commit_hash(repo: Repository, branch: str, remote: str) -> Optional[str]: try: - (commit, _) = repo.resolve_refish(f"{remote}/{branch}") + commit, _ = repo.resolve_refish(f"{remote}/{branch}") return commit.hex except (pygit2.GitError, KeyError): return None @@ -109,7 +109,9 @@ def verify_found_repo_matches_remote( f"found target repo url is referred by remote: {remote.name}, url={remote.url}" ) return - error: str = f"Repo mismatch! No remote matches target url: {expected_remote_url}, found urls: {[remote.url for remote in repo.remotes]}" + error: str = ( + f"Repo mismatch! No remote matches target url: {expected_remote_url}, found urls: {[remote.url for remote in repo.remotes]}" + ) logger.error(error) raise ValueError(error) @@ -126,6 +128,7 @@ def __init__( source: GitPolicyScopeSource, callbacks=PolicyFetcherCallbacks(), remote_name: str = "origin", + clone_timeout: int = 0, ): super().__init__(callbacks) self._base_dir = GitPolicyFetcher.base_dir(base_dir) @@ -134,6 +137,7 @@ def __init__( self._repo_path = GitPolicyFetcher.repo_clone_path(base_dir, self._source) self._remote = remote_name self._scope_id = scope_id + self._clone_timeout = clone_timeout logger.debug( f"Initializing git fetcher: scope_id={scope_id}, url={source.url}, branch={self._source.branch}, path={GitPolicyFetcher.source_id(source)}" ) @@ -187,13 +191,22 @@ async def fetch_and_notify_on_changes( logger.debug( f"Fetching remote (force_fetch={force_fetch}): {self._remote} ({self._source.url})" ) - GitPolicyFetcher.repos_last_fetched[ - self.source_id - ] = datetime.datetime.now() - await run_sync( - repo.remotes[self._remote].fetch, - callbacks=self._auth_callbacks, + GitPolicyFetcher.repos_last_fetched[self.source_id] = ( + datetime.datetime.now() ) + try: + await run_sync_with_timeout( + repo.remotes[self._remote].fetch, + callbacks=self._auth_callbacks, + timeout=self._clone_timeout or None, + ) + except asyncio.TimeoutError: + logger.error( + "Fetch operation timed out after {timeout}s for {url}", + timeout=self._clone_timeout, + url=self._source.url, + ) + raise logger.debug(f"Fetch completed: {self._source.url}") # New commits might be present because of a previous fetch made by another scope @@ -222,14 +235,23 @@ async def _clone(self): path=self._repo_path, ) try: - repo: Repository = await run_sync( + repo: Repository = await run_sync_with_timeout( clone_repository, self._source.url, str(self._repo_path), callbacks=self._auth_callbacks, + timeout=self._clone_timeout or None, + ) + except asyncio.TimeoutError: + logger.error( + "Clone operation timed out after {timeout}s for {url}", + timeout=self._clone_timeout, + url=self._source.url, ) + raise except pygit2.GitError: logger.exception(f"Could not clone repo at {self._source.url}") + raise else: logger.info(f"Clone completed: {self._source.url}") await self._notify_on_changes(repo) diff --git a/packages/opal-server/opal_server/scopes/service.py b/packages/opal-server/opal_server/scopes/service.py index d3df4972f..d9ceddc49 100644 --- a/packages/opal-server/opal_server/scopes/service.py +++ b/packages/opal-server/opal_server/scopes/service.py @@ -12,6 +12,7 @@ from opal_common.schemas.policy import PolicyUpdateMessageNotification from opal_common.schemas.policy_source import GitPolicyScopeSource from opal_common.topics.publisher import ScopedServerSideTopicPublisher +from opal_server.config import opal_server_config from opal_server.git_fetcher import GitPolicyFetcher, PolicyFetcherCallbacks from opal_server.policy.watcher.callbacks import ( create_policy_update, @@ -147,6 +148,7 @@ async def sync_scope( scope.scope_id, source, callbacks=callbacks, + clone_timeout=opal_server_config.POLICY_REPO_CLONE_TIMEOUT, ) try: diff --git a/packages/opal-server/opal_server/tests/git_fetcher_test.py b/packages/opal-server/opal_server/tests/git_fetcher_test.py new file mode 100644 index 000000000..ad1a9f6a3 --- /dev/null +++ b/packages/opal-server/opal_server/tests/git_fetcher_test.py @@ -0,0 +1,166 @@ +import asyncio +import time +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pygit2 +import pytest +from opal_common.schemas.policy_source import GitPolicyScopeSource, NoAuthData +from opal_server.git_fetcher import GitPolicyFetcher, PolicyFetcherCallbacks + + +def _make_source(url="https://example.com/repo.git", branch="main"): + """Create a minimal GitPolicyScopeSource for testing.""" + return GitPolicyScopeSource( + source_type="GIT", + url=url, + branch=branch, + auth=NoAuthData(), + directories=["."], + extensions=[".rego", ".json"], + manifest=".manifest", + bundle_ignore=None, + ) + + +def _make_fetcher(source=None, clone_timeout=0, callbacks=None): + """Create a GitPolicyFetcher with mocked paths and config.""" + if source is None: + source = _make_source() + if callbacks is None: + callbacks = PolicyFetcherCallbacks() + return GitPolicyFetcher( + base_dir=Path("/tmp/test_opal"), + scope_id="test-scope", + source=source, + callbacks=callbacks, + clone_timeout=clone_timeout, + ) + + +@pytest.fixture(autouse=True) +def _clear_class_state(): + """Clear GitPolicyFetcher class-level caches between tests.""" + GitPolicyFetcher.repo_locks.clear() + GitPolicyFetcher.repos.clear() + GitPolicyFetcher.repos_last_fetched.clear() + yield + GitPolicyFetcher.repo_locks.clear() + GitPolicyFetcher.repos.clear() + GitPolicyFetcher.repos_last_fetched.clear() + + +# --------------------------------------------------------------------------- +# Clone tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("opal_server.git_fetcher.GitCallback") +@patch("opal_server.git_fetcher.clone_repository") +async def test_clone_timeout_raises_on_slow_operation(mock_clone, mock_git_cb): + """When clone_repository blocks longer than the timeout, asyncio.TimeoutError propagates.""" + + def slow_clone(*args, **kwargs): + time.sleep(10) + + mock_clone.side_effect = slow_clone + mock_git_cb.return_value = MagicMock() + + fetcher = _make_fetcher(clone_timeout=0.5) + + with pytest.raises(asyncio.TimeoutError): + await fetcher._clone() + + +@pytest.mark.asyncio +@patch("opal_server.git_fetcher.GitCallback") +@patch("opal_server.git_fetcher.clone_repository") +async def test_clone_reraises_git_error(mock_clone, mock_git_cb): + """pygit2.GitError from clone_repository must NOT be swallowed; it should propagate.""" + mock_clone.side_effect = pygit2.GitError("auth failed") + mock_git_cb.return_value = MagicMock() + + fetcher = _make_fetcher() + + with pytest.raises(pygit2.GitError, match="auth failed"): + await fetcher._clone() + + +@pytest.mark.asyncio +@patch("opal_server.git_fetcher.GitCallback") +@patch("opal_server.git_fetcher.clone_repository") +async def test_clone_success(mock_clone, mock_git_cb): + """A successful clone should call _notify_on_changes.""" + mock_repo = MagicMock() + mock_clone.return_value = mock_repo + mock_git_cb.return_value = MagicMock() + + fetcher = _make_fetcher() + fetcher._notify_on_changes = AsyncMock() + + await fetcher._clone() + + mock_clone.assert_called_once() + fetcher._notify_on_changes.assert_awaited_once_with(mock_repo) + + +@pytest.mark.asyncio +@patch("opal_server.git_fetcher.GitCallback") +@patch("opal_server.git_fetcher.clone_repository") +async def test_clone_no_timeout_when_zero(mock_clone, mock_git_cb): + """When clone_timeout=0, no asyncio.wait_for timeout is applied. + + We verify this by having a function that takes a short time but would fail + with a zero-second timeout if one were mistakenly applied. + """ + + def quick_clone(*args, **kwargs): + time.sleep(0.1) + return MagicMock() + + mock_clone.side_effect = quick_clone + mock_git_cb.return_value = MagicMock() + + fetcher = _make_fetcher(clone_timeout=0) + fetcher._notify_on_changes = AsyncMock() + + # Should complete without TimeoutError + await fetcher._clone() + mock_clone.assert_called_once() + + +# --------------------------------------------------------------------------- +# Fetch tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("opal_server.git_fetcher.GitCallback") +@patch("opal_server.git_fetcher.discover_repository", return_value=True) +async def test_fetch_timeout_raises_on_slow_operation(mock_discover, mock_git_cb): + """When the fetch operation blocks longer than the timeout, asyncio.TimeoutError propagates.""" + mock_git_cb.return_value = MagicMock() + + source = _make_source() + fetcher = _make_fetcher(source=source, clone_timeout=0.5) + + # Set up a mock repo with a slow fetch + mock_remote = MagicMock() + + def slow_fetch(**kwargs): + time.sleep(10) + + mock_remote.fetch = slow_fetch + + mock_repo = MagicMock() + mock_repo.remotes = {"origin": mock_remote} + + # Patch _discover_repository and _get_valid_repo to enter the fetch path + fetcher._discover_repository = MagicMock(return_value=True) + fetcher._get_valid_repo = MagicMock(return_value=mock_repo) + fetcher._should_fetch = AsyncMock(return_value=True) + fetcher._notify_on_changes = AsyncMock() + + with pytest.raises(asyncio.TimeoutError): + await fetcher.fetch_and_notify_on_changes(force_fetch=True)