Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,18 @@ def parse_args():
action="store_true",
help="Keep workers alive between epochs",
)
p.add_argument(
"--sample_cache_size",
type=int,
default=0,
help="Per-worker in-process dataset sample LRU cache size (0 disables caching)",
)
p.add_argument(
"--cache_load_mmap",
action="store_true",
default=False,
help="Use mmap-backed torch.load for dataset cache files when supported",
)

# scheduler
p.add_argument(
Expand Down Expand Up @@ -305,6 +317,8 @@ def parse_args():
args = p.parse_args()
if args.encoder_type == "gvp" and args.embedding_dim is not None:
p.error("--embedding_dim is only valid for cached encoders: slae or esm")
if args.sample_cache_size < 0:
p.error("--sample_cache_size must be >= 0")
return args


Expand Down Expand Up @@ -351,6 +365,8 @@ def _build_dataset_config(args: argparse.Namespace) -> tuple[dict, dict, dict]:
"base_pdb_dir": args.base_pdb_dir,
"geometry_cache_name": args.geometry_cache_name,
"include_mates": args.include_mates,
"sample_cache_size": args.sample_cache_size,
"cache_load_mmap": args.cache_load_mmap,
**quality_kwargs,
**water_filter_kwargs,
}
Expand Down
63 changes: 59 additions & 4 deletions src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import itertools
import json
from collections import OrderedDict
from pathlib import Path

import biotite.structure as bts
Expand Down Expand Up @@ -191,11 +192,24 @@ def _pad_atom_embeddings_for_mates(
return torch.cat([asu_embedding, pad], dim=0)


def _load_torch_cache(path: Path, cache_load_mmap: bool = True) -> dict:
"""Load a torch cache file, using mmap when supported by the file/runtime."""
if not cache_load_mmap:
return torch.load(path, weights_only=False)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should you pipe through weights_only?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the weights_only parameter is True by default, I've set it to False everywhere torch.load is used to suppress the pickle concern warning since we load in self-generated .pt cache files at every spot we use torch.load. I do not expect to set it to True anywhere, hence not threading it as a parameter.


try:
return torch.load(path, weights_only=False, mmap=True)
except (TypeError, ValueError, RuntimeError, OSError) as exc:
logger.debug(f"mmap torch.load failed for {path}; falling back: {exc}")
return torch.load(path, weights_only=False)


def load_slae_embedding(
embedding_dir: Path,
cache_key: str,
num_asu_protein: int,
total_num_atoms: int,
cache_load_mmap: bool = True,
) -> torch.Tensor:
"""
Load SLAE atom-level embeddings from cache.
Expand All @@ -207,6 +221,7 @@ def load_slae_embedding(
cache_key: Identifier for the cached embedding file
num_asu_protein: Expected number of ASU protein atoms
total_num_atoms: Total protein atoms including symmetry mates
cache_load_mmap: Use mmap-backed torch.load when supported

Returns:
(total_num_atoms, slae_dim) tensor with zeros padded for mate atoms
Expand All @@ -221,7 +236,7 @@ def load_slae_embedding(
f"SLAE cache file not found: {slae_cache_path}. "
"Generate embeddings with scripts/generate_slae_embeddings.py."
)
slae_cached = torch.load(slae_cache_path, weights_only=False)
slae_cached = _load_torch_cache(slae_cache_path, cache_load_mmap=cache_load_mmap)
if "node_embeddings" not in slae_cached:
raise KeyError(f"Missing 'node_embeddings' in SLAE cache: {slae_cache_path}")
slae_emb = slae_cached["node_embeddings"]
Expand All @@ -237,6 +252,7 @@ def load_esm_embedding(
embedding_dir: Path,
cache_key: str,
num_protein_residues: int,
cache_load_mmap: bool = True,
) -> torch.Tensor:
"""
Load ESM residue-level embeddings from cache.
Expand All @@ -248,6 +264,7 @@ def load_esm_embedding(
embedding_dir: Directory containing cached embedding files
cache_key: Identifier for the cached embedding file
num_protein_residues: Expected number of unique residues
cache_load_mmap: Use mmap-backed torch.load when supported

Returns:
(num_protein_residues, esm_dim) tensor of residue embeddings
Expand All @@ -262,7 +279,7 @@ def load_esm_embedding(
f"ESM cache file not found: {esm_cache_path}. "
"Generate embeddings with scripts/generate_esm_embeddings.py."
)
esm_cached = torch.load(esm_cache_path, weights_only=False)
esm_cached = _load_torch_cache(esm_cache_path, cache_load_mmap=cache_load_mmap)
if "residue_embeddings" not in esm_cached:
raise KeyError(f"Missing 'residue_embeddings' in ESM cache: {esm_cache_path}")
residue_embeddings = esm_cached["residue_embeddings"]
Expand Down Expand Up @@ -664,6 +681,7 @@ def __init__(
encoder_type: str = "gvp",
base_pdb_dir: str = "/sb/wankowicz_lab/data/srivasv/pdb_redo_data",
cutoff: float = 8.0,
max_neighbors: int = 256,
include_mates: bool = True,
geometry_cache_name: str = "geometry",
preprocess: bool = True,
Expand All @@ -679,6 +697,8 @@ def __init__(
filter_by_distance: bool = True,
filter_by_edia: bool = True,
filter_by_bfactor: bool = True,
sample_cache_size: int = 0,
cache_load_mmap: bool = False,
):
"""
Args:
Expand All @@ -690,6 +710,7 @@ def __init__(
Embeddings are loaded only for the selected type.
base_pdb_dir: Base directory containing PDB subdirectories
cutoff: Distance cutoff for PP edges and crystal contacts (Angstroms)
max_neighbors: Maximum neighbors per node for radius graph construction.
include_mates: If True, include symmetry mate atoms as protein nodes
geometry_cache_name: Base name for geometry cache directory. When
include_mates=True, "_mates" is appended automatically.
Expand Down Expand Up @@ -717,14 +738,23 @@ def __init__(
filter_by_edia: Enable/disable EDIA score filtering.
filter_by_bfactor: Enable/disable B-factor z-score filtering.
If a per-water filter is disabled, its threshold is ignored.
sample_cache_size: Number of fully built HeteroData samples to keep in a
per-process LRU cache. 0 disables sample caching.
cache_load_mmap: Use mmap-backed torch.load for cache files when supported.
"""

if sample_cache_size < 0:
raise ValueError("sample_cache_size must be >= 0")
if max_neighbors < 1:
raise ValueError("max_neighbors must be >= 1")

self.cache_dir = Path(processed_dir)
# Directory-based separation: geometry/ vs geometry_mates/
cache_suffix = "_mates" if include_mates else ""
self.geometry_dir = self.cache_dir / f"{geometry_cache_name}{cache_suffix}"
self.base_pdb_dir = Path(base_pdb_dir)
self.cutoff = cutoff
self.max_neighbors = max_neighbors
Comment thread
coderabbitai[bot] marked this conversation as resolved.
self.encoder_type = encoder_type
if self.encoder_type in ("slae", "esm"):
self.embedding_dir = self.cache_dir / self.encoder_type
Expand All @@ -745,6 +775,9 @@ def __init__(
self.filter_by_distance = filter_by_distance
self.filter_by_edia = filter_by_edia
self.filter_by_bfactor = filter_by_bfactor
self.sample_cache_size = int(sample_cache_size)
self.cache_load_mmap = bool(cache_load_mmap)
self._sample_cache: OrderedDict[tuple[int, str], HeteroData] = OrderedDict()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it need to be ordered?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LRU semantics require it to be ordered; it uses move_to_end to mark recently-used keys and popitem(last=False) to evict the oldest, neither of which a regular dict supports.


if self.encoder_type not in {"gvp", "slae", "esm"}:
raise ValueError(
Expand Down Expand Up @@ -1048,7 +1081,12 @@ def _preprocess_one(self, entry: dict, cache_path: Path):

# Compute PP edges and features
if final_protein_pos.size(0) > 0:
pp_edge_index = radius_graph(final_protein_pos, r=self.cutoff, loop=False)
pp_edge_index = radius_graph(
final_protein_pos,
r=self.cutoff,
loop=False,
max_num_neighbors=self.max_neighbors,
)
pp_edge_index = _make_undirected(pp_edge_index)
pp_edge_unit_vectors, pp_edge_rbf = compute_edge_features(
final_protein_pos,
Expand Down Expand Up @@ -1080,6 +1118,7 @@ def _preprocess_one(self, entry: dict, cache_path: Path):
# Metadata
"num_asu_protein": num_asu_protein,
"num_protein_residues": num_residues,
"max_neighbors": self.max_neighbors,
},
cache_path,
)
Expand Down Expand Up @@ -1116,6 +1155,7 @@ def _annotate_data_with_embeddings(
cache_key=cache_key,
num_asu_protein=num_asu_protein,
total_num_atoms=data["protein"].num_nodes,
cache_load_mmap=self.cache_load_mmap,
)
data["protein"].embedding_type = "slae"
elif self.encoder_type == "esm":
Expand All @@ -1124,6 +1164,7 @@ def _annotate_data_with_embeddings(
embedding_dir=self.embedding_dir,
cache_key=cache_key,
num_protein_residues=num_protein_residues,
cache_load_mmap=self.cache_load_mmap,
)
esm_atom_emb = residue_embeddings[asu_protein_res_idx]
data["protein"].embedding = _pad_atom_embeddings_for_mates(
Expand All @@ -1150,6 +1191,13 @@ def __getitem__(self, idx: int) -> HeteroData:

actual_idx = idx % len(self.entries)
entry = self.entries[actual_idx]
sample_cache_key = (actual_idx, entry["cache_key"])
if self.sample_cache_size > 0:
cached_sample = self._sample_cache.get(sample_cache_key)
if cached_sample is not None:
self._sample_cache.move_to_end(sample_cache_key)
return cached_sample.clone()

cache_path = self.geometry_dir / f"{entry['cache_key']}.pt"

if not cache_path.exists():
Expand All @@ -1158,7 +1206,7 @@ def __getitem__(self, idx: int) -> HeteroData:
f"Run with preprocess=True to generate it."
)

cached = torch.load(cache_path, weights_only=False)
cached = _load_torch_cache(cache_path, cache_load_mmap=self.cache_load_mmap)

# load all data directly from cache (already includes mates if applicable)
protein_pos = cached["protein_pos"]
Expand Down Expand Up @@ -1210,6 +1258,13 @@ def __getitem__(self, idx: int) -> HeteroData:
data.pdb_id = entry["embedding_key"]
data.num_asu_protein_atoms = num_asu_protein

if self.sample_cache_size > 0:
self._sample_cache[sample_cache_key] = data
self._sample_cache.move_to_end(sample_cache_key)
while len(self._sample_cache) > self.sample_cache_size:
self._sample_cache.popitem(last=False)
return data.clone()

return data


Expand Down
66 changes: 66 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,72 @@ def test_duplicate_single_sample(
data_5 = dataset[5]
assert torch.allclose(data_0["protein"].pos, data_5["protein"].pos)

def test_sample_cache_returns_mutation_safe_clones(
self, single_pdb_list_file, tmp_processed_dir, pdb_base_dir
):
"""Cached samples should not be corrupted by mutations to returned data."""
dataset = ProteinWaterDataset(
pdb_list_file=single_pdb_list_file,
processed_dir=str(tmp_processed_dir),
base_pdb_dir=str(pdb_base_dir),
preprocess=True,
sample_cache_size=1,
)

first = dataset[0]
original_water_pos = first["water"].pos.clone()
assert original_water_pos.numel() > 0

first["water"].pos.add_(100.0)
second = dataset[0]

assert torch.allclose(second["water"].pos, original_water_pos)
assert not torch.allclose(second["water"].pos, first["water"].pos)

def test_getitem_passes_mmap_flag_to_geometry_loader(self, tmp_path, monkeypatch):
"""Dataset geometry loading should use the configured mmap option."""
list_file = tmp_path / "list.txt"
list_file.write_text("test_final\n")
processed_dir = tmp_path / "processed"
geometry_dir = processed_dir / "geometry"
geometry_dir.mkdir(parents=True)
cache_path = geometry_dir / "test_final.pt"
cache_path.touch()

cached_geometry = {
"protein_pos": torch.zeros((1, 3), dtype=torch.float32),
"protein_x": torch.zeros((1, len(ELEMENT_VOCAB) + 1), dtype=torch.float32),
"protein_res_idx": torch.zeros(1, dtype=torch.long),
"pp_edge_index": torch.empty((2, 0), dtype=torch.long),
"pp_edge_unit_vectors": torch.empty((0, 3), dtype=torch.float32),
"pp_edge_rbf": torch.empty((0, 16), dtype=torch.float32),
"num_asu_protein": 1,
"num_protein_residues": 1,
"water_pos": torch.zeros((1, 3), dtype=torch.float32),
"water_x": torch.zeros((1, len(ELEMENT_VOCAB) + 1), dtype=torch.float32),
}
calls = []

def fake_load(path, *, cache_load_mmap=True):
calls.append((path, cache_load_mmap))
return cached_geometry

monkeypatch.setattr("src.dataset._load_torch_cache", fake_load)

dataset = ProteinWaterDataset(
pdb_list_file=str(list_file),
processed_dir=str(processed_dir),
base_pdb_dir=str(tmp_path / "pdb"),
include_mates=False,
preprocess=False,
cache_load_mmap=False,
)

data = dataset[0]

assert data["protein"].num_nodes == 1
assert calls == [(cache_path, False)]

def test_cached_file_created(
self, single_pdb_list_file, tmp_processed_dir, pdb_base_dir
):
Expand Down
Loading