Skip to content
351 changes: 351 additions & 0 deletions src/sampleworks/core/rewards/structure_factor.py

Large diffs are not rendered by default.

144 changes: 68 additions & 76 deletions src/sampleworks/eval/generate_synthetic_sf.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Generate synthetic structure factor amplitudes via SFcalculator-torch.

Produces an MTZ file of |Fmodel| (or |Fprotein| if not simulate solvent and
scale) for each input PDB/mmCIF structure. The MTZ file has dummy values for
SIGFP and optionally R-free flag column. Each structure can be optionally
overridden with unit cell, space group, atom selection, and occupancy.
For each input PDB/mmCIF structure, produces an MTZ file of protein structure factors
only, or when ``--simulate-solvent-and-scale`` is set, both the protein and total sets
(Fprotein/SIGFprotein/PHIFprotein and Ftotal/SIGFtotal/PHIFtotal) in the same
MTZ. The MTZ file has dummy values for the SIGF column(s) and optionally an R-free flag
column. Each structure can be optionally overridden with unit cell, space group, atom
selection, and occupancy.
"""

import argparse
Expand All @@ -15,20 +17,18 @@
from typing import Any, ClassVar

import gemmi
import numpy as np
import reciprocalspaceship as rs
import reciprocalspaceship.utils
import torch
from biotite.structure import AtomArray
from loguru import logger
from sampleworks.eval.synthetic_utils import (
atomarray_to_gemmi,
load_structure_for_synthetic_reward,
validate_occupancy_values,
)
from sampleworks.utils.atom_array_utils import BLANK_ALTLOC_IDS
from sampleworks.utils.torch_utils import try_gpu
from SFC_Torch import SFcalculator
from SFC_Torch.io import array2hier, PDBParser
from SFC_Torch.io import PDBParser


@dataclass
Expand Down Expand Up @@ -127,62 +127,41 @@ def from_dict(cls, row: dict[str, Any]) -> "BatchRowForMTZ":
)


def atomarray_to_gemmi(
atom_array: AtomArray,
unit_cell: gemmi.UnitCell | None = None,
space_group: str | None = None,
) -> gemmi.Structure:
"""Convert a biotite AtomArray to a gemmi.Structure for SFcalculator.

Anisotropic B-factors are set to zero since biotite does not store them.
Blank altloc labels are converted from biotite's '' to gemmi's '\\x00'.

Parameters
----------
atom_array
Input structure with occupancy and b_factor annotations
unit_cell
Crystallographic unit cell for the structure. If None, gemmi defaults
to (1.0, 1.0, 1.0, 90.0, 90.0, 90.0) in units of Angstroms and degrees.
space_group
Space group (in Hermann-Mauguin string format) for the structure. If
empty or invalid, SFcalculator defaults to P1.

Returns
-------
gemmi.Structure
Structure ready to be wrapped by SFC_Torch.io.PDBParser
def _amplitude_phase_columns(
sfc: SFcalculator,
label: str,
structure_factor_column: str,
miller_index_column: str,
sigma_f_scale: float,
) -> rs.DataSet:
"""Build a one-amplitude rs.DataSet with labelled F / SIGF / PHIF columns.

``sfc.prepare_dataset`` returns an amplitude column and a phase column (degrees)
for the given ``structure_factor_column`` attribute. We auto-detect those by MTZ
dtype (rather than assuming the unexposed ``FMODEL`` / ``PHIFMODEL`` names),
rename them to ``F{label}`` / ``PHIF{label}``, and synthesize a ``SIGF{label}``
column so several structure-factor sets (e.g. protein and total) can coexist in
one MTZ.
"""
n = len(atom_array)
cra_names = [
f"{atom_array.chain_id[i]}-0-{atom_array.res_name[i]}-{atom_array.atom_name[i]}"
for i in range(n)
]
# gemmi uses '\x00' for blank altloc
atom_altloc = ["\x00" if a in BLANK_ALTLOC_IDS else a for a in atom_array.altloc_id]
structure: gemmi.Structure = array2hier(
atom_pos=atom_array.coord,
atom_b_aniso=np.zeros((n, 3, 3), dtype=np.float64),
atom_b_iso=atom_array.b_factor,
atom_occ=atom_array.occupancy,
atom_name=atom_array.element,
cra_name=cra_names,
atom_altloc=atom_altloc,
res_id=atom_array.res_id,
dataset: rs.DataSet = sfc.prepare_dataset(miller_index_column, structure_factor_column)
amplitude_column = dataset.select_mtzdtype(rs.StructureFactorAmplitudeDtype()).columns[0]
phase_column = dataset.select_mtzdtype(rs.PhaseDtype()).columns[0]
logger.debug(
f"Auto-detected amplitude column: {amplitude_column}, "
f"phase column: {phase_column} for {label}"
)
if unit_cell is not None:
structure.cell = unit_cell
if space_group is not None:
structure.spacegroup_hm = space_group
return structure
f_col, phi_col, sig_col = f"F{label}", f"PHIF{label}", f"SIGF{label}"
dataset = dataset.rename(columns={amplitude_column: f_col, phase_column: phi_col})
dataset[sig_col] = (dataset[f_col] * sigma_f_scale).astype(rs.StandardDeviationDtype())
return dataset[[f_col, sig_col, phi_col]]


def process_amplitudes_to_dataset(
sfc: SFcalculator,
structure_factor_columns: dict[str, str],
test_fraction: float = 0.05,
seed: int | None = None,
miller_index_column: str = "Hasu_array",
structure_factor_column: str = "Ftotal_asu",
ccp4_convention: bool = False,
sigma_f_scale: float = 0.2,
output_path: Path | None = None,
Expand All @@ -193,14 +172,20 @@ def process_amplitudes_to_dataset(
----------
sfc: SFcalculator
SFcalculator instance
structure_factor_columns: dict[str, str]
Mapping of ``label -> SFcalculator attribute``. One structure-factor set
(``F{label}``/``SIGF{label}``/``PHIF{label}``) is emitted per entry, and
multiple entries are merged into one MTZ sharing the same HKL list
(``miller_index_column``) and a single R-free column, e.g.
``{"protein": "Fprotein_asu", "total": "Ftotal_asu"}`` produces
``Fprotein``/``SIGFprotein``/``PHIFprotein`` and
``Ftotal``/``SIGFtotal``/``PHIFtotal``.
test_fraction: float
Fraction of reflections to mark as R-free test set (0 disables)
seed: int | None
Optional seed for reproducible R-free flag assignment
miller_index_column: str
Attribute name in SFcalculator for hkl indices
structure_factor_column: str
Attribute name in SFcalculator for structure factors
ccp4_convention: bool
If True, use CCP4 convention for R-free flag assignment. Default
is False, which uses Phenix convention (1 = test, 0 = working).
Expand All @@ -214,18 +199,18 @@ def process_amplitudes_to_dataset(
Returns
-------
rs.DataSet
Dataset with structure factor amplitudes, fake sigma column, and optionally
R-free flags.
Dataset with structure factor amplitudes, dummy sigma column(s), phases,
and optionally R-free flags.
"""
dataset: rs.DataSet = sfc.prepare_dataset(miller_index_column, structure_factor_column)
# assumes the first detected column of dtype F is the structure factor amplitude column
# avoids hardcoding unexposed column name "FMODEL" from sfc.prepare_dataset().
structure_factor_amplitude_column = dataset.select_mtzdtype(
rs.StructureFactorAmplitudeDtype()
).columns[0]
sigma_f_column = f"SIG{structure_factor_amplitude_column}"
dataset[sigma_f_column] = dataset[structure_factor_amplitude_column] * sigma_f_scale
dataset[sigma_f_column] = dataset[sigma_f_column].astype(rs.StandardDeviationDtype())
if not structure_factor_columns:
raise ValueError("structure_factor_columns must contain at least one entry.")
column_items = iter(structure_factor_columns.items())
label, attribute = next(column_items)
dataset = _amplitude_phase_columns(sfc, label, attribute, miller_index_column, sigma_f_scale)
for label, attribute in column_items:
ds = _amplitude_phase_columns(sfc, label, attribute, miller_index_column, sigma_f_scale)
for col in ds.columns:
dataset[col] = ds[col]
if test_fraction > 0:
dataset = rs.utils.add_rfree(
dataset,
Expand Down Expand Up @@ -288,8 +273,9 @@ def _process_single_row(
If True, remove ligand molecules (non-water heteroatoms) before computing structure
factors. Default is False.
simulate_solvent_and_scale
If True, compute bulk solvent and scale factors for Ftotal instead of Fprotein.
Default is False.
If True, compute bulk solvent and scale factors and write a single MTZ containing
both the protein and total structure factor sets. If False (default), only the
protein set is written. One set contains F{label}/SIGF{label}/PHIF{label}.
save_structure
If True, save the processed structure (after selection and occupancy assignment)
as mmCIF to output_dir. Unit cell and space group are preserved. Default is False.
Expand Down Expand Up @@ -354,15 +340,17 @@ def _process_single_row(
f"n_atoms: {len(sfc.atom_pos_orth)}"
)
sfc.calc_fprotein()
structure_factor_columns = {"protein": "Fprotein_asu"}
if simulate_solvent_and_scale:
sfc.inspect_data()
sfc.calc_fsolvent()
sfc.init_scales(requires_grad=False)
sfc.calc_ftotal()
F_attribute = "Ftotal_asu"
else:
F_attribute = "Fprotein_asu"
logger.debug(f"Computed {F_attribute} for {row.filename} on {device}")
structure_factor_columns.update({"total": "Ftotal_asu"})
logger.debug(
f"Computed {'Fprotein + Ftotal' if simulate_solvent_and_scale else 'Fprotein'} "
f"for {row.filename} on {device}"
)
except Exception as e:
logger.error(
f"Failed to compute for {row.filename} ({type(e).__name__}): {e}\n"
Expand All @@ -375,7 +363,7 @@ def _process_single_row(
try:
process_amplitudes_to_dataset(
sfc,
structure_factor_column=F_attribute,
structure_factor_columns=structure_factor_columns,
test_fraction=test_fraction,
seed=seed,
output_path=output_path,
Expand Down Expand Up @@ -552,7 +540,11 @@ def parse_args() -> argparse.Namespace:
sf_group.add_argument(
"--simulate-solvent-and-scale",
action="store_true",
help="Compute bulk solvent and overall scale factors (outputs Ftotal instead of Fprotein)",
help=(
"Compute bulk solvent and overall scale factors and write both protein and "
"total structure factor in one MTZ. Without this flag, protein only. Each "
"set contains F\\{label\\}/SIGF\\{label\\}/PHIF\\{label\\}."
),
Comment on lines +543 to +547

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win

Remove the escaped braces from the help text.

This is not an f-string, so F\\{label\\} will render to users as F\{label\} in --help.

Proposed fix
         help=(
             "Compute bulk solvent and overall scale factors and write both protein and "
             "total structure factor in one MTZ. Without this flag, protein only. Each "
-            "set contains F\\{label\\}/SIGF\\{label\\}/PHIF\\{label\\}."
+            "set contains F{label}/SIGF{label}/PHIF{label}."
         ),
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
help=(
"Compute bulk solvent and overall scale factors and write both protein and "
"total structure factor in one MTZ. Without this flag, protein only. Each "
"set contains F\\{label\\}/SIGF\\{label\\}/PHIF\\{label\\}."
),
help=(
"Compute bulk solvent and overall scale factors and write both protein and "
"total structure factor in one MTZ. Without this flag, protein only. Each "
"set contains F{label}/SIGF{label}/PHIF{label}."
),
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/sampleworks/eval/generate_synthetic_sf.py` around lines 543 - 547, The
help text in generate_synthetic_sf should not escape the braces around label
because it is a plain string, not an f-string. Update the help text in the
argument definition near the existing bulk solvent option so users see
F{label}/SIGF{label}/PHIF{label} in --help, and verify the same wording is used
consistently wherever that option text is defined.

)
sf_group.add_argument(
"--remove-hydrogens",
Expand Down
75 changes: 75 additions & 0 deletions src/sampleworks/eval/synthetic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
import traceback
from pathlib import Path

import gemmi
import numpy as np
from atomworks.io.transforms.atom_array import remove_waters
from biotite.structure import AtomArray
from loguru import logger
from sampleworks.eval.structure_utils import apply_selection
from sampleworks.utils.atom_array_utils import (
AltlocInfo,
BLANK_ALTLOC_IDS,
detect_altlocs,
keep_amino_acids,
keep_polymer,
Expand Down Expand Up @@ -194,3 +197,75 @@ def load_structure_for_synthetic_reward(
raise ValueError(f"Invalid occupancy mode '{occupancy_mode}'")

return atom_array


def atomarray_to_gemmi(
atom_array: AtomArray,
unit_cell: gemmi.UnitCell | None = None,
space_group: str | None = None,
) -> gemmi.Structure:
"""Convert a biotite AtomArray to a gemmi.Structure for SFcalculator.

Anisotropic B-factors are set to zero since biotite does not store them.
Blank altloc labels are converted from biotite's '' to gemmi's '\\x00'. If
the atom array has no ``altloc_id`` annotation (e.g. arrays reconstructed by
a model wrapper), all altlocs default to blank.

Parameters
----------
atom_array
Input structure with occupancy and b_factor annotations
unit_cell
Crystallographic unit cell for the structure. If None, gemmi defaults
to (1.0, 1.0, 1.0, 90.0, 90.0, 90.0) in units of Angstroms and degrees.
space_group
Space group (in Hermann-Mauguin string format) for the structure. If
empty or invalid, SFcalculator defaults to P1.

Returns
-------
gemmi.Structure
Structure ready to be wrapped by SFC_Torch.io.PDBParser
"""
# Lazy import so importing this module does not require SFC_Torch on paths
# that don't need it (e.g. synthetic density generation).
from SFC_Torch.io import array2hier

n = len(atom_array)
cra_names = [
f"{atom_array.chain_id[i]}-0-{atom_array.res_name[i]}-{atom_array.atom_name[i]}"
for i in range(n)
]
# altloc_id is not a mandatory biotite annotation; default to blank when absent.
# gemmi uses '\x00' for blank altloc
if "altloc_id" in atom_array.get_annotation_categories():
atom_altloc = ["\x00" if a in BLANK_ALTLOC_IDS else a for a in atom_array.altloc_id]
else:
atom_altloc = ["\x00"] * n
structure: gemmi.Structure = array2hier(
atom_pos=atom_array.coord,
atom_b_aniso=np.zeros((n, 3, 3), dtype=np.float64),
atom_b_iso=atom_array.b_factor,
atom_occ=atom_array.occupancy,
atom_name=atom_array.element,
cra_name=cra_names,
atom_altloc=atom_altloc,
res_id=atom_array.res_id,
)
# array2hier names the single model "SFC" and its setup_entities() assigns auto-generated
# subchain ids (label_asym_id, e.g. "Axp"). Both corrupt a written-out mmCIF: the
# non-integer model name breaks mmCIF parsers' pdbx_PDB_model_num (biotite/atomworks read
# it as int), and the multi-char label_asym_id is re-read as the chain id (then rejected by
# SFcalculator's PDB-header step, which needs a <=1-char chain). Normalize both — a valid
# numeric model id and label_asym_id == the chain name — so saved structures
# (generate_synthetic_sf --save-structure) round-trip.
for model_idx, model in enumerate(structure):
model.name = str(model_idx + 1)
for chain in model:
for residue in chain:
residue.subchain = chain.name
if unit_cell is not None:
structure.cell = unit_cell
if space_group is not None:
structure.spacegroup_hm = space_group
return structure
Loading
Loading