Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions scripts/generate_esm_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,14 @@
from loguru import logger
from tqdm import tqdm

from src.constants import ONE_TO_THREE, THREE_TO_ONE
from src.constants import THREE_TO_ONE
from src.dataset import parse_asu_with_biotite
from src.utils import normalize_ins_code, parse_split_file, setup_logging_for_tqdm
from src.utils import (
normalize_ins_code,
parse_split_file,
sanitize_res_names_for_esm,
setup_logging_for_tqdm,
)


def compute_esm_embeddings(
Expand Down Expand Up @@ -95,13 +100,11 @@ def compute_esm_embeddings(
]
num_residues = len(biotite_seq)

# Sanitize the AtomArray so ESM accepts all residues
# Sanitize the AtomArray so ESM accepts all residues. Uses the shared
# helper so residue-name canonicalization stays identical to the residue
# indexing in src/dataset.py.
protein_atoms = sanitize_res_names_for_esm(protein_atoms)
protein_atoms.hetero[:] = False
for i in range(len(protein_atoms)):
orig_res = protein_atoms.res_name[i]
# Map to 1-letter code, then convert back to 3-letter
aa1 = THREE_TO_ONE.get(orig_res, "X")
protein_atoms.res_name[i] = ONE_TO_THREE.get(aa1, "UNK")

# Write sanitized array to an in-memory buffer
sanitized_pdb = PDBFile()
Expand Down
40 changes: 24 additions & 16 deletions src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,16 @@
from torch_geometric.data import Batch, HeteroData
from tqdm import tqdm

from src.constants import EDGE_PP, ELEM_IDX, ELEMENT_VOCAB, NUM_RBF
from src.constants import (
EDGE_PP,
ELEM_IDX,
ELEMENT_VOCAB,
NUM_RBF,
)
from src.utils import (
compute_edge_features,
normalize_ins_code,
sanitize_res_names_for_esm,
)


Expand Down Expand Up @@ -140,7 +146,7 @@ def match_atoms_to_coords(
Returns:
List of indices into atoms array for matched atoms
"""
if target_coords.shape[0] == 0:
if target_coords.shape[0] == 0 or len(atoms) == 0:
return []

matched = []
Expand Down Expand Up @@ -649,7 +655,7 @@ def filter_waters_by_quality(

class ProteinWaterDataset(Dataset):
"""
Dataset for protein crystal contact prediction.
Dataset for predicting water positions in protein crystal structures.

Returns HeteroData with:
- 'protein' node type: ASU protein atoms + optionally symmetry mates
Expand Down Expand Up @@ -977,20 +983,22 @@ def _preprocess_one(self, entry: dict, cache_path: Path):
protein_elements = [str(e).upper() for e in protein_atoms.element]
protein_x = element_onehot(protein_elements)

# compute residue indices (including ins_code to match ESM/SLAE residue counting)
res_id = protein_atoms.res_id
chain_id_arr = protein_atoms.chain_id
ins_code_arr = np.array(
[normalize_ins_code(x) for x in protein_atoms.ins_code], dtype=object
)
residue_keys = list(zip(chain_id_arr, res_id, ins_code_arr))
unique_res = {k: i for i, k in enumerate(dict.fromkeys(residue_keys))}
protein_res_idx = torch.tensor(
[unique_res[k] for k in residue_keys], dtype=torch.long
# Residue indices must match the ESM embedding script's residue counting.
# get_residue_starts splits on res_name and ins_code, so normalize both
# the same way the ESM script does (sanitize_res_names_for_esm for names,
# normalize_ins_code for insertion codes) to stay aligned with the stored
# embeddings.
sanitized_for_idx = sanitize_res_names_for_esm(protein_atoms)
for i in range(len(sanitized_for_idx)):
sanitized_for_idx.ins_code[i] = normalize_ins_code(
sanitized_for_idx.ins_code[i]
)
res_starts = bts.get_residue_starts(sanitized_for_idx)
num_residues = len(res_starts)
atom_res_idx = (
np.searchsorted(res_starts, np.arange(len(protein_atoms)), side="right") - 1
)

# check water/residue ratio
num_residues = len(unique_res)
protein_res_idx = torch.from_numpy(atom_res_idx.astype(np.int64))
num_waters = len(water_atoms)
ratio_valid, ratio_reason = check_water_residue_ratio(
num_waters,
Expand Down
29 changes: 27 additions & 2 deletions src/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,37 @@ def build_knn_edges(
batch_dst: torch.Tensor | None = None,
) -> torch.Tensor:
"""
KNN edges from src -> dst (source indices in row 0, dest in row 1).
Build KNN edges from src -> dst (source indices in row 0, dest in row 1).

The KNN query is performed *per destination*: for each point in ``dst_pos``
we look up its ``k`` nearest neighbors in ``src_pos`` (``knn(x=src_pos,
y=dst_pos, ...)``) and emit them as incoming edges. As a consequence every
destination node is guaranteed to have up to ``k`` incoming edges (and so
appears in row 1), whereas a source node that is no destination's nearest
neighbor may not appear in row 0 at all. Coverage checks ("every node has an
edge") must therefore be made against the destination row (row 1).

For a homogeneous graph (``src_pos is dst_pos``) self-edges are dropped.

Args:
src_pos: (N_src, 3) source node positions.
dst_pos: (N_dst, 3) destination node positions.
k: Number of nearest source neighbors to find per destination node.
batch_src: (N_src,) batch assignment for source nodes, or None.
batch_dst: (N_dst,) batch assignment for destination nodes, or None.

Returns:
(2, E) edge index tensor with source indices in row 0, destination in
row 1.
"""
if src_pos.numel() == 0 or dst_pos.numel() == 0:
return torch.empty(2, 0, dtype=torch.long, device=src_pos.device)

idx = knn(x=dst_pos, y=src_pos, k=k, batch_x=batch_dst, batch_y=batch_src)
# knn(x=src_pos, y=dst_pos) returns row 0 = dst (query) indices, row 1 = src
# (neighbor) indices; swap so the result follows the src(row 0)->dst(row 1)
# edge_index convention.
idx = knn(x=src_pos, y=dst_pos, k=k, batch_x=batch_src, batch_y=batch_dst)
idx = torch.stack((idx[1], idx[0]), dim=0)

# remove self-edges if homogeneous
if src_pos.data_ptr() == dst_pos.data_ptr():
Expand Down
98 changes: 32 additions & 66 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

"""
Utility functions organized by category:
1. Feature encoding (rbf, atom37_to_atoms, normalize_ins_code)
1. Feature encoding (rbf, normalize_ins_code)
2. Optimal transport (ot_coupling)
3. Metrics (recall_precision, compute_rmsd, compute_placement_metrics)
4. Visualization (plot_3d_frame, create_trajectory_gif, save_protein_plot)
Expand All @@ -24,44 +24,9 @@
from PIL import Image
from scipy.optimize import linear_sum_assignment
from torch import Tensor
from torch_geometric.nn import knn
from tqdm import tqdm

from src.constants import NUM_RBF, RBF_CUTOFF


def build_knn_edges(
src_pos: torch.Tensor,
dst_pos: torch.Tensor,
k: int,
batch_src: torch.Tensor | None = None,
batch_dst: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Build KNN edges from source to destination nodes.

Args:
src_pos: (N_src, 3) source node positions
dst_pos: (N_dst, 3) destination node positions
k: Number of nearest neighbors per source node
batch_src: (N_src,) batch indices for source nodes, or None if single graph
batch_dst: (N_dst,) batch indices for destination nodes, or None if single graph

Returns:
(2, E) edge index tensor with source indices in row 0, destination in row 1.
Self-edges are removed for homogeneous graphs (src_pos is dst_pos).
"""
if src_pos.numel() == 0 or dst_pos.numel() == 0:
return torch.empty(2, 0, dtype=torch.long, device=src_pos.device)

idx = knn(x=dst_pos, y=src_pos, k=k, batch_x=batch_dst, batch_y=batch_src)

# remove self-edges if homogeneous
if src_pos.data_ptr() == dst_pos.data_ptr():
mask = idx[0] != idx[1]
idx = idx[:, mask]

return idx.unique(dim=1)
from src.constants import NUM_RBF, ONE_TO_THREE, RBF_CUTOFF, THREE_TO_ONE


def setup_logging_for_tqdm(
Expand Down Expand Up @@ -117,6 +82,36 @@ def normalize_ins_code(value) -> str:
return ins


def sanitize_res_names_for_esm(atoms):
"""
Return a copy of an AtomArray with residue names canonicalized to match the
ESM embedding pipeline.

Each residue name is mapped to its one-letter code and back
(``THREE_TO_ONE`` -> ``ONE_TO_THREE``), with anything unrecognized collapsed
to ``"UNK"``. This merges non-canonical names that share a residue position
(e.g. modified residues -> their canonical parent, unknowns -> ``UNK``) so
that biotite's ``get_residue_starts`` does not split them apart.

This is the single source of truth for residue-name sanitization shared by
``scripts/generate_esm_embeddings.py`` (which feeds the sanitized structure
to ESM3) and ``src/dataset.py`` (which derives residue indices that must line
up with the stored ESM embeddings). Insertion codes are normalized
separately via :func:`normalize_ins_code`.

Args:
atoms: A biotite ``AtomArray`` with a ``res_name`` annotation.

Returns:
A copy of ``atoms`` with ``res_name`` canonicalized.
"""
sanitized = atoms.copy()
for i in range(len(sanitized)):
aa1 = THREE_TO_ONE.get(sanitized.res_name[i], "X")
sanitized.res_name[i] = ONE_TO_THREE.get(aa1, "UNK")
return sanitized


def parse_split_file(split_file: Path, base_pdb_dir: Path) -> list[dict]:
"""
Parse split file and construct entries with paths.
Expand Down Expand Up @@ -164,9 +159,6 @@ def parse_split_file(split_file: Path, base_pdb_dir: Path) -> list[dict]:
return entries


ATOM37_FILL = 1e-5


def rbf(r: Tensor, num_gaussians: int = NUM_RBF, cutoff: float = RBF_CUTOFF) -> Tensor:
"""
Compute radial basis function encoding of distances.
Expand Down Expand Up @@ -264,32 +256,6 @@ def compute_edge_features(
return unit_vectors, rbf_features


def atom37_to_atoms(
atom_tensor: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Convert atom37 representation to flat atom list.

Args:
atom_tensor: (N_res, 37, 3) atom37 coordinates

Returns:
coords: (N_atoms, 3) coordinates of present atoms
residue_index: (N_atoms,) which residue each atom belongs to
atom_type: (N_atoms,) atom type index (0-36)
"""
present = (atom_tensor != ATOM37_FILL).any(dim=-1) # (N_res, 37)
nz = present.nonzero(as_tuple=False) # (N_atoms, 2)
residue_index = nz[:, 0]
atom_type = nz[:, 1].long()

flat = atom_tensor.reshape(-1, 3)
flat_mask = present.reshape(-1)
coords = flat[flat_mask]

return coords, residue_index, atom_type


@torch.no_grad()
def ot_coupling(
x1: torch.Tensor,
Expand Down
13 changes: 9 additions & 4 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,8 +573,11 @@ def test_all_waters_have_water_edges(self, simple_hetero_data):
n_water = simple_hetero_data["water"].num_nodes

if n_water > 1:
# Check that all water nodes appear in the water-water edges
water_nodes_with_edges = torch.unique(ww_edges[0])
# WW edges are built per destination (knn query per water), so every
# water is guaranteed to appear as a destination (row 1); a water that
# is no other water's nearest neighbor would be missing from the source
# row (row 0). Assert coverage on the destination/query row.
water_nodes_with_edges = torch.unique(ww_edges[1])
assert len(water_nodes_with_edges) == n_water, (
f"Only {len(water_nodes_with_edges)}/{n_water} waters have water-water edges"
)
Expand All @@ -595,9 +598,11 @@ def test_batched_waters_have_edges(self, batched_hetero_data):
f"Only {len(water_nodes_with_pw_edges)}/{n_water} waters have protein edges in batched data"
)

# Check water-water edges
# Check water-water edges. WW edges are built per destination, so every
# water appears as a destination (row 1); assert coverage on the
# destination/query row rather than the source row.
if n_water > 1:
water_nodes_with_ww_edges = torch.unique(ww_edges[0])
water_nodes_with_ww_edges = torch.unique(ww_edges[1])
Comment thread
vratins marked this conversation as resolved.
assert len(water_nodes_with_ww_edges) == n_water, (
f"Only {len(water_nodes_with_ww_edges)}/{n_water} waters have water-water edges in batched data"
)
Expand Down
Loading
Loading