diff --git a/scripts/generate_esm_embeddings.py b/scripts/generate_esm_embeddings.py index cb58253..c1133b9 100644 --- a/scripts/generate_esm_embeddings.py +++ b/scripts/generate_esm_embeddings.py @@ -41,9 +41,14 @@ from loguru import logger from tqdm import tqdm -from src.constants import ONE_TO_THREE, THREE_TO_ONE +from src.constants import THREE_TO_ONE from src.dataset import parse_asu_with_biotite -from src.utils import normalize_ins_code, parse_split_file, setup_logging_for_tqdm +from src.utils import ( + normalize_ins_code, + parse_split_file, + sanitize_res_names_for_esm, + setup_logging_for_tqdm, +) def compute_esm_embeddings( @@ -95,13 +100,11 @@ def compute_esm_embeddings( ] num_residues = len(biotite_seq) - # Sanitize the AtomArray so ESM accepts all residues + # Sanitize the AtomArray so ESM accepts all residues. Uses the shared + # helper so residue-name canonicalization stays identical to the residue + # indexing in src/dataset.py. + protein_atoms = sanitize_res_names_for_esm(protein_atoms) protein_atoms.hetero[:] = False - for i in range(len(protein_atoms)): - orig_res = protein_atoms.res_name[i] - # Map to 1-letter code, then convert back to 3-letter - aa1 = THREE_TO_ONE.get(orig_res, "X") - protein_atoms.res_name[i] = ONE_TO_THREE.get(aa1, "UNK") # Write sanitized array to an in-memory buffer sanitized_pdb = PDBFile() diff --git a/src/dataset.py b/src/dataset.py index 742b135..109043b 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -29,10 +29,16 @@ from torch_geometric.data import Batch, HeteroData from tqdm import tqdm -from src.constants import EDGE_PP, ELEM_IDX, ELEMENT_VOCAB, NUM_RBF +from src.constants import ( + EDGE_PP, + ELEM_IDX, + ELEMENT_VOCAB, + NUM_RBF, +) from src.utils import ( compute_edge_features, normalize_ins_code, + sanitize_res_names_for_esm, ) @@ -140,7 +146,7 @@ def match_atoms_to_coords( Returns: List of indices into atoms array for matched atoms """ - if target_coords.shape[0] == 0: + if target_coords.shape[0] == 0 or len(atoms) == 0: return [] matched = [] @@ -649,7 +655,7 @@ def filter_waters_by_quality( class ProteinWaterDataset(Dataset): """ - Dataset for protein crystal contact prediction. + Dataset for predicting water positions in protein crystal structures. Returns HeteroData with: - 'protein' node type: ASU protein atoms + optionally symmetry mates @@ -977,20 +983,22 @@ def _preprocess_one(self, entry: dict, cache_path: Path): protein_elements = [str(e).upper() for e in protein_atoms.element] protein_x = element_onehot(protein_elements) - # compute residue indices (including ins_code to match ESM/SLAE residue counting) - res_id = protein_atoms.res_id - chain_id_arr = protein_atoms.chain_id - ins_code_arr = np.array( - [normalize_ins_code(x) for x in protein_atoms.ins_code], dtype=object - ) - residue_keys = list(zip(chain_id_arr, res_id, ins_code_arr)) - unique_res = {k: i for i, k in enumerate(dict.fromkeys(residue_keys))} - protein_res_idx = torch.tensor( - [unique_res[k] for k in residue_keys], dtype=torch.long + # Residue indices must match the ESM embedding script's residue counting. + # get_residue_starts splits on res_name and ins_code, so normalize both + # the same way the ESM script does (sanitize_res_names_for_esm for names, + # normalize_ins_code for insertion codes) to stay aligned with the stored + # embeddings. + sanitized_for_idx = sanitize_res_names_for_esm(protein_atoms) + for i in range(len(sanitized_for_idx)): + sanitized_for_idx.ins_code[i] = normalize_ins_code( + sanitized_for_idx.ins_code[i] + ) + res_starts = bts.get_residue_starts(sanitized_for_idx) + num_residues = len(res_starts) + atom_res_idx = ( + np.searchsorted(res_starts, np.arange(len(protein_atoms)), side="right") - 1 ) - - # check water/residue ratio - num_residues = len(unique_res) + protein_res_idx = torch.from_numpy(atom_res_idx.astype(np.int64)) num_waters = len(water_atoms) ratio_valid, ratio_reason = check_water_residue_ratio( num_waters, diff --git a/src/flow.py b/src/flow.py index ad5f7e4..af44314 100644 --- a/src/flow.py +++ b/src/flow.py @@ -34,12 +34,37 @@ def build_knn_edges( batch_dst: torch.Tensor | None = None, ) -> torch.Tensor: """ - KNN edges from src -> dst (source indices in row 0, dest in row 1). + Build KNN edges from src -> dst (source indices in row 0, dest in row 1). + + The KNN query is performed *per destination*: for each point in ``dst_pos`` + we look up its ``k`` nearest neighbors in ``src_pos`` (``knn(x=src_pos, + y=dst_pos, ...)``) and emit them as incoming edges. As a consequence every + destination node is guaranteed to have up to ``k`` incoming edges (and so + appears in row 1), whereas a source node that is no destination's nearest + neighbor may not appear in row 0 at all. Coverage checks ("every node has an + edge") must therefore be made against the destination row (row 1). + + For a homogeneous graph (``src_pos is dst_pos``) self-edges are dropped. + + Args: + src_pos: (N_src, 3) source node positions. + dst_pos: (N_dst, 3) destination node positions. + k: Number of nearest source neighbors to find per destination node. + batch_src: (N_src,) batch assignment for source nodes, or None. + batch_dst: (N_dst,) batch assignment for destination nodes, or None. + + Returns: + (2, E) edge index tensor with source indices in row 0, destination in + row 1. """ if src_pos.numel() == 0 or dst_pos.numel() == 0: return torch.empty(2, 0, dtype=torch.long, device=src_pos.device) - idx = knn(x=dst_pos, y=src_pos, k=k, batch_x=batch_dst, batch_y=batch_src) + # knn(x=src_pos, y=dst_pos) returns row 0 = dst (query) indices, row 1 = src + # (neighbor) indices; swap so the result follows the src(row 0)->dst(row 1) + # edge_index convention. + idx = knn(x=src_pos, y=dst_pos, k=k, batch_x=batch_src, batch_y=batch_dst) + idx = torch.stack((idx[1], idx[0]), dim=0) # remove self-edges if homogeneous if src_pos.data_ptr() == dst_pos.data_ptr(): diff --git a/src/utils.py b/src/utils.py index e67c30e..3717a9c 100644 --- a/src/utils.py +++ b/src/utils.py @@ -4,7 +4,7 @@ """ Utility functions organized by category: -1. Feature encoding (rbf, atom37_to_atoms, normalize_ins_code) +1. Feature encoding (rbf, normalize_ins_code) 2. Optimal transport (ot_coupling) 3. Metrics (recall_precision, compute_rmsd, compute_placement_metrics) 4. Visualization (plot_3d_frame, create_trajectory_gif, save_protein_plot) @@ -24,44 +24,9 @@ from PIL import Image from scipy.optimize import linear_sum_assignment from torch import Tensor -from torch_geometric.nn import knn from tqdm import tqdm -from src.constants import NUM_RBF, RBF_CUTOFF - - -def build_knn_edges( - src_pos: torch.Tensor, - dst_pos: torch.Tensor, - k: int, - batch_src: torch.Tensor | None = None, - batch_dst: torch.Tensor | None = None, -) -> torch.Tensor: - """ - Build KNN edges from source to destination nodes. - - Args: - src_pos: (N_src, 3) source node positions - dst_pos: (N_dst, 3) destination node positions - k: Number of nearest neighbors per source node - batch_src: (N_src,) batch indices for source nodes, or None if single graph - batch_dst: (N_dst,) batch indices for destination nodes, or None if single graph - - Returns: - (2, E) edge index tensor with source indices in row 0, destination in row 1. - Self-edges are removed for homogeneous graphs (src_pos is dst_pos). - """ - if src_pos.numel() == 0 or dst_pos.numel() == 0: - return torch.empty(2, 0, dtype=torch.long, device=src_pos.device) - - idx = knn(x=dst_pos, y=src_pos, k=k, batch_x=batch_dst, batch_y=batch_src) - - # remove self-edges if homogeneous - if src_pos.data_ptr() == dst_pos.data_ptr(): - mask = idx[0] != idx[1] - idx = idx[:, mask] - - return idx.unique(dim=1) +from src.constants import NUM_RBF, ONE_TO_THREE, RBF_CUTOFF, THREE_TO_ONE def setup_logging_for_tqdm( @@ -117,6 +82,36 @@ def normalize_ins_code(value) -> str: return ins +def sanitize_res_names_for_esm(atoms): + """ + Return a copy of an AtomArray with residue names canonicalized to match the + ESM embedding pipeline. + + Each residue name is mapped to its one-letter code and back + (``THREE_TO_ONE`` -> ``ONE_TO_THREE``), with anything unrecognized collapsed + to ``"UNK"``. This merges non-canonical names that share a residue position + (e.g. modified residues -> their canonical parent, unknowns -> ``UNK``) so + that biotite's ``get_residue_starts`` does not split them apart. + + This is the single source of truth for residue-name sanitization shared by + ``scripts/generate_esm_embeddings.py`` (which feeds the sanitized structure + to ESM3) and ``src/dataset.py`` (which derives residue indices that must line + up with the stored ESM embeddings). Insertion codes are normalized + separately via :func:`normalize_ins_code`. + + Args: + atoms: A biotite ``AtomArray`` with a ``res_name`` annotation. + + Returns: + A copy of ``atoms`` with ``res_name`` canonicalized. + """ + sanitized = atoms.copy() + for i in range(len(sanitized)): + aa1 = THREE_TO_ONE.get(sanitized.res_name[i], "X") + sanitized.res_name[i] = ONE_TO_THREE.get(aa1, "UNK") + return sanitized + + def parse_split_file(split_file: Path, base_pdb_dir: Path) -> list[dict]: """ Parse split file and construct entries with paths. @@ -164,9 +159,6 @@ def parse_split_file(split_file: Path, base_pdb_dir: Path) -> list[dict]: return entries -ATOM37_FILL = 1e-5 - - def rbf(r: Tensor, num_gaussians: int = NUM_RBF, cutoff: float = RBF_CUTOFF) -> Tensor: """ Compute radial basis function encoding of distances. @@ -264,32 +256,6 @@ def compute_edge_features( return unit_vectors, rbf_features -def atom37_to_atoms( - atom_tensor: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Convert atom37 representation to flat atom list. - - Args: - atom_tensor: (N_res, 37, 3) atom37 coordinates - - Returns: - coords: (N_atoms, 3) coordinates of present atoms - residue_index: (N_atoms,) which residue each atom belongs to - atom_type: (N_atoms,) atom type index (0-36) - """ - present = (atom_tensor != ATOM37_FILL).any(dim=-1) # (N_res, 37) - nz = present.nonzero(as_tuple=False) # (N_atoms, 2) - residue_index = nz[:, 0] - atom_type = nz[:, 1].long() - - flat = atom_tensor.reshape(-1, 3) - flat_mask = present.reshape(-1) - coords = flat[flat_mask] - - return coords, residue_index, atom_type - - @torch.no_grad() def ot_coupling( x1: torch.Tensor, diff --git a/tests/test_flow.py b/tests/test_flow.py index a3a519c..0aa3ff1 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -573,8 +573,11 @@ def test_all_waters_have_water_edges(self, simple_hetero_data): n_water = simple_hetero_data["water"].num_nodes if n_water > 1: - # Check that all water nodes appear in the water-water edges - water_nodes_with_edges = torch.unique(ww_edges[0]) + # WW edges are built per destination (knn query per water), so every + # water is guaranteed to appear as a destination (row 1); a water that + # is no other water's nearest neighbor would be missing from the source + # row (row 0). Assert coverage on the destination/query row. + water_nodes_with_edges = torch.unique(ww_edges[1]) assert len(water_nodes_with_edges) == n_water, ( f"Only {len(water_nodes_with_edges)}/{n_water} waters have water-water edges" ) @@ -595,9 +598,11 @@ def test_batched_waters_have_edges(self, batched_hetero_data): f"Only {len(water_nodes_with_pw_edges)}/{n_water} waters have protein edges in batched data" ) - # Check water-water edges + # Check water-water edges. WW edges are built per destination, so every + # water appears as a destination (row 1); assert coverage on the + # destination/query row rather than the source row. if n_water > 1: - water_nodes_with_ww_edges = torch.unique(ww_edges[0]) + water_nodes_with_ww_edges = torch.unique(ww_edges[1]) assert len(water_nodes_with_ww_edges) == n_water, ( f"Only {len(water_nodes_with_ww_edges)}/{n_water} waters have water-water edges in batched data" ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 0239779..6e5ca72 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,7 +4,7 @@ Tests for src/utils.py utility functions. Organized by category to match utils.py structure: -1. Feature encoding (rbf, atom37_to_atoms) +1. Feature encoding (rbf) 2. Optimal transport (ot_coupling) 3. Metrics (recall_precision, compute_rmsd, compute_placement_metrics) 4. Visualization (plot_3d_frame, save_protein_plot, create_trajectory_gif) @@ -12,18 +12,18 @@ All test cases created with assistance from Claude Code and refined. """ +import biotite.structure as bts import matplotlib import numpy as np import pytest import torch +from biotite.structure import array, Atom matplotlib.use("Agg") from src.utils import ( - ATOM37_FILL, - atom37_to_atoms, compute_edge_features, compute_edge_geometry, compute_placement_metrics, @@ -36,6 +36,7 @@ rbf, # Metrics recall_precision, + sanitize_res_names_for_esm, save_protein_plot, ) @@ -132,80 +133,85 @@ def test_normalize_valid_code(self): @pytest.mark.unit -class TestAtom37ToAtoms: - """Tests for atom37 representation conversion.""" - - def test_basic_conversion(self): - """Basic conversion from atom37 to flat atoms.""" - # Create atom37 tensor with some present atoms - atom_tensor = torch.full((3, 37, 3), ATOM37_FILL) - # Place atoms at specific positions - atom_tensor[0, 0, :] = torch.tensor([1.0, 2.0, 3.0]) # CA of residue 0 - atom_tensor[0, 1, :] = torch.tensor([1.5, 2.5, 3.5]) # C of residue 0 - atom_tensor[1, 0, :] = torch.tensor([4.0, 5.0, 6.0]) # CA of residue 1 - - coords, residue_idx, atom_type = atom37_to_atoms(atom_tensor) - - assert coords.shape == (3, 3) - assert residue_idx.shape == (3,) - assert atom_type.shape == (3,) - - def test_residue_indices_correct(self): - """Residue indices should match the original residue.""" - atom_tensor = torch.full((2, 37, 3), ATOM37_FILL) - atom_tensor[0, 0, :] = torch.tensor([1.0, 0.0, 0.0]) - atom_tensor[0, 1, :] = torch.tensor([2.0, 0.0, 0.0]) - atom_tensor[1, 5, :] = torch.tensor([3.0, 0.0, 0.0]) - - _, residue_idx, atom_type = atom37_to_atoms(atom_tensor) - - assert residue_idx[0] == 0 - assert residue_idx[1] == 0 - assert residue_idx[2] == 1 - - def test_atom_types_correct(self): - """Atom types should match the slot index.""" - atom_tensor = torch.full((1, 37, 3), ATOM37_FILL) - atom_tensor[0, 0, :] = torch.tensor([1.0, 0.0, 0.0]) # slot 0 - atom_tensor[0, 5, :] = torch.tensor([2.0, 0.0, 0.0]) # slot 5 - atom_tensor[0, 10, :] = torch.tensor([3.0, 0.0, 0.0]) # slot 10 - - _, _, atom_type = atom37_to_atoms(atom_tensor) - - assert atom_type[0] == 0 - assert atom_type[1] == 5 - assert atom_type[2] == 10 - - def test_empty_residues(self): - """Empty residues should not contribute atoms.""" - atom_tensor = torch.full((3, 37, 3), ATOM37_FILL) - # Only residue 0 has atoms - atom_tensor[0, 0, :] = torch.tensor([1.0, 0.0, 0.0]) - - coords, residue_idx, _ = atom37_to_atoms(atom_tensor) - - assert coords.shape == (1, 3) - assert residue_idx[0] == 0 - - def test_all_empty(self): - """All-empty tensor should return empty outputs.""" - atom_tensor = torch.full((5, 37, 3), ATOM37_FILL) - - coords, residue_idx, atom_type = atom37_to_atoms(atom_tensor) - - assert coords.shape == (0, 3) - assert residue_idx.shape == (0,) - assert atom_type.shape == (0,) - - def test_coordinates_preserved(self): - """Coordinates should be preserved exactly.""" - atom_tensor = torch.full((1, 37, 3), ATOM37_FILL) - expected_coord = torch.tensor([1.234, 5.678, 9.012]) - atom_tensor[0, 0, :] = expected_coord - - coords, _, _ = atom37_to_atoms(atom_tensor) +class TestSanitizeResNamesForEsm: + """Tests for ESM residue-name sanitization and residue-count alignment. + + These guard the contract that src/dataset.py's residue counting (via + biotite.get_residue_starts on a sanitized array) stays in lockstep with + scripts/generate_esm_embeddings.py's residue keys, which are built with + THREE_TO_ONE/normalize_ins_code. A drift here desyncs protein_res_idx from + the cached ESM embeddings. + """ + + @staticmethod + def _make_atoms(residues): + """Build a single-atom-per-row AtomArray from (res_id, res_name, ins) tuples.""" + return array( + [ + Atom( + [0.0, 0.0, 0.0], + chain_id="A", + res_id=res_id, + res_name=res_name, + ins_code=ins, + atom_name="CA", + element="C", + ) + for (res_id, res_name, ins) in residues + ] + ) - assert torch.allclose(coords[0], expected_coord) + @staticmethod + def _esm_key_count(atoms): + """Replicate the ESM script's residue-key counting.""" + keys = [] + for i in range(len(atoms)): + key = ( + atoms.chain_id[i], + atoms.res_id[i], + normalize_ins_code(atoms.ins_code[i]), + ) + if key not in keys: + keys.append(key) + return len(keys) + + def test_sanitize_canonicalizes_modified_and_unknown(self): + atoms = self._make_atoms([(1, "MSE", ""), (2, "ALA", ""), (3, "Q2K", "")]) + sanitized = sanitize_res_names_for_esm(atoms) + # MSE -> MET (canonical parent), ALA unchanged, Q2K -> UNK (unknown) + assert list(sanitized.res_name) == ["MET", "ALA", "UNK"] + # original array must be untouched (helper returns a copy) + assert list(atoms.res_name) == ["MSE", "ALA", "Q2K"] + + def test_placeholder_ins_code_desync_is_fixed(self): + # Two atoms that share (chain, res_id, res_name) and differ only by a + # placeholder insertion code ('' vs '.'). ESM keys treat them as one + # residue; raw get_residue_starts would split them into two. + atoms = self._make_atoms([(5, "GLY", ""), (5, "GLY", ".")]) + assert self._esm_key_count(atoms) == 1 + + sanitized = sanitize_res_names_for_esm(atoms) + for i in range(len(sanitized)): + sanitized.ins_code[i] = normalize_ins_code(sanitized.ins_code[i]) + assert len(bts.get_residue_starts(sanitized)) == 1 + + def test_residue_count_matches_esm_keys(self): + # Mix of canonical, modified, unknown residues with placeholder and real + # insertion codes; the dataset count must equal the ESM key count. + atoms = self._make_atoms( + [ + (1, "ALA", ""), + (1, "ALA", ""), + (2, "MSE", "."), + (3, "GLY", "?"), + (3, "GLY", "A"), # real insertion code -> distinct residue + (4, "Q2K", ""), + ] + ) + sanitized = sanitize_res_names_for_esm(atoms) + for i in range(len(sanitized)): + sanitized.ins_code[i] = normalize_ins_code(sanitized.ins_code[i]) + assert len(bts.get_residue_starts(sanitized)) == self._esm_key_count(atoms) @pytest.mark.unit diff --git a/uv.lock b/uv.lock index 2adbd1f..1714044 100644 --- a/uv.lock +++ b/uv.lock @@ -642,6 +642,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/56/6d/0d9848617b9f753b87f214f1c682592f7ca42de085f564352f10f0843026/ipywidgets-8.1.8-py3-none-any.whl", hash = "sha256:ecaca67aed704a338f88f67b1181b58f821ab5dc89c1f0f5ef99db43c1c2921e", size = 139808, upload-time = "2025-11-01T21:18:10.956Z" }, ] +[[package]] +name = "jaxtyping" +version = "0.3.11" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wadler-lindig" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/91/c1/091b8852bd7cbf50bd655543c8506033cf4029300c67f8c176c1286879a9/jaxtyping-0.3.11.tar.gz", hash = "sha256:b09c14acf6686feb9e0df5b0d8c6e7c5b6f8d36bf059ee54cd522a186c2ef050", size = 46489, upload-time = "2026-06-13T18:35:23.167Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8b/38/c66bbdc5047f4776c2bd3e47e5295a350e3fa44d5b8942105e71c2a876a0/jaxtyping-0.3.11-py3-none-any.whl", hash = "sha256:8a4bedc4e3f963fa82df41bd13c7ebc2bad925601eb48614c65798f21329d4e3", size = 56593, upload-time = "2026-06-13T18:35:22.01Z" }, +] + [[package]] name = "jedi" version = "0.19.2" @@ -714,25 +726,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e2/92/5f3068cf15ee5cb624a0c7596e67e2a0bb2adee33f71c379054a491d07da/kiwisolver-1.4.9-cp312-cp312-win_arm64.whl", hash = "sha256:2c1a4f57df73965f3f14df20b80ee29e6a7930a57d2d9e8491a25f676e197c60", size = 64992, upload-time = "2025-08-10T21:26:25.732Z" }, ] -[[package]] -name = "librt" -version = "0.7.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/93/e4/b59bdf1197fdf9888452ea4d2048cdad61aef85eb83e99dc52551d7fdc04/librt-0.7.4.tar.gz", hash = "sha256:3871af56c59864d5fd21d1ac001eb2fb3b140d52ba0454720f2e4a19812404ba", size = 145862, upload-time = "2025-12-15T16:52:43.862Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f3/e7/b805d868d21f425b7e76a0ea71a2700290f2266a4f3c8357fcf73efc36aa/librt-0.7.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7dd3b5c37e0fb6666c27cf4e2c88ae43da904f2155c4cfc1e5a2fdce3b9fcf92", size = 55688, upload-time = "2025-12-15T16:51:31.571Z" }, - { url = "https://files.pythonhosted.org/packages/59/5e/69a2b02e62a14cfd5bfd9f1e9adea294d5bcfeea219c7555730e5d068ee4/librt-0.7.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a9c5de1928c486201b23ed0cc4ac92e6e07be5cd7f3abc57c88a9cf4f0f32108", size = 57141, upload-time = "2025-12-15T16:51:32.714Z" }, - { url = "https://files.pythonhosted.org/packages/6e/6b/05dba608aae1272b8ea5ff8ef12c47a4a099a04d1e00e28a94687261d403/librt-0.7.4-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:078ae52ffb3f036396cc4aed558e5b61faedd504a3c1f62b8ae34bf95ae39d94", size = 165322, upload-time = "2025-12-15T16:51:33.986Z" }, - { url = "https://files.pythonhosted.org/packages/8f/bc/199533d3fc04a4cda8d7776ee0d79955ab0c64c79ca079366fbc2617e680/librt-0.7.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ce58420e25097b2fc201aef9b9f6d65df1eb8438e51154e1a7feb8847e4a55ab", size = 174216, upload-time = "2025-12-15T16:51:35.384Z" }, - { url = "https://files.pythonhosted.org/packages/62/ec/09239b912a45a8ed117cb4a6616d9ff508f5d3131bd84329bf2f8d6564f1/librt-0.7.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b719c8730c02a606dc0e8413287e8e94ac2d32a51153b300baf1f62347858fba", size = 189005, upload-time = "2025-12-15T16:51:36.687Z" }, - { url = "https://files.pythonhosted.org/packages/46/2e/e188313d54c02f5b0580dd31476bb4b0177514ff8d2be9f58d4a6dc3a7ba/librt-0.7.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3749ef74c170809e6dee68addec9d2458700a8de703de081c888e92a8b015cf9", size = 183960, upload-time = "2025-12-15T16:51:37.977Z" }, - { url = "https://files.pythonhosted.org/packages/eb/84/f1d568d254518463d879161d3737b784137d236075215e56c7c9be191cee/librt-0.7.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b35c63f557653c05b5b1b6559a074dbabe0afee28ee2a05b6c9ba21ad0d16a74", size = 177609, upload-time = "2025-12-15T16:51:40.584Z" }, - { url = "https://files.pythonhosted.org/packages/5d/43/060bbc1c002f0d757c33a1afe6bf6a565f947a04841139508fc7cef6c08b/librt-0.7.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1ef704e01cb6ad39ad7af668d51677557ca7e5d377663286f0ee1b6b27c28e5f", size = 199269, upload-time = "2025-12-15T16:51:41.879Z" }, - { url = "https://files.pythonhosted.org/packages/ff/7f/708f8f02d8012ee9f366c07ea6a92882f48bd06cc1ff16a35e13d0fbfb08/librt-0.7.4-cp312-cp312-win32.whl", hash = "sha256:c66c2b245926ec15188aead25d395091cb5c9df008d3b3207268cd65557d6286", size = 43186, upload-time = "2025-12-15T16:51:43.149Z" }, - { url = "https://files.pythonhosted.org/packages/f1/a5/4e051b061c8b2509be31b2c7ad4682090502c0a8b6406edcf8c6b4fe1ef7/librt-0.7.4-cp312-cp312-win_amd64.whl", hash = "sha256:71a56f4671f7ff723451f26a6131754d7c1809e04e22ebfbac1db8c9e6767a20", size = 49455, upload-time = "2025-12-15T16:51:44.336Z" }, - { url = "https://files.pythonhosted.org/packages/d0/d2/90d84e9f919224a3c1f393af1636d8638f54925fdc6cd5ee47f1548461e5/librt-0.7.4-cp312-cp312-win_arm64.whl", hash = "sha256:419eea245e7ec0fe664eb7e85e7ff97dcdb2513ca4f6b45a8ec4a3346904f95a", size = 42828, upload-time = "2025-12-15T16:51:45.498Z" }, -] - [[package]] name = "loguru" version = "0.7.3" @@ -869,36 +862,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/da/7d22601b625e241d4f23ef1ebff8acfc60da633c9e7e7922e24d10f592b3/multidict-6.7.0-py3-none-any.whl", hash = "sha256:394fc5c42a333c9ffc3e421a4c85e08580d990e08b99f6bf35b4132114c5dcb3", size = 12317, upload-time = "2025-10-06T14:52:29.272Z" }, ] -[[package]] -name = "mypy" -version = "1.19.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "librt", marker = "platform_python_implementation != 'PyPy'" }, - { name = "mypy-extensions" }, - { name = "pathspec" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/f5/db/4efed9504bc01309ab9c2da7e352cc223569f05478012b5d9ece38fd44d2/mypy-1.19.1.tar.gz", hash = "sha256:19d88bb05303fe63f71dd2c6270daca27cb9401c4ca8255fe50d1d920e0eb9ba", size = 3582404, upload-time = "2025-12-15T05:03:48.42Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/06/8a/19bfae96f6615aa8a0604915512e0289b1fad33d5909bf7244f02935d33a/mypy-1.19.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a8174a03289288c1f6c46d55cef02379b478bfbc8e358e02047487cad44c6ca1", size = 13206053, upload-time = "2025-12-15T05:03:46.622Z" }, - { url = "https://files.pythonhosted.org/packages/a5/34/3e63879ab041602154ba2a9f99817bb0c85c4df19a23a1443c8986e4d565/mypy-1.19.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ffcebe56eb09ff0c0885e750036a095e23793ba6c2e894e7e63f6d89ad51f22e", size = 12219134, upload-time = "2025-12-15T05:03:24.367Z" }, - { url = "https://files.pythonhosted.org/packages/89/cc/2db6f0e95366b630364e09845672dbee0cbf0bbe753a204b29a944967cd9/mypy-1.19.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b64d987153888790bcdb03a6473d321820597ab8dd9243b27a92153c4fa50fd2", size = 12731616, upload-time = "2025-12-15T05:02:44.725Z" }, - { url = "https://files.pythonhosted.org/packages/00/be/dd56c1fd4807bc1eba1cf18b2a850d0de7bacb55e158755eb79f77c41f8e/mypy-1.19.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c35d298c2c4bba75feb2195655dfea8124d855dfd7343bf8b8c055421eaf0cf8", size = 13620847, upload-time = "2025-12-15T05:03:39.633Z" }, - { url = "https://files.pythonhosted.org/packages/6d/42/332951aae42b79329f743bf1da088cd75d8d4d9acc18fbcbd84f26c1af4e/mypy-1.19.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:34c81968774648ab5ac09c29a375fdede03ba253f8f8287847bd480782f73a6a", size = 13834976, upload-time = "2025-12-15T05:03:08.786Z" }, - { url = "https://files.pythonhosted.org/packages/6f/63/e7493e5f90e1e085c562bb06e2eb32cae27c5057b9653348d38b47daaecc/mypy-1.19.1-cp312-cp312-win_amd64.whl", hash = "sha256:b10e7c2cd7870ba4ad9b2d8a6102eb5ffc1f16ca35e3de6bfa390c1113029d13", size = 10118104, upload-time = "2025-12-15T05:03:10.834Z" }, - { url = "https://files.pythonhosted.org/packages/8d/f4/4ce9a05ce5ded1de3ec1c1d96cf9f9504a04e54ce0ed55cfa38619a32b8d/mypy-1.19.1-py3-none-any.whl", hash = "sha256:f1235f5ea01b7db5468d53ece6aaddf1ad0b88d9e7462b86ef96fe04995d7247", size = 2471239, upload-time = "2025-12-15T05:03:07.248Z" }, -] - -[[package]] -name = "mypy-extensions" -version = "1.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a2/6e/371856a3fb9d31ca8dac321cda606860fa4548858c0cc45d9d1d4ca2628b/mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558", size = 6343, upload-time = "2025-04-22T14:54:24.164Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" }, -] - [[package]] name = "networkx" version = "3.6.1" @@ -1122,15 +1085,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/16/32/f8e3c85d1d5250232a5d3477a2a28cc291968ff175caeadaf3cc19ce0e4a/parso-0.8.5-py2.py3-none-any.whl", hash = "sha256:646204b5ee239c396d040b90f9e272e9a8017c630092bf59980beb62fd033887", size = 106668, upload-time = "2025-08-23T15:15:25.663Z" }, ] -[[package]] -name = "pathspec" -version = "0.12.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043, upload-time = "2023-12-10T22:30:45Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" }, -] - [[package]] name = "pexpect" version = "4.9.0" @@ -1748,8 +1702,8 @@ dependencies = [ { name = "typing-extensions" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/cu126/torch-2.8.0%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ce6e6a1f4803ad62d1fe51cec3fe5ca14bcd8bc7cace7b09d5590f8147fa16ad" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.8.0%2Bcu126-cp312-cp312-win_amd64.whl", hash = "sha256:f6c79eac0018f9d131479ee1b7a68edb030619a316bfbc69275043aa4f338e4c" }, + { url = "https://download-r2.pytorch.org/whl/cu126/torch-2.8.0%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ce6e6a1f4803ad62d1fe51cec3fe5ca14bcd8bc7cace7b09d5590f8147fa16ad", upload-time = "2025-10-01T23:40:02Z" }, + { url = "https://download-r2.pytorch.org/whl/cu126/torch-2.8.0%2Bcu126-cp312-cp312-win_amd64.whl", hash = "sha256:f6c79eac0018f9d131479ee1b7a68edb030619a316bfbc69275043aa4f338e4c", upload-time = "2025-10-01T23:40:33Z" }, ] [[package]] @@ -1987,8 +1941,10 @@ dependencies = [ { name = "biotite" }, { name = "e3nn" }, { name = "esm" }, + { name = "jaxtyping" }, { name = "loguru" }, { name = "matplotlib" }, + { name = "numpy" }, { name = "pandas" }, { name = "pillow" }, { name = "pyg-lib" }, @@ -2004,7 +1960,6 @@ dependencies = [ [package.dev-dependencies] dev = [ - { name = "mypy" }, { name = "prek" }, { name = "pytest" }, { name = "pytest-cov" }, @@ -2017,8 +1972,10 @@ requires-dist = [ { name = "biotite" }, { name = "e3nn" }, { name = "esm" }, + { name = "jaxtyping" }, { name = "loguru" }, { name = "matplotlib" }, + { name = "numpy" }, { name = "pandas" }, { name = "pillow" }, { name = "pyg-lib", index = "https://data.pyg.org/whl/torch-2.8.0+cu126.html" }, @@ -2034,7 +1991,6 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ - { name = "mypy" }, { name = "prek" }, { name = "pytest" }, { name = "pytest-cov" },