-
Notifications
You must be signed in to change notification settings - Fork 1
Optimizations to reduce per-step I/O overhead and to reduce graph density #85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
|
|
||
| import itertools | ||
| import json | ||
| from collections import OrderedDict | ||
| from pathlib import Path | ||
|
|
||
| import biotite.structure as bts | ||
|
|
@@ -191,11 +192,24 @@ def _pad_atom_embeddings_for_mates( | |
| return torch.cat([asu_embedding, pad], dim=0) | ||
|
|
||
|
|
||
| def _load_torch_cache(path: Path, cache_load_mmap: bool = True) -> dict: | ||
| """Load a torch cache file, using mmap when supported by the file/runtime.""" | ||
| if not cache_load_mmap: | ||
| return torch.load(path, weights_only=False) | ||
|
|
||
| try: | ||
| return torch.load(path, weights_only=False, mmap=True) | ||
| except (TypeError, ValueError, RuntimeError, OSError) as exc: | ||
| logger.debug(f"mmap torch.load failed for {path}; falling back: {exc}") | ||
| return torch.load(path, weights_only=False) | ||
|
|
||
|
|
||
| def load_slae_embedding( | ||
| embedding_dir: Path, | ||
| cache_key: str, | ||
| num_asu_protein: int, | ||
| total_num_atoms: int, | ||
| cache_load_mmap: bool = True, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Load SLAE atom-level embeddings from cache. | ||
|
|
@@ -207,6 +221,7 @@ def load_slae_embedding( | |
| cache_key: Identifier for the cached embedding file | ||
| num_asu_protein: Expected number of ASU protein atoms | ||
| total_num_atoms: Total protein atoms including symmetry mates | ||
| cache_load_mmap: Use mmap-backed torch.load when supported | ||
|
|
||
| Returns: | ||
| (total_num_atoms, slae_dim) tensor with zeros padded for mate atoms | ||
|
|
@@ -221,7 +236,7 @@ def load_slae_embedding( | |
| f"SLAE cache file not found: {slae_cache_path}. " | ||
| "Generate embeddings with scripts/generate_slae_embeddings.py." | ||
| ) | ||
| slae_cached = torch.load(slae_cache_path, weights_only=False) | ||
| slae_cached = _load_torch_cache(slae_cache_path, cache_load_mmap=cache_load_mmap) | ||
| if "node_embeddings" not in slae_cached: | ||
| raise KeyError(f"Missing 'node_embeddings' in SLAE cache: {slae_cache_path}") | ||
| slae_emb = slae_cached["node_embeddings"] | ||
|
|
@@ -237,6 +252,7 @@ def load_esm_embedding( | |
| embedding_dir: Path, | ||
| cache_key: str, | ||
| num_protein_residues: int, | ||
| cache_load_mmap: bool = True, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Load ESM residue-level embeddings from cache. | ||
|
|
@@ -248,6 +264,7 @@ def load_esm_embedding( | |
| embedding_dir: Directory containing cached embedding files | ||
| cache_key: Identifier for the cached embedding file | ||
| num_protein_residues: Expected number of unique residues | ||
| cache_load_mmap: Use mmap-backed torch.load when supported | ||
|
|
||
| Returns: | ||
| (num_protein_residues, esm_dim) tensor of residue embeddings | ||
|
|
@@ -262,7 +279,7 @@ def load_esm_embedding( | |
| f"ESM cache file not found: {esm_cache_path}. " | ||
| "Generate embeddings with scripts/generate_esm_embeddings.py." | ||
| ) | ||
| esm_cached = torch.load(esm_cache_path, weights_only=False) | ||
| esm_cached = _load_torch_cache(esm_cache_path, cache_load_mmap=cache_load_mmap) | ||
| if "residue_embeddings" not in esm_cached: | ||
| raise KeyError(f"Missing 'residue_embeddings' in ESM cache: {esm_cache_path}") | ||
| residue_embeddings = esm_cached["residue_embeddings"] | ||
|
|
@@ -664,6 +681,7 @@ def __init__( | |
| encoder_type: str = "gvp", | ||
| base_pdb_dir: str = "/sb/wankowicz_lab/data/srivasv/pdb_redo_data", | ||
| cutoff: float = 8.0, | ||
| max_neighbors: int = 256, | ||
| include_mates: bool = True, | ||
| geometry_cache_name: str = "geometry", | ||
| preprocess: bool = True, | ||
|
|
@@ -679,6 +697,8 @@ def __init__( | |
| filter_by_distance: bool = True, | ||
| filter_by_edia: bool = True, | ||
| filter_by_bfactor: bool = True, | ||
| sample_cache_size: int = 0, | ||
| cache_load_mmap: bool = False, | ||
| ): | ||
| """ | ||
| Args: | ||
|
|
@@ -690,6 +710,7 @@ def __init__( | |
| Embeddings are loaded only for the selected type. | ||
| base_pdb_dir: Base directory containing PDB subdirectories | ||
| cutoff: Distance cutoff for PP edges and crystal contacts (Angstroms) | ||
| max_neighbors: Maximum neighbors per node for radius graph construction. | ||
| include_mates: If True, include symmetry mate atoms as protein nodes | ||
| geometry_cache_name: Base name for geometry cache directory. When | ||
| include_mates=True, "_mates" is appended automatically. | ||
|
|
@@ -717,14 +738,23 @@ def __init__( | |
| filter_by_edia: Enable/disable EDIA score filtering. | ||
| filter_by_bfactor: Enable/disable B-factor z-score filtering. | ||
| If a per-water filter is disabled, its threshold is ignored. | ||
| sample_cache_size: Number of fully built HeteroData samples to keep in a | ||
| per-process LRU cache. 0 disables sample caching. | ||
| cache_load_mmap: Use mmap-backed torch.load for cache files when supported. | ||
| """ | ||
|
|
||
| if sample_cache_size < 0: | ||
| raise ValueError("sample_cache_size must be >= 0") | ||
| if max_neighbors < 1: | ||
| raise ValueError("max_neighbors must be >= 1") | ||
|
|
||
| self.cache_dir = Path(processed_dir) | ||
| # Directory-based separation: geometry/ vs geometry_mates/ | ||
| cache_suffix = "_mates" if include_mates else "" | ||
| self.geometry_dir = self.cache_dir / f"{geometry_cache_name}{cache_suffix}" | ||
| self.base_pdb_dir = Path(base_pdb_dir) | ||
| self.cutoff = cutoff | ||
| self.max_neighbors = max_neighbors | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| self.encoder_type = encoder_type | ||
| if self.encoder_type in ("slae", "esm"): | ||
| self.embedding_dir = self.cache_dir / self.encoder_type | ||
|
|
@@ -745,6 +775,9 @@ def __init__( | |
| self.filter_by_distance = filter_by_distance | ||
| self.filter_by_edia = filter_by_edia | ||
| self.filter_by_bfactor = filter_by_bfactor | ||
| self.sample_cache_size = int(sample_cache_size) | ||
| self.cache_load_mmap = bool(cache_load_mmap) | ||
| self._sample_cache: OrderedDict[tuple[int, str], HeteroData] = OrderedDict() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does it need to be ordered?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LRU semantics require it to be ordered; it uses |
||
|
|
||
| if self.encoder_type not in {"gvp", "slae", "esm"}: | ||
| raise ValueError( | ||
|
|
@@ -1048,7 +1081,12 @@ def _preprocess_one(self, entry: dict, cache_path: Path): | |
|
|
||
| # Compute PP edges and features | ||
| if final_protein_pos.size(0) > 0: | ||
| pp_edge_index = radius_graph(final_protein_pos, r=self.cutoff, loop=False) | ||
| pp_edge_index = radius_graph( | ||
| final_protein_pos, | ||
| r=self.cutoff, | ||
| loop=False, | ||
| max_num_neighbors=self.max_neighbors, | ||
| ) | ||
| pp_edge_index = _make_undirected(pp_edge_index) | ||
| pp_edge_unit_vectors, pp_edge_rbf = compute_edge_features( | ||
| final_protein_pos, | ||
|
|
@@ -1080,6 +1118,7 @@ def _preprocess_one(self, entry: dict, cache_path: Path): | |
| # Metadata | ||
| "num_asu_protein": num_asu_protein, | ||
| "num_protein_residues": num_residues, | ||
| "max_neighbors": self.max_neighbors, | ||
| }, | ||
| cache_path, | ||
| ) | ||
|
|
@@ -1116,6 +1155,7 @@ def _annotate_data_with_embeddings( | |
| cache_key=cache_key, | ||
| num_asu_protein=num_asu_protein, | ||
| total_num_atoms=data["protein"].num_nodes, | ||
| cache_load_mmap=self.cache_load_mmap, | ||
| ) | ||
| data["protein"].embedding_type = "slae" | ||
| elif self.encoder_type == "esm": | ||
|
|
@@ -1124,6 +1164,7 @@ def _annotate_data_with_embeddings( | |
| embedding_dir=self.embedding_dir, | ||
| cache_key=cache_key, | ||
| num_protein_residues=num_protein_residues, | ||
| cache_load_mmap=self.cache_load_mmap, | ||
| ) | ||
| esm_atom_emb = residue_embeddings[asu_protein_res_idx] | ||
| data["protein"].embedding = _pad_atom_embeddings_for_mates( | ||
|
|
@@ -1150,6 +1191,13 @@ def __getitem__(self, idx: int) -> HeteroData: | |
|
|
||
| actual_idx = idx % len(self.entries) | ||
| entry = self.entries[actual_idx] | ||
| sample_cache_key = (actual_idx, entry["cache_key"]) | ||
| if self.sample_cache_size > 0: | ||
| cached_sample = self._sample_cache.get(sample_cache_key) | ||
| if cached_sample is not None: | ||
| self._sample_cache.move_to_end(sample_cache_key) | ||
| return cached_sample.clone() | ||
|
|
||
| cache_path = self.geometry_dir / f"{entry['cache_key']}.pt" | ||
|
|
||
| if not cache_path.exists(): | ||
|
|
@@ -1158,7 +1206,7 @@ def __getitem__(self, idx: int) -> HeteroData: | |
| f"Run with preprocess=True to generate it." | ||
| ) | ||
|
|
||
| cached = torch.load(cache_path, weights_only=False) | ||
| cached = _load_torch_cache(cache_path, cache_load_mmap=self.cache_load_mmap) | ||
|
|
||
| # load all data directly from cache (already includes mates if applicable) | ||
| protein_pos = cached["protein_pos"] | ||
|
|
@@ -1210,6 +1258,13 @@ def __getitem__(self, idx: int) -> HeteroData: | |
| data.pdb_id = entry["embedding_key"] | ||
| data.num_asu_protein_atoms = num_asu_protein | ||
|
|
||
| if self.sample_cache_size > 0: | ||
| self._sample_cache[sample_cache_key] = data | ||
| self._sample_cache.move_to_end(sample_cache_key) | ||
| while len(self._sample_cache) > self.sample_cache_size: | ||
| self._sample_cache.popitem(last=False) | ||
| return data.clone() | ||
|
|
||
| return data | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should you pipe through
weights_only?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the
weights_onlyparameter is True by default, I've set it to False everywhere torch.load is used to suppress the pickle concern warning since we load in self-generated .pt cache files at every spot we use torch.load. I do not expect to set it to True anywhere, hence not threading it as a parameter.