feat(rewards): add reciprocal-space reward (Fprotein from SFC)#272
feat(rewards): add reciprocal-space reward (Fprotein from SFC)#272DorisMai wants to merge 8 commits into
Conversation
…fixing in prepare()
📝 WalkthroughWalkthroughIntroduces ChangesStructure Factor Reward Function
atomarray_to_gemmi shared utility
Synthetic SF generation multi-label MTZ refactor
Reward function contract test suite
Sequence Diagram(s)sequenceDiagram
participant Caller
participant StructureFactorRewardFunction
participant _detect_mtz_metadata
participant atomarray_to_gemmi
participant SFcalculator
rect rgba(100, 149, 237, 0.5)
note over Caller,StructureFactorRewardFunction: Phase 1 — Configuration
Caller->>StructureFactorRewardFunction: __init__(mtzfile, resolution, loss, ...)
StructureFactorRewardFunction->>_detect_mtz_metadata: read unit_cell, expcolumns from MTZ
_detect_mtz_metadata-->>StructureFactorRewardFunction: unit_cell, space_group, expcolumns
end
rect rgba(60, 179, 113, 0.5)
note over Caller,SFcalculator: Phase 2 — Preparation
Caller->>StructureFactorRewardFunction: prepare(atom_array)
StructureFactorRewardFunction->>atomarray_to_gemmi: convert AtomArray → gemmi.Structure
atomarray_to_gemmi-->>StructureFactorRewardFunction: gemmi.Structure
StructureFactorRewardFunction->>SFcalculator: __init__(pdbmodel=PDBParser(gemmi_structure))
StructureFactorRewardFunction->>SFcalculator: inspect_data()
SFcalculator-->>StructureFactorRewardFunction: Fo, Eo, outlier flags
StructureFactorRewardFunction-->>Caller: ready (reflection mask built)
end
rect rgba(210, 105, 30, 0.5)
note over Caller,SFcalculator: Phase 3 — Forward pass
Caller->>StructureFactorRewardFunction: __call__(coordinates, elements, b_factors, occupancies)
StructureFactorRewardFunction->>SFcalculator: write occupancy/B, calc_fprotein_batch
SFcalculator-->>StructureFactorRewardFunction: per-reflection complex SFs
StructureFactorRewardFunction->>StructureFactorRewardFunction: sum ensemble, compute amplitudes, apply loss mask
StructureFactorRewardFunction-->>Caller: scalar loss tensor
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 7
🧹 Nitpick comments (2)
src/sampleworks/eval/generate_synthetic_sf.py (1)
137-145: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winUse a NumPy-style docstring for this new helper.
The summary is useful, but this new function is missing the required
Parameters/Returnssections. As per coding guidelines, “Always include NumPy-style docstrings for every function and class.”Proposed docstring update
- """Build a one-amplitude rs.DataSet with labelled F / SIGF / PHIF columns. - - ``sfc.prepare_dataset`` returns an amplitude column and a phase column (degrees) - for the given ``structure_factor_column`` attribute. We auto-detect those by MTZ - dtype (rather than assuming the unexposed ``FMODEL`` / ``PHIFMODEL`` names), - rename them to ``F{label}`` / ``PHIF{label}``, and synthesize a ``SIGF{label}`` - column so several structure-factor sets (e.g. protein and total) can coexist in - one MTZ. - """ + """Build a one-amplitude dataset with labelled F / SIGF / PHIF columns. + + Parameters + ---------- + sfc + Structure-factor calculator containing the requested ASU amplitudes. + label + Output column label suffix, e.g. ``protein`` or ``total``. + structure_factor_column + SFcalculator attribute passed to ``prepare_dataset``. + miller_index_column + SFcalculator attribute containing Miller indices. + sigma_f_scale + Scale factor used to synthesize dummy SIGF values from amplitudes. + + Returns + ------- + rs.DataSet + Dataset containing ``F{label}``, ``SIGF{label}``, and ``PHIF{label}``. + """🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/sampleworks/eval/generate_synthetic_sf.py` around lines 137 - 145, The new helper in generate_synthetic_sf.py has a summary docstring but is missing the required NumPy-style structure. Update the docstring for the helper that builds the rs.DataSet to use NumPy format with explicit Parameters and Returns sections, documenting each input and the returned dataset/columns clearly; keep the existing behavior unchanged and ensure the docstring matches the function name and its role in auto-detecting and renaming the structure-factor columns.Source: Coding guidelines
tests/rewards/test_reward_function_contract.py (1)
47-66: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winFreeze
RewardCasebefore sharing it across tests.This bundle is passed around as shared state, so leaving the dataclass mutable makes accidental test-side mutation hard to spot.
@dataclass(frozen=True)matches the repo's immutable-state convention and still works withbatch(). As per coding guidelines, "Use frozen dataclasses with functional updates for immutable state management."Suggested change
-@dataclass +@dataclass(frozen=True) class RewardCase:🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/rewards/test_reward_function_contract.py` around lines 47 - 66, Make RewardCase immutable by marking the RewardCase dataclass as frozen so shared test state cannot be mutated accidentally. Update the RewardCase definition to use a frozen dataclass while keeping batch() unchanged, since it only reads fields and still works with immutable instances. Use the existing RewardCase symbol in tests/rewards/test_reward_function_contract.py to locate the class and apply the repo’s immutable-state convention.Source: Coding guidelines
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/sampleworks/core/rewards/structure_factor.py`:
- Around line 152-163: Validate the batch_partition argument eagerly in the
structure_factor constructor before assigning it on self, and raise a clear
ValueError when it is zero or negative. Add the check in the initializer where
device, mtzfile, and batch_partition are set so invalid OOM-tuning input is
rejected immediately instead of failing later in calc_fprotein_batch().
- Around line 69-78: The auto-detection in structure_factor.py is ambiguous
because it independently chooses the first amplitude and sigma columns, which
can silently pair the wrong MTZ labels in multi-label datasets. Update the
selection logic in the structure-factor loading path to either require explicit
expcolumns when more than one StructureFactorAmplitudeDtype or
StandardDeviationDtype candidate exists, or ensure the chosen sigma column is
matched to the selected amplitude label in the same ds.select_mtzdtype/return
path.
In `@src/sampleworks/eval/generate_synthetic_sf.py`:
- Around line 543-547: The help text in generate_synthetic_sf should not escape
the braces around label because it is a plain string, not an f-string. Update
the help text in the argument definition near the existing bulk solvent option
so users see F{label}/SIGF{label}/PHIF{label} in --help, and verify the same
wording is used consistently wherever that option text is defined.
In `@tests/eval/test_generate_synthetic_sf.py`:
- Around line 135-139: The round-trip test in test_generate_synthetic_sf is only
checking that each non-blank altloc label exists, so it can miss cases where
some atoms lose their altloc assignment. Update the assertion near
find_all_altloc_ids/loaded to verify multiplicity as well, either by comparing
per-label counts for altloc_id or by comparing the full altloc_id annotation
when order is stable. Keep the check black-box by asserting observable
annotation behavior rather than implementation details.
In `@tests/rewards/reward_input_helpers.py`:
- Around line 7-13: The shared helper build_scattering_indices() still has a
prose-only docstring, so update it to the repository’s NumPy-style format. Add
the standard sections for parameters, returns, and any relevant notes/details so
the contract is explicit for the reward tests that reuse it, while keeping the
behavior unchanged.
In `@tests/rewards/test_reward_function_contract.py`:
- Around line 248-270: The debug-only `test_gradient_descent_loss_trace` in
`test_reward_function_contract.py` should be removed from collected tests
because it only prints loss values and never asserts anything. Delete this
`test_...` method from `TestRewardFunctionContract`, or if you want to keep the
trace, rename it to a non-test helper so pytest won’t run it; make sure no
`print`-only optimization loop remains in the test suite.
- Around line 39-44: The module-level test marker in the reward contract suite
should include both GPU and slow tagging. Update the existing pytestmark
assignment in the test module so the suite remains GPU-marked via
pytest.mark.gpu while also adding pytest.mark.slow, keeping the change localized
to the module-level marker used by these reward contract tests.
---
Nitpick comments:
In `@src/sampleworks/eval/generate_synthetic_sf.py`:
- Around line 137-145: The new helper in generate_synthetic_sf.py has a summary
docstring but is missing the required NumPy-style structure. Update the
docstring for the helper that builds the rs.DataSet to use NumPy format with
explicit Parameters and Returns sections, documenting each input and the
returned dataset/columns clearly; keep the existing behavior unchanged and
ensure the docstring matches the function name and its role in auto-detecting
and renaming the structure-factor columns.
In `@tests/rewards/test_reward_function_contract.py`:
- Around line 47-66: Make RewardCase immutable by marking the RewardCase
dataclass as frozen so shared test state cannot be mutated accidentally. Update
the RewardCase definition to use a frozen dataclass while keeping batch()
unchanged, since it only reads fields and still works with immutable instances.
Use the existing RewardCase symbol in
tests/rewards/test_reward_function_contract.py to locate the class and apply the
repo’s immutable-state convention.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b9ae1d00-b4d9-46d0-87da-383baa1d031f
📒 Files selected for processing (11)
src/sampleworks/core/rewards/structure_factor.pysrc/sampleworks/eval/generate_synthetic_sf.pysrc/sampleworks/eval/synthetic_utils.pytests/conftest.pytests/eval/test_generate_synthetic_sf.pytests/resources/1vme/1vme_final_crystalframe_0.5occA_0.5occB_1.80A.ciftests/resources/1vme/1vme_final_crystalframe_0.5occA_0.5occB_1.80A.mtztests/rewards/reward_input_helpers.pytests/rewards/test_real_space_density_reward.pytests/rewards/test_reward_function_contract.pytests/rewards/test_structure_factor_reward.py
| amplitude_cols = ds.select_mtzdtype(rs.StructureFactorAmplitudeDtype()).columns | ||
| sigma_cols = ds.select_mtzdtype(rs.StandardDeviationDtype()).columns | ||
| if len(amplitude_cols) == 0 or len(sigma_cols) == 0: | ||
| raise ValueError( | ||
| f"MTZ '{mtzfile}' needs a structure-factor-amplitude column and a " | ||
| f"standard-deviation column; found amplitudes={list(amplitude_cols)}, " | ||
| f"sigmas={list(sigma_cols)}." | ||
| ) | ||
| spacegroup = ds.spacegroup.hm if ds.spacegroup is not None else None | ||
| return ds.cell, spacegroup, [amplitude_cols[0], sigma_cols[0]] |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
Reject ambiguous auto-detection for multi-label MTZs.
This now silently picks amplitude_cols[0] and sigma_cols[0] independently. With the new MTZ layout containing both Fprotein and Ftotal, expcolumns=None becomes column-order dependent and can bind the wrong target pair without any warning. Please either require explicit expcolumns whenever multiple amplitude/sigma candidates exist, or match the sigma column to the selected amplitude label instead of taking the first one.
Proposed fix
amplitude_cols = ds.select_mtzdtype(rs.StructureFactorAmplitudeDtype()).columns
sigma_cols = ds.select_mtzdtype(rs.StandardDeviationDtype()).columns
if len(amplitude_cols) == 0 or len(sigma_cols) == 0:
raise ValueError(
f"MTZ '{mtzfile}' needs a structure-factor-amplitude column and a "
f"standard-deviation column; found amplitudes={list(amplitude_cols)}, "
f"sigmas={list(sigma_cols)}."
)
+ if len(amplitude_cols) > 1 or len(sigma_cols) > 1:
+ raise ValueError(
+ f"MTZ '{mtzfile}' contains multiple amplitude/sigma columns; pass "
+ "`expcolumns=[amplitude, sigma]` explicitly."
+ )
spacegroup = ds.spacegroup.hm if ds.spacegroup is not None else None
return ds.cell, spacegroup, [amplitude_cols[0], sigma_cols[0]]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| amplitude_cols = ds.select_mtzdtype(rs.StructureFactorAmplitudeDtype()).columns | |
| sigma_cols = ds.select_mtzdtype(rs.StandardDeviationDtype()).columns | |
| if len(amplitude_cols) == 0 or len(sigma_cols) == 0: | |
| raise ValueError( | |
| f"MTZ '{mtzfile}' needs a structure-factor-amplitude column and a " | |
| f"standard-deviation column; found amplitudes={list(amplitude_cols)}, " | |
| f"sigmas={list(sigma_cols)}." | |
| ) | |
| spacegroup = ds.spacegroup.hm if ds.spacegroup is not None else None | |
| return ds.cell, spacegroup, [amplitude_cols[0], sigma_cols[0]] | |
| amplitude_cols = ds.select_mtzdtype(rs.StructureFactorAmplitudeDtype()).columns | |
| sigma_cols = ds.select_mtzdtype(rs.StandardDeviationDtype()).columns | |
| if len(amplitude_cols) == 0 or len(sigma_cols) == 0: | |
| raise ValueError( | |
| f"MTZ '{mtzfile}' needs a structure-factor-amplitude column and a " | |
| f"standard-deviation column; found amplitudes={list(amplitude_cols)}, " | |
| f"sigmas={list(sigma_cols)}." | |
| ) | |
| if len(amplitude_cols) > 1 or len(sigma_cols) > 1: | |
| raise ValueError( | |
| f"MTZ '{mtzfile}' contains multiple amplitude/sigma columns; pass " | |
| "`expcolumns=[amplitude, sigma]` explicitly." | |
| ) | |
| spacegroup = ds.spacegroup.hm if ds.spacegroup is not None else None | |
| return ds.cell, spacegroup, [amplitude_cols[0], sigma_cols[0]] |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/sampleworks/core/rewards/structure_factor.py` around lines 69 - 78, The
auto-detection in structure_factor.py is ambiguous because it independently
chooses the first amplitude and sigma columns, which can silently pair the wrong
MTZ labels in multi-label datasets. Update the selection logic in the
structure-factor loading path to either require explicit expcolumns when more
than one StructureFactorAmplitudeDtype or StandardDeviationDtype candidate
exists, or ensure the chosen sigma column is matched to the selected amplitude
label in the same ds.select_mtzdtype/return path.
| batch_partition: int = 10, | ||
| device: torch.device | None = None, | ||
| sfcalculator_kwargs: dict | None = None, | ||
| ): | ||
| if device is None: | ||
| device = try_gpu() | ||
| self.device = device | ||
| self.mtzfile = str(mtzfile) | ||
| self.resolution = resolution | ||
| self.scattering_factor_mode = scattering_factor_mode | ||
| self.exclude_free_reflections = exclude_free_reflections | ||
| self.batch_partition = batch_partition |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟡 Minor | ⚡ Quick win
Validate batch_partition at construction time.
A zero or negative partition size is accepted here and only fails later inside calc_fprotein_batch(). Since this knob is exposed specifically for OOM handling, it should be rejected eagerly with a clear ValueError.
Proposed fix
- self.batch_partition = batch_partition
+ if batch_partition <= 0:
+ raise ValueError("batch_partition must be a positive integer")
+ self.batch_partition = batch_partition📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| batch_partition: int = 10, | |
| device: torch.device | None = None, | |
| sfcalculator_kwargs: dict | None = None, | |
| ): | |
| if device is None: | |
| device = try_gpu() | |
| self.device = device | |
| self.mtzfile = str(mtzfile) | |
| self.resolution = resolution | |
| self.scattering_factor_mode = scattering_factor_mode | |
| self.exclude_free_reflections = exclude_free_reflections | |
| self.batch_partition = batch_partition | |
| batch_partition: int = 10, | |
| device: torch.device | None = None, | |
| sfcalculator_kwargs: dict | None = None, | |
| ): | |
| if device is None: | |
| device = try_gpu() | |
| self.device = device | |
| self.mtzfile = str(mtzfile) | |
| self.resolution = resolution | |
| self.scattering_factor_mode = scattering_factor_mode | |
| self.exclude_free_reflections = exclude_free_reflections | |
| if batch_partition <= 0: | |
| raise ValueError("batch_partition must be a positive integer") | |
| self.batch_partition = batch_partition |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/sampleworks/core/rewards/structure_factor.py` around lines 152 - 163,
Validate the batch_partition argument eagerly in the structure_factor
constructor before assigning it on self, and raise a clear ValueError when it is
zero or negative. Add the check in the initializer where device, mtzfile, and
batch_partition are set so invalid OOM-tuning input is rejected immediately
instead of failing later in calc_fprotein_batch().
| help=( | ||
| "Compute bulk solvent and overall scale factors and write both protein and " | ||
| "total structure factor in one MTZ. Without this flag, protein only. Each " | ||
| "set contains F\\{label\\}/SIGF\\{label\\}/PHIF\\{label\\}." | ||
| ), |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
Remove the escaped braces from the help text.
This is not an f-string, so F\\{label\\} will render to users as F\{label\} in --help.
Proposed fix
help=(
"Compute bulk solvent and overall scale factors and write both protein and "
"total structure factor in one MTZ. Without this flag, protein only. Each "
- "set contains F\\{label\\}/SIGF\\{label\\}/PHIF\\{label\\}."
+ "set contains F{label}/SIGF{label}/PHIF{label}."
),📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| help=( | |
| "Compute bulk solvent and overall scale factors and write both protein and " | |
| "total structure factor in one MTZ. Without this flag, protein only. Each " | |
| "set contains F\\{label\\}/SIGF\\{label\\}/PHIF\\{label\\}." | |
| ), | |
| help=( | |
| "Compute bulk solvent and overall scale factors and write both protein and " | |
| "total structure factor in one MTZ. Without this flag, protein only. Each " | |
| "set contains F{label}/SIGF{label}/PHIF{label}." | |
| ), |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/sampleworks/eval/generate_synthetic_sf.py` around lines 543 - 547, The
help text in generate_synthetic_sf should not escape the braces around label
because it is a plain string, not an f-string. Update the help text in the
argument definition near the existing bulk solvent option so users see
F{label}/SIGF{label}/PHIF{label} in --help, and verify the same wording is used
consistently wherever that option text is defined.
| assert len(loaded) == len(stripped_atom_array) | ||
| assert "altloc_id" in loaded.get_annotation_categories() | ||
| # every real (non-blank) altloc label from the source must survive the round trip, | ||
| # not just one. find_all_altloc_ids already strips blank-altloc sentinels. | ||
| assert find_all_altloc_ids(loaded) == find_all_altloc_ids(stripped_atom_array) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win
Assert altloc multiplicity, not just label presence.
Line 139 only compares the set of non-blank altloc labels, so this still passes if some atoms lose their altloc assignment as long as one atom with that label survives elsewhere. Compare per-label counts (or the full altloc_id annotation if order is stable) to make the round-trip guard catch partial label loss.
🧪 Tighten the regression check
loaded = load_structure_with_altlocs(out)
assert len(loaded) == len(stripped_atom_array)
assert "altloc_id" in loaded.get_annotation_categories()
- # every real (non-blank) altloc label from the source must survive the round trip,
- # not just one. find_all_altloc_ids already strips blank-altloc sentinels.
- assert find_all_altloc_ids(loaded) == find_all_altloc_ids(stripped_atom_array)
+ expected_altloc_ids = find_all_altloc_ids(stripped_atom_array)
+ assert find_all_altloc_ids(loaded) == expected_altloc_ids
+ for altloc_id in expected_altloc_ids:
+ assert np.count_nonzero(loaded.altloc_id == altloc_id) == np.count_nonzero(
+ stripped_atom_array.altloc_id == altloc_id
+ )As per coding guidelines, tests/**/*.py should "write black-box tests that verify behavior, not implementation."
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| assert len(loaded) == len(stripped_atom_array) | |
| assert "altloc_id" in loaded.get_annotation_categories() | |
| # every real (non-blank) altloc label from the source must survive the round trip, | |
| # not just one. find_all_altloc_ids already strips blank-altloc sentinels. | |
| assert find_all_altloc_ids(loaded) == find_all_altloc_ids(stripped_atom_array) | |
| assert len(loaded) == len(stripped_atom_array) | |
| assert "altloc_id" in loaded.get_annotation_categories() | |
| expected_altloc_ids = find_all_altloc_ids(stripped_atom_array) | |
| assert find_all_altloc_ids(loaded) == expected_altloc_ids | |
| for altloc_id in expected_altloc_ids: | |
| assert np.count_nonzero(loaded.altloc_id == altloc_id) == np.count_nonzero( | |
| stripped_atom_array.altloc_id == altloc_id | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/eval/test_generate_synthetic_sf.py` around lines 135 - 139, The
round-trip test in test_generate_synthetic_sf is only checking that each
non-blank altloc label exists, so it can miss cases where some atoms lose their
altloc assignment. Update the assertion near find_all_altloc_ids/loaded to
verify multiplicity as well, either by comparing per-label counts for altloc_id
or by comparing the full altloc_id annotation when order is stable. Keep the
check black-box by asserting observable annotation behavior rather than
implementation details.
Source: Coding guidelines
| def build_scattering_indices(atom_array, device: torch.device) -> torch.Tensor: | ||
| """Map biotite element symbols to scattering-tensor indices (production path). | ||
|
|
||
| Mirrors ``RewardInputs.from_atom_array``: uses ``elements_to_scattering_indices`` | ||
| (so ionic forms resolve correctly) and ``dtype=torch.long``. Shared by the reward | ||
| test files (contract + structure-factor) so they build ``elements`` identically. | ||
| """ |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
Use a NumPy-style docstring for this shared helper.
build_scattering_indices() is now the common input path for multiple reward tests, but its docstring is still prose-only. Please expand it to the repo's required NumPy-style format so the helper's contract is explicit where it is reused. As per coding guidelines, "Always include NumPy-style docstrings for every function and class."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/rewards/reward_input_helpers.py` around lines 7 - 13, The shared helper
build_scattering_indices() still has a prose-only docstring, so update it to the
repository’s NumPy-style format. Add the standard sections for parameters,
returns, and any relevant notes/details so the contract is explicit for the
reward tests that reuse it, while keeping the behavior unchanged.
Source: Coding guidelines
| # Every test exercises CUDA-targeted reward code on the `device` fixture (try_gpu), so the | ||
| # whole module is gpu-marked. Deliberately NOT `slow`: measured warm per-test time is <2.5s | ||
| # (slowest is the SFC gradient-descent loop at ~2.4s; the rest are sub-second). The ~11s of | ||
| # fixed cost is one-time import + session-scoped reward construction, which is paid once per | ||
| # pytest invocation and cannot be skipped by `slow`-marking these tests. | ||
| pytestmark = pytest.mark.gpu |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
Mark this GPU contract suite as slow.
This module explicitly depends on GPU-targeted reward execution, but it only applies pytest.mark.gpu. Please add slow as well so fast CI can exclude it consistently. As per coding guidelines, "Mark any test requiring a GPU or model checkpoint with @pytest.mark.slow so it is excluded from fast CI runs."
Suggested change
-pytestmark = pytest.mark.gpu
+pytestmark = [pytest.mark.gpu, pytest.mark.slow]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Every test exercises CUDA-targeted reward code on the `device` fixture (try_gpu), so the | |
| # whole module is gpu-marked. Deliberately NOT `slow`: measured warm per-test time is <2.5s | |
| # (slowest is the SFC gradient-descent loop at ~2.4s; the rest are sub-second). The ~11s of | |
| # fixed cost is one-time import + session-scoped reward construction, which is paid once per | |
| # pytest invocation and cannot be skipped by `slow`-marking these tests. | |
| pytestmark = pytest.mark.gpu | |
| # Every test exercises CUDA-targeted reward code on the `device` fixture (try_gpu), so the | |
| # whole module is gpu-marked. Deliberately NOT `slow`: measured warm per-test time is <2.5s | |
| # (slowest is the SFC gradient-descent loop at ~2.4s; the rest are sub-second). The ~11s of | |
| # fixed cost is one-time import + session-scoped reward construction, which is paid once per | |
| # pytest invocation and cannot be skipped by `slow`-marking these tests. | |
| pytestmark = [pytest.mark.gpu, pytest.mark.slow] |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/rewards/test_reward_function_contract.py` around lines 39 - 44, The
module-level test marker in the reward contract suite should include both GPU
and slow tagging. Update the existing pytestmark assignment in the test module
so the suite remains GPU-marked via pytest.mark.gpu while also adding
pytest.mark.slow, keeping the change localized to the module-level marker used
by these reward contract tests.
Source: Coding guidelines
| def test_gradient_descent_loss_trace(self, reward_case): | ||
| """TEMP (exploration): print the per-step loss for the 10-step descent.""" | ||
| torch.manual_seed(42) | ||
| perturbation = torch.randn_like(reward_case.coords) * 0.5 | ||
| coords_opt = (reward_case.coords + perturbation).unsqueeze(0).requires_grad_(True) | ||
| optimizer = torch.optim.Adam([coords_opt], lr=0.01) | ||
|
|
||
| def loss_fn(): | ||
| return reward_case.reward_function( | ||
| coordinates=coords_opt, | ||
| elements=reward_case.elements.unsqueeze(0), | ||
| b_factors=reward_case.b_factors.unsqueeze(0), | ||
| occupancies=reward_case.occupancies.unsqueeze(0), | ||
| ) | ||
|
|
||
| print(f"\n[{reward_case.name}] initial: {loss_fn().item():.6e}") | ||
| for i in range(10): | ||
| optimizer.zero_grad() | ||
| loss = loss_fn() | ||
| loss.backward() | ||
| optimizer.step() | ||
| with torch.no_grad(): | ||
| print(f"[{reward_case.name}] after step {i + 1:2d}: {loss_fn().item():.6e}") |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
Remove the print-only debug test before merge.
test_gradient_descent_loss_trace() never asserts anything, so it always passes while adding another 10-step GPU optimization loop and debug output to every run. If you still want this trace, keep it as a non-collected helper instead. As per coding guidelines, "No dead code and no compatibility shims for hypothetical users."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/rewards/test_reward_function_contract.py` around lines 248 - 270, The
debug-only `test_gradient_descent_loss_trace` in
`test_reward_function_contract.py` should be removed from collected tests
because it only prints loss values and never asserts anything. Delete this
`test_...` method from `TestRewardFunctionContract`, or if you want to keep the
trace, rename it to a non-test helper so pytest won’t run it; make sure no
`print`-only optimization loop remains in the test suite.
Source: Coding guidelines
What changed
StructureFactorRewardFunction(core/rewards/structure_factor.py) via SFC. Two-phase construction (__init__config +prepare(atom_array)) as SFC requires knowing topology.|Fprotein|normalize_amplitudefor testing andbatch_partitionin case of OOMeval/generate_synthetic_sf.pyto generate test data with both Fprotein and Ftotal in the same mtz for debugging/development and support the round trip of cif --> atomarray --> gemmi --> cif --> atomarray.test_real_space_density_reward.pytotest_reward_function_contract.py. Real space specific tests that are unused (e.g. vmap related) remain untouched.test_structure_factor_reward.pyand synthetic test data (1vme cif and mtz) generated fromeval/generate_synthetic_sf.py.Next steps
|Ftotal|(should be trivial), blocked by SFC PR mergingSummary by CodeRabbit
New Features
Bug Fixes
Tests