Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions packages/opal-common/opal_common/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,23 @@ async def async_function():
)


async def run_sync_with_timeout(
func: Callable[P_args, T_result],
*args: P_args.args,
timeout: Optional[float] = None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the upstream function have "timeout" kwarg then you can't pass it to it, make it more unique

**kwargs: P_args.kwargs,
) -> T_result:
"""Like run_sync, but with an optional timeout.

If timeout is None or 0, behaves identically to run_sync (no timeout).
If timeout > 0, wraps the call in asyncio.wait_for with the given timeout.
Raises asyncio.TimeoutError if the operation exceeds the timeout.
"""
if timeout:
return await asyncio.wait_for(run_sync(func, *args, **kwargs), timeout=timeout)
return await run_sync(func, *args, **kwargs)


class TakeANumberQueue:
"""Enables a task to hold a place in queue prior to having the actual item
to be sent over the queue.
Expand Down
61 changes: 61 additions & 0 deletions packages/opal-common/opal_common/tests/test_async_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import asyncio
import time

import pytest
from opal_common.async_utils import run_sync_with_timeout


@pytest.mark.asyncio
async def test_run_sync_with_timeout_completes_normally():
"""A fast function should complete within the timeout."""

def fast_func(x, y):
return x + y

result = await run_sync_with_timeout(fast_func, 2, 3, timeout=5.0)
assert result == 5


@pytest.mark.asyncio
async def test_run_sync_with_timeout_raises_on_slow_function():
"""A slow function should raise asyncio.TimeoutError when it exceeds the timeout."""

def slow_func():
time.sleep(10)
return "done"

with pytest.raises(asyncio.TimeoutError):
await run_sync_with_timeout(slow_func, timeout=0.3)


@pytest.mark.asyncio
async def test_run_sync_with_timeout_no_timeout_when_zero():
"""When timeout is 0, no timeout should be applied (behaves like run_sync)."""

def quick_func():
return 42

result = await run_sync_with_timeout(quick_func, timeout=0)
assert result == 42


@pytest.mark.asyncio
async def test_run_sync_with_timeout_no_timeout_when_none():
"""When timeout is None, no timeout should be applied (behaves like run_sync)."""

def quick_func():
return "ok"

result = await run_sync_with_timeout(quick_func, timeout=None)
assert result == "ok"


@pytest.mark.asyncio
async def test_run_sync_with_timeout_propagates_exceptions():
"""Exceptions from the wrapped function should propagate normally."""

def failing_func():
raise ValueError("test error")

with pytest.raises(ValueError, match="test error"):
await run_sync_with_timeout(failing_func, timeout=5.0)
44 changes: 33 additions & 11 deletions packages/opal-server/opal_server/git_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pygit2
from ddtrace import tracer
from git import Repo
from opal_common.async_utils import run_sync
from opal_common.async_utils import run_sync, run_sync_with_timeout
from opal_common.git_utils.bundle_maker import BundleMaker
from opal_common.logger import logger
from opal_common.schemas.policy import PolicyBundle
Expand Down Expand Up @@ -60,7 +60,7 @@ def create_local_branch_ref(
if branch_name not in repo.branches.local:
base_remote_branch = f"{remote_name}/{base_branch}"
if repo.branches.remote.get(base_remote_branch) is not None:
(commit, _) = repo.resolve_refish(base_remote_branch)
commit, _ = repo.resolve_refish(base_remote_branch)
else:
raise RuntimeError("Base branch was not found on remote")
logger.debug(
Expand Down Expand Up @@ -91,7 +91,7 @@ def get_local_branch(repo: Repository, branch: str) -> Optional[pygit2.Reference
@staticmethod
def get_commit_hash(repo: Repository, branch: str, remote: str) -> Optional[str]:
try:
(commit, _) = repo.resolve_refish(f"{remote}/{branch}")
commit, _ = repo.resolve_refish(f"{remote}/{branch}")
return commit.hex
except (pygit2.GitError, KeyError):
return None
Expand All @@ -109,7 +109,9 @@ def verify_found_repo_matches_remote(
f"found target repo url is referred by remote: {remote.name}, url={remote.url}"
)
return
error: str = f"Repo mismatch! No remote matches target url: {expected_remote_url}, found urls: {[remote.url for remote in repo.remotes]}"
error: str = (
f"Repo mismatch! No remote matches target url: {expected_remote_url}, found urls: {[remote.url for remote in repo.remotes]}"
)
logger.error(error)
raise ValueError(error)

Expand All @@ -126,6 +128,7 @@ def __init__(
source: GitPolicyScopeSource,
callbacks=PolicyFetcherCallbacks(),
remote_name: str = "origin",
clone_timeout: int = 0,
):
super().__init__(callbacks)
self._base_dir = GitPolicyFetcher.base_dir(base_dir)
Expand All @@ -134,6 +137,7 @@ def __init__(
self._repo_path = GitPolicyFetcher.repo_clone_path(base_dir, self._source)
self._remote = remote_name
self._scope_id = scope_id
self._clone_timeout = clone_timeout
logger.debug(
f"Initializing git fetcher: scope_id={scope_id}, url={source.url}, branch={self._source.branch}, path={GitPolicyFetcher.source_id(source)}"
)
Expand Down Expand Up @@ -187,13 +191,22 @@ async def fetch_and_notify_on_changes(
logger.debug(
f"Fetching remote (force_fetch={force_fetch}): {self._remote} ({self._source.url})"
)
GitPolicyFetcher.repos_last_fetched[
self.source_id
] = datetime.datetime.now()
await run_sync(
repo.remotes[self._remote].fetch,
callbacks=self._auth_callbacks,
GitPolicyFetcher.repos_last_fetched[self.source_id] = (
datetime.datetime.now()
)
try:
await run_sync_with_timeout(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Timeout should be for the entire sync procedure, not individually fetch and clone. That would naturally double the actual timeout

repo.remotes[self._remote].fetch,
callbacks=self._auth_callbacks,
timeout=self._clone_timeout or None,
)
except asyncio.TimeoutError:
logger.error(
"Fetch operation timed out after {timeout}s for {url}",
timeout=self._clone_timeout,
url=self._source.url,
)
raise
logger.debug(f"Fetch completed: {self._source.url}")

# New commits might be present because of a previous fetch made by another scope
Expand Down Expand Up @@ -222,14 +235,23 @@ async def _clone(self):
path=self._repo_path,
)
try:
repo: Repository = await run_sync(
repo: Repository = await run_sync_with_timeout(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to make sure that in case of a timeout, we actually terminate the Git subprocess

clone_repository,
self._source.url,
str(self._repo_path),
callbacks=self._auth_callbacks,
timeout=self._clone_timeout or None,
)
except asyncio.TimeoutError:
logger.error(
"Clone operation timed out after {timeout}s for {url}",
timeout=self._clone_timeout,
url=self._source.url,
)
raise
except pygit2.GitError:
logger.exception(f"Could not clone repo at {self._source.url}")
raise
else:
logger.info(f"Clone completed: {self._source.url}")
await self._notify_on_changes(repo)
Expand Down
2 changes: 2 additions & 0 deletions packages/opal-server/opal_server/scopes/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from opal_common.schemas.policy import PolicyUpdateMessageNotification
from opal_common.schemas.policy_source import GitPolicyScopeSource
from opal_common.topics.publisher import ScopedServerSideTopicPublisher
from opal_server.config import opal_server_config
from opal_server.git_fetcher import GitPolicyFetcher, PolicyFetcherCallbacks
from opal_server.policy.watcher.callbacks import (
create_policy_update,
Expand Down Expand Up @@ -147,6 +148,7 @@ async def sync_scope(
scope.scope_id,
source,
callbacks=callbacks,
clone_timeout=opal_server_config.POLICY_REPO_CLONE_TIMEOUT,
)

try:
Expand Down
166 changes: 166 additions & 0 deletions packages/opal-server/opal_server/tests/git_fetcher_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import asyncio
import time
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch

import pygit2
import pytest
from opal_common.schemas.policy_source import GitPolicyScopeSource, NoAuthData
from opal_server.git_fetcher import GitPolicyFetcher, PolicyFetcherCallbacks


def _make_source(url="https://example.com/repo.git", branch="main"):
"""Create a minimal GitPolicyScopeSource for testing."""
return GitPolicyScopeSource(
source_type="GIT",
url=url,
branch=branch,
auth=NoAuthData(),
directories=["."],
extensions=[".rego", ".json"],
manifest=".manifest",
bundle_ignore=None,
)


def _make_fetcher(source=None, clone_timeout=0, callbacks=None):
"""Create a GitPolicyFetcher with mocked paths and config."""
if source is None:
source = _make_source()
if callbacks is None:
callbacks = PolicyFetcherCallbacks()
return GitPolicyFetcher(
base_dir=Path("/tmp/test_opal"),
scope_id="test-scope",
source=source,
callbacks=callbacks,
clone_timeout=clone_timeout,
)


@pytest.fixture(autouse=True)
def _clear_class_state():
"""Clear GitPolicyFetcher class-level caches between tests."""
GitPolicyFetcher.repo_locks.clear()
GitPolicyFetcher.repos.clear()
GitPolicyFetcher.repos_last_fetched.clear()
yield
GitPolicyFetcher.repo_locks.clear()
GitPolicyFetcher.repos.clear()
GitPolicyFetcher.repos_last_fetched.clear()


# ---------------------------------------------------------------------------
# Clone tests
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
@patch("opal_server.git_fetcher.GitCallback")
@patch("opal_server.git_fetcher.clone_repository")
async def test_clone_timeout_raises_on_slow_operation(mock_clone, mock_git_cb):
"""When clone_repository blocks longer than the timeout, asyncio.TimeoutError propagates."""

def slow_clone(*args, **kwargs):
time.sleep(10)

mock_clone.side_effect = slow_clone
mock_git_cb.return_value = MagicMock()

fetcher = _make_fetcher(clone_timeout=0.5)

with pytest.raises(asyncio.TimeoutError):
await fetcher._clone()


@pytest.mark.asyncio
@patch("opal_server.git_fetcher.GitCallback")
@patch("opal_server.git_fetcher.clone_repository")
async def test_clone_reraises_git_error(mock_clone, mock_git_cb):
"""pygit2.GitError from clone_repository must NOT be swallowed; it should propagate."""
mock_clone.side_effect = pygit2.GitError("auth failed")
mock_git_cb.return_value = MagicMock()

fetcher = _make_fetcher()

with pytest.raises(pygit2.GitError, match="auth failed"):
await fetcher._clone()


@pytest.mark.asyncio
@patch("opal_server.git_fetcher.GitCallback")
@patch("opal_server.git_fetcher.clone_repository")
async def test_clone_success(mock_clone, mock_git_cb):
"""A successful clone should call _notify_on_changes."""
mock_repo = MagicMock()
mock_clone.return_value = mock_repo
mock_git_cb.return_value = MagicMock()

fetcher = _make_fetcher()
fetcher._notify_on_changes = AsyncMock()

await fetcher._clone()

mock_clone.assert_called_once()
fetcher._notify_on_changes.assert_awaited_once_with(mock_repo)


@pytest.mark.asyncio
@patch("opal_server.git_fetcher.GitCallback")
@patch("opal_server.git_fetcher.clone_repository")
async def test_clone_no_timeout_when_zero(mock_clone, mock_git_cb):
"""When clone_timeout=0, no asyncio.wait_for timeout is applied.

We verify this by having a function that takes a short time but would fail
with a zero-second timeout if one were mistakenly applied.
"""

def quick_clone(*args, **kwargs):
time.sleep(0.1)
return MagicMock()

mock_clone.side_effect = quick_clone
mock_git_cb.return_value = MagicMock()

fetcher = _make_fetcher(clone_timeout=0)
fetcher._notify_on_changes = AsyncMock()

# Should complete without TimeoutError
await fetcher._clone()
mock_clone.assert_called_once()


# ---------------------------------------------------------------------------
# Fetch tests
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
@patch("opal_server.git_fetcher.GitCallback")
@patch("opal_server.git_fetcher.discover_repository", return_value=True)
async def test_fetch_timeout_raises_on_slow_operation(mock_discover, mock_git_cb):
"""When the fetch operation blocks longer than the timeout, asyncio.TimeoutError propagates."""
mock_git_cb.return_value = MagicMock()

source = _make_source()
fetcher = _make_fetcher(source=source, clone_timeout=0.5)

# Set up a mock repo with a slow fetch
mock_remote = MagicMock()

def slow_fetch(**kwargs):
time.sleep(10)

mock_remote.fetch = slow_fetch

mock_repo = MagicMock()
mock_repo.remotes = {"origin": mock_remote}

# Patch _discover_repository and _get_valid_repo to enter the fetch path
fetcher._discover_repository = MagicMock(return_value=True)
fetcher._get_valid_repo = MagicMock(return_value=mock_repo)
fetcher._should_fetch = AsyncMock(return_value=True)
fetcher._notify_on_changes = AsyncMock()

with pytest.raises(asyncio.TimeoutError):
await fetcher.fetch_and_notify_on_changes(force_fetch=True)
Loading