diff --git a/fastembed/common/model_management.py b/fastembed/common/model_management.py index 35e682366..90db37908 100644 --- a/fastembed/common/model_management.py +++ b/fastembed/common/model_management.py @@ -304,20 +304,44 @@ def decompress_to_cache(cls, targz_path: str, cache_dir: str) -> str: try: # Open the tar.gz file with tarfile.open(targz_path, "r:gz") as tar: + cache_path = Path(cache_dir).resolve() + members = list(tar.getmembers()) + for member in members: + cls._validate_tar_member(member, cache_path) + # Extract all files into the cache directory - tar.extractall( - path=cache_dir, - ) - except tarfile.TarError as e: + try: + tar.extractall(path=cache_dir, members=members, filter="data") + except TypeError: # pragma: no cover - Python < 3.12 + tar.extractall(path=cache_dir, members=members) + except (tarfile.TarError, ValueError) as e: # If any error occurs while opening or extracting the tar.gz file, # delete the cache directory (if it was created in this function) # and raise the error again - if "tmp" in cache_dir: + if "tmp" in cache_dir and os.path.exists(cache_dir): shutil.rmtree(cache_dir) raise ValueError(f"An error occurred while decompressing {targz_path}: {e}") return cache_dir + @staticmethod + def _validate_tar_member(member: tarfile.TarInfo, cache_path: Path) -> None: + target_path = (cache_path / member.name).resolve() + if not target_path.is_relative_to(cache_path): + raise ValueError(f"Unsafe tar member path: {member.name}") + + if member.issym() or member.islnk(): + link_name = Path(member.linkname) + if member.issym(): + link_target = ( + link_name if link_name.is_absolute() else target_path.parent / link_name + ) + else: + link_target = link_name if link_name.is_absolute() else cache_path / link_name + + if not link_target.resolve().is_relative_to(cache_path): + raise ValueError(f"Unsafe tar link target: {member.name} -> {member.linkname}") + @classmethod def retrieve_model_gcs( cls, diff --git a/tests/test_model_management.py b/tests/test_model_management.py new file mode 100644 index 000000000..9d66119d9 --- /dev/null +++ b/tests/test_model_management.py @@ -0,0 +1,54 @@ +import io +import tarfile + +import pytest + +from fastembed.common.model_management import ModelManagement + + +def _add_file(tar: tarfile.TarFile, name: str, data: bytes) -> None: + info = tarfile.TarInfo(name=name) + info.size = len(data) + tar.addfile(info, io.BytesIO(data)) + + +def test_decompress_to_cache_extracts_safe_members(tmp_path): + archive_path = tmp_path / "model.tar.gz" + cache_dir = tmp_path / "cache" + cache_dir.mkdir() + + with tarfile.open(archive_path, "w:gz") as tar: + _add_file(tar, "model/config.json", b"{}") + + assert ModelManagement.decompress_to_cache(str(archive_path), str(cache_dir)) == str(cache_dir) + assert (cache_dir / "model" / "config.json").read_bytes() == b"{}" + + +def test_decompress_to_cache_rejects_member_path_traversal(tmp_path): + archive_path = tmp_path / "model.tar.gz" + cache_dir = tmp_path / "cache" + cache_dir.mkdir() + outside_path = tmp_path / "outside.txt" + + with tarfile.open(archive_path, "w:gz") as tar: + _add_file(tar, "../outside.txt", b"owned") + + with pytest.raises(ValueError, match="Unsafe tar member path"): + ModelManagement.decompress_to_cache(str(archive_path), str(cache_dir)) + + assert not outside_path.exists() + + +def test_decompress_to_cache_rejects_symlink_path_traversal(tmp_path): + archive_path = tmp_path / "model.tar.gz" + cache_dir = tmp_path / "cache" + cache_dir.mkdir() + + with tarfile.open(archive_path, "w:gz") as tar: + link = tarfile.TarInfo(name="model/escape") + link.type = tarfile.SYMTYPE + link.linkname = "../../outside.txt" + tar.addfile(link) + + with pytest.raises(ValueError, match="Unsafe tar link target"): + ModelManagement.decompress_to_cache(str(archive_path), str(cache_dir))