KNN edge construction fix + align residue indexing with ESM sanitization#82
KNN edge construction fix + align residue indexing with ESM sanitization#82vratins wants to merge 7 commits into
Conversation
📝 WalkthroughWalkthroughThe PR adds shared ESM residue-name sanitization, uses it in embedding generation and dataset residue indexing, and updates KNN edge construction to emit source-to-destination directed edges. Related tests now assert the updated residue and edge semantics. ChangesESM Residue Sanitization Alignment
KNN Edge Direction Correction
Estimated code review effort🎯 4 (Complex) | ⏱️ ~40 minutes Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Pull request overview
This PR fixes KNN-based edge construction directionality in the flow model, hardens dataset preprocessing for edge cases, and aligns protein residue indexing with the same residue-name sanitization used when generating cached ESM embeddings.
Changes:
- Fix
build_knn_edgesto query per-destination point (and swap returned index rows) so edges align with intended src→dst semantics. - Update dataset preprocessing to handle empty
atomsinputs and to compute residue indices using ESM-style residue-name canonicalization before residue-boundary detection. - Adjust tests by adding an
xfailmarker for a batched edge-connectivity test (though the WW coverage assertions likely need updating instead).
Reviewed changes
Copilot reviewed 3 out of 4 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
uv.lock |
Updates locked dependencies (adds jaxtyping, removes mypy, adjusts some wheels/metadata). |
tests/test_flow.py |
Marks the batched water-edge connectivity test as xfail(strict=True) and updates rationale text. |
src/flow.py |
Fixes KNN query argument order and explicitly swaps index rows to preserve src→dst edge_index layout. |
src/dataset.py |
Handles empty atoms in coordinate matching; updates residue indexing to mirror ESM sanitization behavior. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/test_flow.py (1)
582-617:⚠️ Potential issue | 🟠 Major | ⚡ Quick winNarrow the
xfailscope so it doesn’t hidepwregressions.
xfailis currently applied to the whole test, so failures in the protein-water assertions are also treated as expected. That drops useful coverage beyond the knownwwissue.Suggested split to preserve `pw` coverage
- `@pytest.mark.xfail`( + def test_batched_waters_have_protein_edges(self, batched_hetero_data): + """Ensure all waters in a batched graph have protein-water edges.""" + updater = ProteinWaterUpdate(hidden_dims=(128, 16), layers=1) + edge_dict = updater.build_edges(batched_hetero_data, k_pw=4, k_ww=3) + pw_edges = edge_dict[("protein", "pw", "water")] + n_water = batched_hetero_data["water"].num_nodes + water_nodes_with_pw_edges = torch.unique(pw_edges[1]) + assert len(water_nodes_with_pw_edges) == n_water, ( + f"Only {len(water_nodes_with_pw_edges)}/{n_water} waters have protein edges in batched data" + ) + + `@pytest.mark.xfail`( reason=( "build_knn_edges' src/dst argument-order fix changes self-graph (ww) " "edge direction: row 0 now holds discovered neighbors rather than query " "points, so a point that is nobody's k-nearest neighbor can be dropped " "from coverage. The fixed-degree k_pw/k_ww KNN approach is replaced by " "radius-based edges + KNN-fallback-for-isolated-nodes in a future PR " "(edge type flags & dynamic edge construction), which removes the " "k_pw/k_ww params and fixes this guarantee structurally. will remove this " "marker when that PR is created." ), strict=True, ) - def test_batched_waters_have_edges(self, batched_hetero_data): - """Ensure all waters in a batched graph have edges.""" + def test_batched_waters_have_water_edges(self, batched_hetero_data): + """Ensure all waters in a batched graph have water-water edges.""" updater = ProteinWaterUpdate(hidden_dims=(128, 16), layers=1) - edge_dict = updater.build_edges(batched_hetero_data, k_pw=4, k_ww=3) - pw_edges = edge_dict[("protein", "pw", "water")] ww_edges = edge_dict[("water", "ww", "water")] - n_water = batched_hetero_data["water"].num_nodes - - # Check protein-water edges - water_nodes_with_pw_edges = torch.unique(pw_edges[1]) - assert len(water_nodes_with_pw_edges) == n_water, ( - f"Only {len(water_nodes_with_pw_edges)}/{n_water} waters have protein edges in batched data" - ) - - # Check water-water edges if n_water > 1: water_nodes_with_ww_edges = torch.unique(ww_edges[0]) assert len(water_nodes_with_ww_edges) == n_water, ( f"Only {len(water_nodes_with_ww_edges)}/{n_water} waters have water-water edges in batched data" )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_flow.py` around lines 582 - 617, The xfail marker is currently applied to the entire test_batched_waters_have_edges function, which hides failures in the protein-water edge assertions that should not be expected to fail. Remove the xfail decorator from the function and instead apply it only to the water-water edge checking section (the assertions checking water_nodes_with_ww_edges). This can be done by either splitting the test into two separate test functions with xfail only on the water-water test, or by wrapping just the water-water edge assertion block with pytest.xfail() to preserve protein-water edge coverage while still allowing the known water-water edge failure.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/dataset.py`:
- Around line 993-1000: The insertion code normalization is missing before
calculating residue starts, which causes misalignment with the cached ESM
embeddings. After the loop that sanitizes res_name for the sanitized_for_idx
object (which converts three-letter codes to one-letter and back), add code to
normalize the ins_code field by setting blank or non-standard insertion codes to
a consistent placeholder value (similar to how "X" is used for unknown
residues). This normalization must occur before calling
bts.get_residue_starts(sanitized_for_idx) to ensure the residue count and
protein_res_idx indices match what was computed in generate_esm_embeddings.py.
---
Outside diff comments:
In `@tests/test_flow.py`:
- Around line 582-617: The xfail marker is currently applied to the entire
test_batched_waters_have_edges function, which hides failures in the
protein-water edge assertions that should not be expected to fail. Remove the
xfail decorator from the function and instead apply it only to the water-water
edge checking section (the assertions checking water_nodes_with_ww_edges). This
can be done by either splitting the test into two separate test functions with
xfail only on the water-water test, or by wrapping just the water-water edge
assertion block with pytest.xfail() to preserve protein-water edge coverage
while still allowing the known water-water edge failure.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 65109962-a7f2-481b-b0ef-737262e6f23a
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (3)
src/dataset.pysrc/flow.pytests/test_flow.py
build_knn_edgeswas callingknn(x=dst_pos, y=src_pos), which queries eachsrcpoint's nearestdstpoints instead of eachdstpoint's nearestsrcpoints; swapped the call and the resulting index rows to fix this. This was previously masked by taking the union of edges on both sides. A future PR will use KNN as a fallback to radius graphs. Expanded the docstring to document that the query is per-destination, so every destination is guaranteed incoming edges (row 1) while a source that is nobody's nearest neighbor may be absent fromrow 0— and updated the water-water coverage tests to assert on the destination row (row 1) accordingly, removing the now-unnecessaryxfailmarkers.match_atoms_to_coordsnow also handles an emptyatomsarray instead of only an emptytarget_coordsarray.THREE_TO_ONE->ONE_TO_THREE, unknowns ->UNK) before counting residue boundaries with biotite'sget_residue_starts. Without this, two residues that share (chain,resid,ins_code) but had different originalres_namescould get merged into one under ESM's sanitization but stay separate here, desyncing residue counts/indices from the stored ESM embeddings. Insertion codes are now also normalized (normalize_ins_code) before counting, sinceget_residue_startssplits onins_codetoo — a blank vs. placeholder code (''/'.'/'?') would otherwise split or merge residues differently from the ESM script's residue keys. Both the canonicalization (sanitize_res_names_for_esm) and the insertion-code normalization are now shared betweensrc/dataset.pyandscripts/generate_esm_embeddings.pyto prevent the two paths from drifting apart.build_knn_edgesinsrc/utils.py(the canonical one lives insrc/flow.py) that still carried the old per-source semantics, and removed the orphanedatom37_to_atoms/ATOM37_FILLhelper (unused outside its own tests; the SLAE pipeline uses its own implementation). Added regression tests covering residue-count alignment between the dataset and the ESM residue keys.Summary by CodeRabbit
Bug Fixes
Tests
Documentation