Skip to content

feat(rewards): add reciprocal-space reward (Fprotein from SFC)#272

Open
DorisMai wants to merge 8 commits into
mainfrom
dm/add-sfc-reward
Open

feat(rewards): add reciprocal-space reward (Fprotein from SFC)#272
DorisMai wants to merge 8 commits into
mainfrom
dm/add-sfc-reward

Conversation

@DorisMai

@DorisMai DorisMai commented Jun 27, 2026

Copy link
Copy Markdown
Collaborator

What changed

  • Added StructureFactorRewardFunction (core/rewards/structure_factor.py) via SFC. Two-phase construction (__init__ config + prepare(atom_array)) as SFC requires knowing topology.
    • ensembles combine as a complex sum.: v1 fits |Fprotein|
    • restricted by SFC to occupancy and B-factor shared across batch
    • exposes normalize_amplitude for testing and batch_partition in case of OOM
  • Patched eval/generate_synthetic_sf.py to generate test data with both Fprotein and Ftotal in the same mtz for debugging/development and support the round trip of cif --> atomarray --> gemmi --> cif --> atomarray.
  • Refactored and added reward tests.
    • Extracted common functions from test_real_space_density_reward.py to test_reward_function_contract.py. Real space specific tests that are unused (e.g. vmap related) remain untouched.
    • Added test_structure_factor_reward.py and synthetic test data (1vme cif and mtz) generated from eval/generate_synthetic_sf.py.

Next steps

  • Add unscaled |Ftotal| (should be trivial), blocked by SFC PR merging
  • Wire the reward into a guidance step scaler for initial 40 experiment
  • Add scaling optimization

Summary by CodeRabbit

  • New Features

    • Added support for evaluating reciprocal-space structure-factor rewards from experimental targets.
    • Synthetic structure-factor generation can now write richer MTZ outputs with multiple labeled amplitude/phase sets.
    • Improved structure conversion preserves altlocs and round-trips more reliably.
  • Bug Fixes

    • Better validation and clearer errors when batch inputs for occupancies or B-factors are inconsistent.
    • Improved handling of crystal metadata and target-column detection.
  • Tests

    • Expanded coverage for structure-factor workflows, altloc round-trips, and reward-function behavior.

@coderabbitai

coderabbitai Bot commented Jun 27, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

Introduces StructureFactorRewardFunction, a two-phase reciprocal-space reward calculator that scores multi-conformer ensembles against experimental MTZ amplitudes. Refactors synthetic SF generation to emit multi-label MTZ files with F{label}/SIGF{label}/PHIF{label} column sets. Extracts atomarray_to_gemmi into a shared utility. Adds a unified reward-function contract test suite alongside SF-specific occupancy/B validation tests.

Changes

Structure Factor Reward Function

Layer / File(s) Summary
MTZ metadata detection and resolution
src/sampleworks/core/rewards/structure_factor.py
_detect_mtz_metadata reads MTZ columns and unit cell by dtype; _resolve_mtz_metadata merges caller-supplied overrides, emitting warnings on mismatches. Module-level AmplitudeLoss type alias and comparison tolerances are defined.
StructureFactorRewardFunction init and prepare
src/sampleworks/core/rewards/structure_factor.py
__init__ stores config, resolves crystal metadata, and builds SFcalculator kwargs. prepare(atom_array) converts the atom array to gemmi, instantiates SFcalculator, calls inspect_data(), validates normalization availability, and computes the reflection mask.
StructureFactorRewardFunction __call__ forward pass
src/sampleworks/core/rewards/structure_factor.py
Validates batch-identical occupancy/B-factors, writes them into SFcalculator, computes ensemble-summed protein structure factors, selects normalized (Ec/Eo) or raw (Fprotein/Fo) amplitude path, and applies the configured scalar loss over masked reflections.

atomarray_to_gemmi shared utility

Layer / File(s) Summary
atomarray_to_gemmi implementation
src/sampleworks/eval/synthetic_utils.py
Adds gemmi/numpy imports and atomarray_to_gemmi, handling altloc normalization, zero aniso B-factors, model name normalization, and residue subchain fixup. Optionally sets unit cell and space group.
Round-trip altloc test
tests/eval/test_generate_synthetic_sf.py
New test writes the converted structure to mmCIF, reloads with altloc support, and asserts atom count and altloc label preservation via find_all_altloc_ids.

Synthetic SF generation multi-label MTZ refactor

Layer / File(s) Summary
Multi-label MTZ dataset assembly
src/sampleworks/eval/generate_synthetic_sf.py
Adds _amplitude_phase_columns helper that auto-detects and renames amplitude/phase columns and synthesizes SIGF. Refactors process_amplitudes_to_dataset to accept a structure_factor_columns dict and merge per-label datasets. Updates _process_single_row to build the mapping (always protein, conditionally total) and updates CLI help text.

Reward function contract test suite

Layer / File(s) Summary
RewardCase dataclass, fixtures, and shared helpers
tests/rewards/reward_input_helpers.py, tests/rewards/test_reward_function_contract.py
Defines RewardCase with a batch() helper, _REWARD_BUNDLES, _LOSS_THRESHOLDS, the parameterized reward_case fixture, and build_scattering_indices utility used across reward tests.
Contract test classes
tests/rewards/test_reward_function_contract.py
TestRewardFunctionInterface, TestRewardCorrelation, TestRewardGradientFlow, TestRewardBatchHandling, and TestRewardEdgeCases assert shared semantics (scalar output, loss ordering, finite gradients, Adam descent, batch sizes, numerical stability) across real-space and SF reward implementations.
SF reward fixtures and occupancy/B test
tests/conftest.py, tests/rewards/test_structure_factor_reward.py
Adds session-scoped 1vme MTZ/CIF/coordinate/reward fixtures with skip-on-missing; adds TestStructureFactorOccupancy asserting ValueError on non-broadcast occupancy or B-factors across the batch.
Real-space test refactor
tests/rewards/test_real_space_density_reward.py
Removes tests now covered by the contract suite; retains TestVmapCompatibility (vmap correctness, shape, consistency vs sequential) and TestEdgeCases (single-atom, structure_to_reward_input round-trip).

Sequence Diagram(s)

sequenceDiagram
  participant Caller
  participant StructureFactorRewardFunction
  participant _detect_mtz_metadata
  participant atomarray_to_gemmi
  participant SFcalculator

  rect rgba(100, 149, 237, 0.5)
    note over Caller,StructureFactorRewardFunction: Phase 1 — Configuration
    Caller->>StructureFactorRewardFunction: __init__(mtzfile, resolution, loss, ...)
    StructureFactorRewardFunction->>_detect_mtz_metadata: read unit_cell, expcolumns from MTZ
    _detect_mtz_metadata-->>StructureFactorRewardFunction: unit_cell, space_group, expcolumns
  end

  rect rgba(60, 179, 113, 0.5)
    note over Caller,SFcalculator: Phase 2 — Preparation
    Caller->>StructureFactorRewardFunction: prepare(atom_array)
    StructureFactorRewardFunction->>atomarray_to_gemmi: convert AtomArray → gemmi.Structure
    atomarray_to_gemmi-->>StructureFactorRewardFunction: gemmi.Structure
    StructureFactorRewardFunction->>SFcalculator: __init__(pdbmodel=PDBParser(gemmi_structure))
    StructureFactorRewardFunction->>SFcalculator: inspect_data()
    SFcalculator-->>StructureFactorRewardFunction: Fo, Eo, outlier flags
    StructureFactorRewardFunction-->>Caller: ready (reflection mask built)
  end

  rect rgba(210, 105, 30, 0.5)
    note over Caller,SFcalculator: Phase 3 — Forward pass
    Caller->>StructureFactorRewardFunction: __call__(coordinates, elements, b_factors, occupancies)
    StructureFactorRewardFunction->>SFcalculator: write occupancy/B, calc_fprotein_batch
    SFcalculator-->>StructureFactorRewardFunction: per-reflection complex SFs
    StructureFactorRewardFunction->>StructureFactorRewardFunction: sum ensemble, compute amplitudes, apply loss mask
    StructureFactorRewardFunction-->>Caller: scalar loss tensor
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • marcuscollins

🐇 A new reward hops into the space,
Where structure factors find their place,
MTZ columns mapped with care,
Ensemble SFs floating there—
The rabbit scores each conformation's grace! ✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 71.79% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main change: adding a reciprocal-space reward based on Fprotein from SFC.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch dm/add-sfc-reward

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@DorisMai DorisMai marked this pull request as ready for review June 30, 2026 01:11

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🧹 Nitpick comments (2)
src/sampleworks/eval/generate_synthetic_sf.py (1)

137-145: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Use a NumPy-style docstring for this new helper.

The summary is useful, but this new function is missing the required Parameters / Returns sections. As per coding guidelines, “Always include NumPy-style docstrings for every function and class.”

Proposed docstring update
-    """Build a one-amplitude rs.DataSet with labelled F / SIGF / PHIF columns.
-
-    ``sfc.prepare_dataset`` returns an amplitude column and a phase column (degrees)
-    for the given ``structure_factor_column`` attribute. We auto-detect those by MTZ
-    dtype (rather than assuming the unexposed ``FMODEL`` / ``PHIFMODEL`` names),
-    rename them to ``F{label}`` / ``PHIF{label}``, and synthesize a ``SIGF{label}``
-    column so several structure-factor sets (e.g. protein and total) can coexist in
-    one MTZ.
-    """
+    """Build a one-amplitude dataset with labelled F / SIGF / PHIF columns.
+
+    Parameters
+    ----------
+    sfc
+        Structure-factor calculator containing the requested ASU amplitudes.
+    label
+        Output column label suffix, e.g. ``protein`` or ``total``.
+    structure_factor_column
+        SFcalculator attribute passed to ``prepare_dataset``.
+    miller_index_column
+        SFcalculator attribute containing Miller indices.
+    sigma_f_scale
+        Scale factor used to synthesize dummy SIGF values from amplitudes.
+
+    Returns
+    -------
+    rs.DataSet
+        Dataset containing ``F{label}``, ``SIGF{label}``, and ``PHIF{label}``.
+    """
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/sampleworks/eval/generate_synthetic_sf.py` around lines 137 - 145, The
new helper in generate_synthetic_sf.py has a summary docstring but is missing
the required NumPy-style structure. Update the docstring for the helper that
builds the rs.DataSet to use NumPy format with explicit Parameters and Returns
sections, documenting each input and the returned dataset/columns clearly; keep
the existing behavior unchanged and ensure the docstring matches the function
name and its role in auto-detecting and renaming the structure-factor columns.

Source: Coding guidelines

tests/rewards/test_reward_function_contract.py (1)

47-66: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Freeze RewardCase before sharing it across tests.

This bundle is passed around as shared state, so leaving the dataclass mutable makes accidental test-side mutation hard to spot. @dataclass(frozen=True) matches the repo's immutable-state convention and still works with batch(). As per coding guidelines, "Use frozen dataclasses with functional updates for immutable state management."

Suggested change
-@dataclass
+@dataclass(frozen=True)
 class RewardCase:
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/rewards/test_reward_function_contract.py` around lines 47 - 66, Make
RewardCase immutable by marking the RewardCase dataclass as frozen so shared
test state cannot be mutated accidentally. Update the RewardCase definition to
use a frozen dataclass while keeping batch() unchanged, since it only reads
fields and still works with immutable instances. Use the existing RewardCase
symbol in tests/rewards/test_reward_function_contract.py to locate the class and
apply the repo’s immutable-state convention.

Source: Coding guidelines

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/sampleworks/core/rewards/structure_factor.py`:
- Around line 152-163: Validate the batch_partition argument eagerly in the
structure_factor constructor before assigning it on self, and raise a clear
ValueError when it is zero or negative. Add the check in the initializer where
device, mtzfile, and batch_partition are set so invalid OOM-tuning input is
rejected immediately instead of failing later in calc_fprotein_batch().
- Around line 69-78: The auto-detection in structure_factor.py is ambiguous
because it independently chooses the first amplitude and sigma columns, which
can silently pair the wrong MTZ labels in multi-label datasets. Update the
selection logic in the structure-factor loading path to either require explicit
expcolumns when more than one StructureFactorAmplitudeDtype or
StandardDeviationDtype candidate exists, or ensure the chosen sigma column is
matched to the selected amplitude label in the same ds.select_mtzdtype/return
path.

In `@src/sampleworks/eval/generate_synthetic_sf.py`:
- Around line 543-547: The help text in generate_synthetic_sf should not escape
the braces around label because it is a plain string, not an f-string. Update
the help text in the argument definition near the existing bulk solvent option
so users see F{label}/SIGF{label}/PHIF{label} in --help, and verify the same
wording is used consistently wherever that option text is defined.

In `@tests/eval/test_generate_synthetic_sf.py`:
- Around line 135-139: The round-trip test in test_generate_synthetic_sf is only
checking that each non-blank altloc label exists, so it can miss cases where
some atoms lose their altloc assignment. Update the assertion near
find_all_altloc_ids/loaded to verify multiplicity as well, either by comparing
per-label counts for altloc_id or by comparing the full altloc_id annotation
when order is stable. Keep the check black-box by asserting observable
annotation behavior rather than implementation details.

In `@tests/rewards/reward_input_helpers.py`:
- Around line 7-13: The shared helper build_scattering_indices() still has a
prose-only docstring, so update it to the repository’s NumPy-style format. Add
the standard sections for parameters, returns, and any relevant notes/details so
the contract is explicit for the reward tests that reuse it, while keeping the
behavior unchanged.

In `@tests/rewards/test_reward_function_contract.py`:
- Around line 248-270: The debug-only `test_gradient_descent_loss_trace` in
`test_reward_function_contract.py` should be removed from collected tests
because it only prints loss values and never asserts anything. Delete this
`test_...` method from `TestRewardFunctionContract`, or if you want to keep the
trace, rename it to a non-test helper so pytest won’t run it; make sure no
`print`-only optimization loop remains in the test suite.
- Around line 39-44: The module-level test marker in the reward contract suite
should include both GPU and slow tagging. Update the existing pytestmark
assignment in the test module so the suite remains GPU-marked via
pytest.mark.gpu while also adding pytest.mark.slow, keeping the change localized
to the module-level marker used by these reward contract tests.

---

Nitpick comments:
In `@src/sampleworks/eval/generate_synthetic_sf.py`:
- Around line 137-145: The new helper in generate_synthetic_sf.py has a summary
docstring but is missing the required NumPy-style structure. Update the
docstring for the helper that builds the rs.DataSet to use NumPy format with
explicit Parameters and Returns sections, documenting each input and the
returned dataset/columns clearly; keep the existing behavior unchanged and
ensure the docstring matches the function name and its role in auto-detecting
and renaming the structure-factor columns.

In `@tests/rewards/test_reward_function_contract.py`:
- Around line 47-66: Make RewardCase immutable by marking the RewardCase
dataclass as frozen so shared test state cannot be mutated accidentally. Update
the RewardCase definition to use a frozen dataclass while keeping batch()
unchanged, since it only reads fields and still works with immutable instances.
Use the existing RewardCase symbol in
tests/rewards/test_reward_function_contract.py to locate the class and apply the
repo’s immutable-state convention.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b9ae1d00-b4d9-46d0-87da-383baa1d031f

📥 Commits

Reviewing files that changed from the base of the PR and between 2777630 and f7c0f62.

📒 Files selected for processing (11)
  • src/sampleworks/core/rewards/structure_factor.py
  • src/sampleworks/eval/generate_synthetic_sf.py
  • src/sampleworks/eval/synthetic_utils.py
  • tests/conftest.py
  • tests/eval/test_generate_synthetic_sf.py
  • tests/resources/1vme/1vme_final_crystalframe_0.5occA_0.5occB_1.80A.cif
  • tests/resources/1vme/1vme_final_crystalframe_0.5occA_0.5occB_1.80A.mtz
  • tests/rewards/reward_input_helpers.py
  • tests/rewards/test_real_space_density_reward.py
  • tests/rewards/test_reward_function_contract.py
  • tests/rewards/test_structure_factor_reward.py

Comment on lines +69 to +78
amplitude_cols = ds.select_mtzdtype(rs.StructureFactorAmplitudeDtype()).columns
sigma_cols = ds.select_mtzdtype(rs.StandardDeviationDtype()).columns
if len(amplitude_cols) == 0 or len(sigma_cols) == 0:
raise ValueError(
f"MTZ '{mtzfile}' needs a structure-factor-amplitude column and a "
f"standard-deviation column; found amplitudes={list(amplitude_cols)}, "
f"sigmas={list(sigma_cols)}."
)
spacegroup = ds.spacegroup.hm if ds.spacegroup is not None else None
return ds.cell, spacegroup, [amplitude_cols[0], sigma_cols[0]]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Reject ambiguous auto-detection for multi-label MTZs.

This now silently picks amplitude_cols[0] and sigma_cols[0] independently. With the new MTZ layout containing both Fprotein and Ftotal, expcolumns=None becomes column-order dependent and can bind the wrong target pair without any warning. Please either require explicit expcolumns whenever multiple amplitude/sigma candidates exist, or match the sigma column to the selected amplitude label instead of taking the first one.

Proposed fix
     amplitude_cols = ds.select_mtzdtype(rs.StructureFactorAmplitudeDtype()).columns
     sigma_cols = ds.select_mtzdtype(rs.StandardDeviationDtype()).columns
     if len(amplitude_cols) == 0 or len(sigma_cols) == 0:
         raise ValueError(
             f"MTZ '{mtzfile}' needs a structure-factor-amplitude column and a "
             f"standard-deviation column; found amplitudes={list(amplitude_cols)}, "
             f"sigmas={list(sigma_cols)}."
         )
+    if len(amplitude_cols) > 1 or len(sigma_cols) > 1:
+        raise ValueError(
+            f"MTZ '{mtzfile}' contains multiple amplitude/sigma columns; pass "
+            "`expcolumns=[amplitude, sigma]` explicitly."
+        )
     spacegroup = ds.spacegroup.hm if ds.spacegroup is not None else None
     return ds.cell, spacegroup, [amplitude_cols[0], sigma_cols[0]]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
amplitude_cols = ds.select_mtzdtype(rs.StructureFactorAmplitudeDtype()).columns
sigma_cols = ds.select_mtzdtype(rs.StandardDeviationDtype()).columns
if len(amplitude_cols) == 0 or len(sigma_cols) == 0:
raise ValueError(
f"MTZ '{mtzfile}' needs a structure-factor-amplitude column and a "
f"standard-deviation column; found amplitudes={list(amplitude_cols)}, "
f"sigmas={list(sigma_cols)}."
)
spacegroup = ds.spacegroup.hm if ds.spacegroup is not None else None
return ds.cell, spacegroup, [amplitude_cols[0], sigma_cols[0]]
amplitude_cols = ds.select_mtzdtype(rs.StructureFactorAmplitudeDtype()).columns
sigma_cols = ds.select_mtzdtype(rs.StandardDeviationDtype()).columns
if len(amplitude_cols) == 0 or len(sigma_cols) == 0:
raise ValueError(
f"MTZ '{mtzfile}' needs a structure-factor-amplitude column and a "
f"standard-deviation column; found amplitudes={list(amplitude_cols)}, "
f"sigmas={list(sigma_cols)}."
)
if len(amplitude_cols) > 1 or len(sigma_cols) > 1:
raise ValueError(
f"MTZ '{mtzfile}' contains multiple amplitude/sigma columns; pass "
"`expcolumns=[amplitude, sigma]` explicitly."
)
spacegroup = ds.spacegroup.hm if ds.spacegroup is not None else None
return ds.cell, spacegroup, [amplitude_cols[0], sigma_cols[0]]
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/sampleworks/core/rewards/structure_factor.py` around lines 69 - 78, The
auto-detection in structure_factor.py is ambiguous because it independently
chooses the first amplitude and sigma columns, which can silently pair the wrong
MTZ labels in multi-label datasets. Update the selection logic in the
structure-factor loading path to either require explicit expcolumns when more
than one StructureFactorAmplitudeDtype or StandardDeviationDtype candidate
exists, or ensure the chosen sigma column is matched to the selected amplitude
label in the same ds.select_mtzdtype/return path.

Comment on lines +152 to +163
batch_partition: int = 10,
device: torch.device | None = None,
sfcalculator_kwargs: dict | None = None,
):
if device is None:
device = try_gpu()
self.device = device
self.mtzfile = str(mtzfile)
self.resolution = resolution
self.scattering_factor_mode = scattering_factor_mode
self.exclude_free_reflections = exclude_free_reflections
self.batch_partition = batch_partition

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟡 Minor | ⚡ Quick win

Validate batch_partition at construction time.

A zero or negative partition size is accepted here and only fails later inside calc_fprotein_batch(). Since this knob is exposed specifically for OOM handling, it should be rejected eagerly with a clear ValueError.

Proposed fix
-        self.batch_partition = batch_partition
+        if batch_partition <= 0:
+            raise ValueError("batch_partition must be a positive integer")
+        self.batch_partition = batch_partition
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
batch_partition: int = 10,
device: torch.device | None = None,
sfcalculator_kwargs: dict | None = None,
):
if device is None:
device = try_gpu()
self.device = device
self.mtzfile = str(mtzfile)
self.resolution = resolution
self.scattering_factor_mode = scattering_factor_mode
self.exclude_free_reflections = exclude_free_reflections
self.batch_partition = batch_partition
batch_partition: int = 10,
device: torch.device | None = None,
sfcalculator_kwargs: dict | None = None,
):
if device is None:
device = try_gpu()
self.device = device
self.mtzfile = str(mtzfile)
self.resolution = resolution
self.scattering_factor_mode = scattering_factor_mode
self.exclude_free_reflections = exclude_free_reflections
if batch_partition <= 0:
raise ValueError("batch_partition must be a positive integer")
self.batch_partition = batch_partition
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/sampleworks/core/rewards/structure_factor.py` around lines 152 - 163,
Validate the batch_partition argument eagerly in the structure_factor
constructor before assigning it on self, and raise a clear ValueError when it is
zero or negative. Add the check in the initializer where device, mtzfile, and
batch_partition are set so invalid OOM-tuning input is rejected immediately
instead of failing later in calc_fprotein_batch().

Comment on lines +543 to +547
help=(
"Compute bulk solvent and overall scale factors and write both protein and "
"total structure factor in one MTZ. Without this flag, protein only. Each "
"set contains F\\{label\\}/SIGF\\{label\\}/PHIF\\{label\\}."
),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win

Remove the escaped braces from the help text.

This is not an f-string, so F\\{label\\} will render to users as F\{label\} in --help.

Proposed fix
         help=(
             "Compute bulk solvent and overall scale factors and write both protein and "
             "total structure factor in one MTZ. Without this flag, protein only. Each "
-            "set contains F\\{label\\}/SIGF\\{label\\}/PHIF\\{label\\}."
+            "set contains F{label}/SIGF{label}/PHIF{label}."
         ),
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
help=(
"Compute bulk solvent and overall scale factors and write both protein and "
"total structure factor in one MTZ. Without this flag, protein only. Each "
"set contains F\\{label\\}/SIGF\\{label\\}/PHIF\\{label\\}."
),
help=(
"Compute bulk solvent and overall scale factors and write both protein and "
"total structure factor in one MTZ. Without this flag, protein only. Each "
"set contains F{label}/SIGF{label}/PHIF{label}."
),
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/sampleworks/eval/generate_synthetic_sf.py` around lines 543 - 547, The
help text in generate_synthetic_sf should not escape the braces around label
because it is a plain string, not an f-string. Update the help text in the
argument definition near the existing bulk solvent option so users see
F{label}/SIGF{label}/PHIF{label} in --help, and verify the same wording is used
consistently wherever that option text is defined.

Comment on lines +135 to +139
assert len(loaded) == len(stripped_atom_array)
assert "altloc_id" in loaded.get_annotation_categories()
# every real (non-blank) altloc label from the source must survive the round trip,
# not just one. find_all_altloc_ids already strips blank-altloc sentinels.
assert find_all_altloc_ids(loaded) == find_all_altloc_ids(stripped_atom_array)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win

Assert altloc multiplicity, not just label presence.

Line 139 only compares the set of non-blank altloc labels, so this still passes if some atoms lose their altloc assignment as long as one atom with that label survives elsewhere. Compare per-label counts (or the full altloc_id annotation if order is stable) to make the round-trip guard catch partial label loss.

🧪 Tighten the regression check
         loaded = load_structure_with_altlocs(out)

         assert len(loaded) == len(stripped_atom_array)
         assert "altloc_id" in loaded.get_annotation_categories()
-        # every real (non-blank) altloc label from the source must survive the round trip,
-        # not just one. find_all_altloc_ids already strips blank-altloc sentinels.
-        assert find_all_altloc_ids(loaded) == find_all_altloc_ids(stripped_atom_array)
+        expected_altloc_ids = find_all_altloc_ids(stripped_atom_array)
+        assert find_all_altloc_ids(loaded) == expected_altloc_ids
+        for altloc_id in expected_altloc_ids:
+            assert np.count_nonzero(loaded.altloc_id == altloc_id) == np.count_nonzero(
+                stripped_atom_array.altloc_id == altloc_id
+            )

As per coding guidelines, tests/**/*.py should "write black-box tests that verify behavior, not implementation."

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
assert len(loaded) == len(stripped_atom_array)
assert "altloc_id" in loaded.get_annotation_categories()
# every real (non-blank) altloc label from the source must survive the round trip,
# not just one. find_all_altloc_ids already strips blank-altloc sentinels.
assert find_all_altloc_ids(loaded) == find_all_altloc_ids(stripped_atom_array)
assert len(loaded) == len(stripped_atom_array)
assert "altloc_id" in loaded.get_annotation_categories()
expected_altloc_ids = find_all_altloc_ids(stripped_atom_array)
assert find_all_altloc_ids(loaded) == expected_altloc_ids
for altloc_id in expected_altloc_ids:
assert np.count_nonzero(loaded.altloc_id == altloc_id) == np.count_nonzero(
stripped_atom_array.altloc_id == altloc_id
)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/eval/test_generate_synthetic_sf.py` around lines 135 - 139, The
round-trip test in test_generate_synthetic_sf is only checking that each
non-blank altloc label exists, so it can miss cases where some atoms lose their
altloc assignment. Update the assertion near find_all_altloc_ids/loaded to
verify multiplicity as well, either by comparing per-label counts for altloc_id
or by comparing the full altloc_id annotation when order is stable. Keep the
check black-box by asserting observable annotation behavior rather than
implementation details.

Source: Coding guidelines

Comment on lines +7 to +13
def build_scattering_indices(atom_array, device: torch.device) -> torch.Tensor:
"""Map biotite element symbols to scattering-tensor indices (production path).

Mirrors ``RewardInputs.from_atom_array``: uses ``elements_to_scattering_indices``
(so ionic forms resolve correctly) and ``dtype=torch.long``. Shared by the reward
test files (contract + structure-factor) so they build ``elements`` identically.
"""

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win

Use a NumPy-style docstring for this shared helper.

build_scattering_indices() is now the common input path for multiple reward tests, but its docstring is still prose-only. Please expand it to the repo's required NumPy-style format so the helper's contract is explicit where it is reused. As per coding guidelines, "Always include NumPy-style docstrings for every function and class."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/rewards/reward_input_helpers.py` around lines 7 - 13, The shared helper
build_scattering_indices() still has a prose-only docstring, so update it to the
repository’s NumPy-style format. Add the standard sections for parameters,
returns, and any relevant notes/details so the contract is explicit for the
reward tests that reuse it, while keeping the behavior unchanged.

Source: Coding guidelines

Comment on lines +39 to +44
# Every test exercises CUDA-targeted reward code on the `device` fixture (try_gpu), so the
# whole module is gpu-marked. Deliberately NOT `slow`: measured warm per-test time is <2.5s
# (slowest is the SFC gradient-descent loop at ~2.4s; the rest are sub-second). The ~11s of
# fixed cost is one-time import + session-scoped reward construction, which is paid once per
# pytest invocation and cannot be skipped by `slow`-marking these tests.
pytestmark = pytest.mark.gpu

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win

Mark this GPU contract suite as slow.

This module explicitly depends on GPU-targeted reward execution, but it only applies pytest.mark.gpu. Please add slow as well so fast CI can exclude it consistently. As per coding guidelines, "Mark any test requiring a GPU or model checkpoint with @pytest.mark.slow so it is excluded from fast CI runs."

Suggested change
-pytestmark = pytest.mark.gpu
+pytestmark = [pytest.mark.gpu, pytest.mark.slow]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Every test exercises CUDA-targeted reward code on the `device` fixture (try_gpu), so the
# whole module is gpu-marked. Deliberately NOT `slow`: measured warm per-test time is <2.5s
# (slowest is the SFC gradient-descent loop at ~2.4s; the rest are sub-second). The ~11s of
# fixed cost is one-time import + session-scoped reward construction, which is paid once per
# pytest invocation and cannot be skipped by `slow`-marking these tests.
pytestmark = pytest.mark.gpu
# Every test exercises CUDA-targeted reward code on the `device` fixture (try_gpu), so the
# whole module is gpu-marked. Deliberately NOT `slow`: measured warm per-test time is <2.5s
# (slowest is the SFC gradient-descent loop at ~2.4s; the rest are sub-second). The ~11s of
# fixed cost is one-time import + session-scoped reward construction, which is paid once per
# pytest invocation and cannot be skipped by `slow`-marking these tests.
pytestmark = [pytest.mark.gpu, pytest.mark.slow]
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/rewards/test_reward_function_contract.py` around lines 39 - 44, The
module-level test marker in the reward contract suite should include both GPU
and slow tagging. Update the existing pytestmark assignment in the test module
so the suite remains GPU-marked via pytest.mark.gpu while also adding
pytest.mark.slow, keeping the change localized to the module-level marker used
by these reward contract tests.

Source: Coding guidelines

Comment on lines +248 to +270
def test_gradient_descent_loss_trace(self, reward_case):
"""TEMP (exploration): print the per-step loss for the 10-step descent."""
torch.manual_seed(42)
perturbation = torch.randn_like(reward_case.coords) * 0.5
coords_opt = (reward_case.coords + perturbation).unsqueeze(0).requires_grad_(True)
optimizer = torch.optim.Adam([coords_opt], lr=0.01)

def loss_fn():
return reward_case.reward_function(
coordinates=coords_opt,
elements=reward_case.elements.unsqueeze(0),
b_factors=reward_case.b_factors.unsqueeze(0),
occupancies=reward_case.occupancies.unsqueeze(0),
)

print(f"\n[{reward_case.name}] initial: {loss_fn().item():.6e}")
for i in range(10):
optimizer.zero_grad()
loss = loss_fn()
loss.backward()
optimizer.step()
with torch.no_grad():
print(f"[{reward_case.name}] after step {i + 1:2d}: {loss_fn().item():.6e}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win

Remove the print-only debug test before merge.

test_gradient_descent_loss_trace() never asserts anything, so it always passes while adding another 10-step GPU optimization loop and debug output to every run. If you still want this trace, keep it as a non-collected helper instead. As per coding guidelines, "No dead code and no compatibility shims for hypothetical users."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/rewards/test_reward_function_contract.py` around lines 248 - 270, The
debug-only `test_gradient_descent_loss_trace` in
`test_reward_function_contract.py` should be removed from collected tests
because it only prints loss values and never asserts anything. Delete this
`test_...` method from `TestRewardFunctionContract`, or if you want to keep the
trace, rename it to a non-test helper so pytest won’t run it; make sure no
`print`-only optimization loop remains in the test suite.

Source: Coding guidelines

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant