Optimizations to reduce per-step I/O overhead and to reduce graph density#85
Conversation
There was a problem hiding this comment.
Pull request overview
This PR reduces runtime overhead in ProteinWaterDataset by optimizing cache loading, adding an optional per-worker in-process sample cache, and capping preprocessing-time graph neighborhood density to limit graph size.
Changes:
- Added
_load_torch_cache()wrapper to optionally usetorch.load(..., mmap=True)with a safe fallback, and threaded the option through geometry + embedding cache loads. - Added an optional per-process LRU cache for fully built
HeteroDatasamples in__getitem__(sample_cache_size), returning mutation-safe clones on cache hits. - Added a
max_neighborscap forradius_graphduring preprocessing and stored it in cache metadata for traceability; exposed new runtime flags inscripts/train.py.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
tests/test_dataset.py |
Adds tests for mutation-safe sample caching and for passing the mmap flag through dataset geometry loading. |
src/dataset.py |
Implements mmap-backed cache loading, per-process sample LRU caching, and radius_graph neighbor capping + metadata. |
scripts/train.py |
Exposes --sample_cache_size and --cache_load_mmap and threads them into dataset configuration. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Warning Review limit reached
More reviews will be available in 54 minutes and 48 seconds. Learn how PR review limits work. Your organization has used up its prepaid credits, and credit purchases are no longer available. Enable the review add-on in the billing tab to keep reviews running — you're only billed for reviews past your plan's rate limits ($0.25/file). ⌛ How to resolve this issue?After more reviews become available, a review can be triggered using the To avoid repeated limits, reduce automatic review volume by pausing incremental auto-reviews earlier, using label-based review opt-in, excluding WIP or generated PR titles, or requesting reviews manually when the PR is ready. If your team needs uninterrupted high-volume reviews, an organization admin can enable usage-based credits. 🚦 How do rate limits work?CodeRabbit enforces per-developer PR review limits for each organization. Most developers receive the normal plan review availability. For paid Pro and Pro+ PR reviews, CodeRabbit uses adaptive limits for sustained high-volume activity. When a developer's recent PR review activity reaches the 95th percentile or higher among CodeRabbit users, additional reviews become available more gradually as earlier reviews age out of the rolling window. Please see our Fair Usage Limits Policy for further information. 📝 WalkthroughWalkthroughAdds mmap-backed ChangesDataset caching and loading optimizations
Sequence Diagram(s)sequenceDiagram
participant DL as DataLoader
participant DS as ProteinWaterDataset.__getitem__
participant Cache as _sample_cache
participant LTC as _load_torch_cache
participant Ann as _annotate_data_with_embeddings
DL->>DS: __getitem__(idx)
DS->>Cache: lookup (actual_idx, cache_key)
alt cache hit
Cache-->>DS: HeteroData
DS-->>DL: clone(HeteroData)
else cache miss
DS->>LTC: geometry_path, cache_load_mmap
LTC-->>DS: geometry dict (mmap or fallback)
DS->>Ann: data, cache_load_mmap
Ann-->>DS: annotated HeteroData
DS->>Cache: store + evict LRU if over capacity
DS-->>DL: clone(HeteroData)
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Warning Review ran into problems🔥 ProblemsStopped waiting for pipeline failures after 30000ms. One of your pipelines takes longer than our 30000ms fetch window to run, so review may not consider pipeline-failure results for inline comments if any failures occurred after the fetch window. Increase the timeout if you want to wait longer or run a Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/dataset.py (1)
849-853:⚠️ Potential issue | 🟠 Major | ⚡ Quick winReject stale geometry caches when
max_neighborschanges.Line 1119 records the cap, but Lines 849-853 still reuse any existing
.ptsolely by path. Re-running with the sameprocessed_dirand a differentmax_neighborssilently serves oldpp_edge_indexand edge features, so the dataset no longer matches its configuration.🐛 Minimal fail-fast guard
cached = _load_torch_cache(cache_path, cache_load_mmap=self.cache_load_mmap) + cached_max_neighbors = cached.get("max_neighbors") + if cached_max_neighbors != self.max_neighbors: + raise ValueError( + f"Geometry cache {cache_path} was generated with " + f"max_neighbors={cached_max_neighbors}, but this dataset was " + f"configured with max_neighbors={self.max_neighbors}. " + "Regenerate the geometry cache or use a distinct geometry_cache_name." + ) # load all data directly from cache (already includes mates if applicable)Also applies to: 1119-1119, 1207-1215
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/dataset.py` around lines 849 - 853, The list comprehension filtering to_process entries at lines 849-853 only checks if the geometry cache file exists by path, but does not validate that the cached file was created with the same max_neighbors configuration. When max_neighbors changes on re-run with the same processed_dir, stale cache files are silently reused despite no longer matching the dataset configuration. Enhance the condition that checks for file existence to also validate that the cached geometry was created with the current max_neighbors value (as recorded at line 1119), ensuring entries with incompatible caches are included in to_process and re-processed.
🧹 Nitpick comments (1)
tests/test_dataset.py (1)
656-698: ⚡ Quick winCover the mmap opt-in path too.
This test only asserts
cache_load_mmap=False; if__getitem__accidentally hard-codedFalse, the new opt-in behavior would still pass. Parameterize both values.🧪 Proposed test tightening
- def test_getitem_passes_mmap_flag_to_geometry_loader(self, tmp_path, monkeypatch): + `@pytest.mark.parametrize`("cache_load_mmap", [False, True]) + def test_getitem_passes_mmap_flag_to_geometry_loader( + self, tmp_path, monkeypatch, cache_load_mmap + ): """Dataset geometry loading should use the configured mmap option.""" @@ include_mates=False, preprocess=False, - cache_load_mmap=False, + cache_load_mmap=cache_load_mmap, ) @@ assert data["protein"].num_nodes == 1 - assert calls == [(cache_path, False)] + assert calls == [(cache_path, cache_load_mmap)]🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_dataset.py` around lines 656 - 698, The test test_getitem_passes_mmap_flag_to_geometry_loader only covers the case where cache_load_mmap=False, which means if the code accidentally hard-coded False in __getitem__, the test would still pass. Parameterize this test using pytest.mark.parametrize to run with both cache_load_mmap=True and cache_load_mmap=False, and update the corresponding assertion on the calls variable to verify the correct mmap flag value is passed to the geometry loader in each case.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/dataset.py`:
- Around line 746-755: Add validation for the max_neighbors parameter in the
constructor to ensure it contains only positive values, similar to the existing
validation for sample_cache_size. After the sample_cache_size validation check,
add a check that max_neighbors is greater than 0 and raise a ValueError with an
appropriate message if it is not. This will prevent invalid graph topology when
max_neighbors is later passed to the radius_graph call and ensures consistency
with other parameter validations in the constructor.
---
Outside diff comments:
In `@src/dataset.py`:
- Around line 849-853: The list comprehension filtering to_process entries at
lines 849-853 only checks if the geometry cache file exists by path, but does
not validate that the cached file was created with the same max_neighbors
configuration. When max_neighbors changes on re-run with the same processed_dir,
stale cache files are silently reused despite no longer matching the dataset
configuration. Enhance the condition that checks for file existence to also
validate that the cached geometry was created with the current max_neighbors
value (as recorded at line 1119), ensuring entries with incompatible caches are
included in to_process and re-processed.
---
Nitpick comments:
In `@tests/test_dataset.py`:
- Around line 656-698: The test test_getitem_passes_mmap_flag_to_geometry_loader
only covers the case where cache_load_mmap=False, which means if the code
accidentally hard-coded False in __getitem__, the test would still pass.
Parameterize this test using pytest.mark.parametrize to run with both
cache_load_mmap=True and cache_load_mmap=False, and update the corresponding
assertion on the calls variable to verify the correct mmap flag value is passed
to the geometry loader in each case.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 55892353-48b4-479c-b491-bfebf0ad51ca
📒 Files selected for processing (3)
scripts/train.pysrc/dataset.pytests/test_dataset.py
marcuscollins
left a comment
There was a problem hiding this comment.
A couple small things for you to consider before merging, but otherwise approving.
| def _load_torch_cache(path: Path, cache_load_mmap: bool = True) -> dict: | ||
| """Load a torch cache file, using mmap when supported by the file/runtime.""" | ||
| if not cache_load_mmap: | ||
| return torch.load(path, weights_only=False) |
There was a problem hiding this comment.
should you pipe through weights_only?
There was a problem hiding this comment.
the weights_only parameter is True by default, I've set it to False everywhere torch.load is used to suppress the pickle concern warning since we load in self-generated .pt cache files at every spot we use torch.load. I do not expect to set it to True anywhere, hence not threading it as a parameter.
| self.filter_by_bfactor = filter_by_bfactor | ||
| self.sample_cache_size = int(sample_cache_size) | ||
| self.cache_load_mmap = bool(cache_load_mmap) | ||
| self._sample_cache: OrderedDict[tuple[int, str], HeteroData] = OrderedDict() |
There was a problem hiding this comment.
does it need to be ordered?
There was a problem hiding this comment.
LRU semantics require it to be ordered; it uses move_to_end to mark recently-used keys and popitem(last=False) to evict the oldest, neither of which a regular dict supports.
mmap-backed cache loading (
_load_torch_cache)torch.loadwithmmap=Truesogeometry/embedding .ptfiles are memory-mapped. Falls back to regular load if the runtime or file format doesn't support it.load_slae_embedding,load_esm_embedding, and__getitem__.--cache_load_mmap(default: off).Per-worker sample LRU cache (
sample_cache_size)OrderedDict-backed LRU cache in__getitem__. When a sample is already in cache, skips all I/O and returns a.clone(). Evicts least-recently-used entries when the cache is full.sample_cache_sizeN to hold N samples per worker.max_neighborscap onradius_graphradius_graphat preprocessing timeSummary by CodeRabbit
Release Notes
New Features
--sample_cache_sizecommand-line argument to enable per-worker in-process sample caching with configurable capacity--cache_load_mmapflag to enable optimized dataset cache loadingTests