From 981b4adebf250c13c568f0b027d307b75a72c0f9 Mon Sep 17 00:00:00 2001 From: vratins Date: Wed, 17 Jun 2026 06:03:39 +0000 Subject: [PATCH 1/6] water sampling refactor --- src/flow.py | 339 ++++++++++++++++++++++++++++++++++++++------- tests/test_flow.py | 121 ++++++++++++++++ 2 files changed, 410 insertions(+), 50 deletions(-) diff --git a/src/flow.py b/src/flow.py index ad5f7e4..7c6511b 100644 --- a/src/flow.py +++ b/src/flow.py @@ -17,7 +17,7 @@ from torch import nn, Tensor from torch_geometric.data import Batch, HeteroData from torch_geometric.nn import knn -from torch_scatter import scatter_mean +from torch_scatter import scatter, scatter_mean from tqdm.auto import tqdm from src.constants import ALL_EDGE_TYPES, EDGE_PP, EDGE_PW, EDGE_WP, EDGE_WW, NUM_RBF @@ -26,6 +26,115 @@ from src.utils import ot_coupling + +def sample_waters_uniform_ball( + protein_pos: Tensor, + batch_p: Tensor, + num_waters: Tensor, + cutoff: float = 8.0, + device: torch.device | None = None, +) -> tuple[Tensor, Tensor]: + """ + Sample water positions uniformly inside balls of radius *cutoff* centred + on randomly chosen protein atoms. + + Every sample is guaranteed within *cutoff* of at least one protein atom. + No rejection sampling — runs in O(1) rounds, fully vectorised. + + Args: + protein_pos: (N_protein, 3) protein coordinates for all graphs + batch_p: (N_protein,) graph indices for protein atoms + num_waters: (num_graphs,) target water count per graph + cutoff: Ball radius in Angstroms + device: Optional output device (defaults to protein_pos.device) + + Returns: + water_pos: (sum(num_waters), 3) sampled positions + batch_w: (sum(num_waters),) graph indices + """ + if device is None: + device = protein_pos.device + + num_graphs = num_waters.numel() + total_waters = int(num_waters.sum().item()) + + if total_waters == 0: + return ( + torch.empty(0, 3, dtype=protein_pos.dtype, device=device), + torch.empty(0, dtype=torch.long, device=device), + ) + + # batch indices for output waters + batch_w = torch.repeat_interleave( + torch.arange(num_graphs, device=device), num_waters.to(device) + ) + + # per-graph protein atom counts and cumulative offsets + num_p_per_graph = scatter( + torch.ones(batch_p.size(0), device=device, dtype=torch.long), + batch_p.to(device), + dim=0, + dim_size=num_graphs, + reduce="sum", + ) + offsets = torch.zeros(num_graphs + 1, dtype=torch.long, device=device) + offsets[1:] = num_p_per_graph.cumsum(dim=0) + + # pick a random protein atom per water (uniform with replacement) + graph_sizes = num_p_per_graph[batch_w] + graph_offsets = offsets[batch_w] + local_idx = (torch.rand(total_waters, device=device) * graph_sizes.float()).long() + anchors = protein_pos.to(device)[graph_offsets + local_idx] + + # uniform direction on the unit sphere + direction = torch.randn(total_waters, 3, device=device, dtype=protein_pos.dtype) + direction = direction / direction.norm(dim=-1, keepdim=True).clamp(min=1e-12) + + # uniform radius inside the ball: r = R * U^(1/3) + r = cutoff * torch.rand(total_waters, 1, device=device, dtype=protein_pos.dtype).pow( + 1.0 / 3.0 + ) + + return anchors + r * direction, batch_w + + +def sample_waters_scaled_gaussian( + num_waters: Tensor, + sigma_per_graph: Tensor, + device: torch.device, + dtype: torch.dtype = torch.float32, +) -> tuple[Tensor, Tensor]: + """ + Sample water positions from N(0, sigma^2 * I) with no rejection. + + Args: + num_waters: (num_graphs,) target water count per graph + sigma_per_graph: (num_graphs,) Gaussian scale per graph + device: Output device + dtype: Output dtype + + Returns: + water_pos: (sum(num_waters), 3) sampled positions + batch_w: (sum(num_waters),) graph indices + """ + num_graphs = num_waters.numel() + total_waters = int(num_waters.sum().item()) + + if total_waters == 0: + return ( + torch.empty(0, 3, dtype=dtype, device=device), + torch.empty(0, dtype=torch.long, device=device), + ) + + batch_w = torch.repeat_interleave( + torch.arange(num_graphs, device=device), num_waters.to(device) + ) + sigma = sigma_per_graph.to(device=device, dtype=dtype)[batch_w].unsqueeze(-1) + water_pos = torch.randn(total_waters, 3, device=device, dtype=dtype) * sigma + + return water_pos, batch_w + + def build_knn_edges( src_pos: torch.Tensor, dst_pos: torch.Tensor, @@ -477,6 +586,9 @@ class FlowMatcher: High level class for flow matching training, validation, and numerical integration """ + SAMPLING_STRATEGIES = {"uniform_ball", "scaled_gaussian"} + DYNAMIC_EDGE_POLICIES = {"auto", "radius", "knn_if_isolated"} + def __init__( self, model, @@ -486,6 +598,8 @@ def __init__( t_distort: float = 0.5, sigma_distort: float = 0.5, loss_eps: float = 1e-3, + sampling_strategy: str = "uniform_ball", + dynamic_edge_policy: str = "auto", ): """ Initialize flow matcher for training and inference. @@ -498,7 +612,21 @@ def __init__( t_distort: Time threshold after which distortion may be applied sigma_distort: Standard deviation of distortion noise loss_eps: Small constant for numerical stability in loss weighting + sampling_strategy: Source distribution for flow matching noise. + "uniform_ball" samples uniformly in balls around protein atoms. + "scaled_gaussian" samples from N(0, sigma^2*I). + dynamic_edge_policy: Runtime policy for dynamic water-edge building. """ + if sampling_strategy not in self.SAMPLING_STRATEGIES: + raise ValueError( + f"sampling_strategy must be one of {self.SAMPLING_STRATEGIES}, " + f"got '{sampling_strategy}'" + ) + if dynamic_edge_policy not in self.DYNAMIC_EDGE_POLICIES: + raise ValueError( + f"dynamic_edge_policy must be one of {self.DYNAMIC_EDGE_POLICIES}, " + f"got '{dynamic_edge_policy}'" + ) self.model = model self.p_self_cond = p_self_cond self.use_distortion = use_distortion @@ -506,6 +634,52 @@ def __init__( self.t_distort = t_distort self.sigma_distort = sigma_distort self.loss_eps = loss_eps + self.graph_cutoff = getattr(model, "cutoff", 8.0) + self.sampling_strategy = sampling_strategy + self.dynamic_edge_policy = dynamic_edge_policy + + @staticmethod + def _num_graphs(data: HeteroData | Batch) -> int: + """Infer graph count without forcing a device sync when batch metadata exists.""" + num_graphs = getattr(data, "num_graphs", None) + if num_graphs is not None: + return int(num_graphs) + batch_p = data["protein"].batch + if batch_p.numel() == 0: + return 0 + return int(batch_p.max().item()) + 1 + + def _sample_waters( + self, + batch_data: HeteroData | Batch, + num_waters: Tensor, + device: torch.device, + ) -> tuple[Tensor, Tensor]: + """Dispatch to the configured sampling strategy.""" + if self.sampling_strategy == "uniform_ball": + return sample_waters_uniform_ball( + protein_pos=batch_data["protein"].pos, + batch_p=batch_data["protein"].batch, + num_waters=num_waters, + cutoff=self.graph_cutoff, + device=device, + ) + # scaled_gaussian + sigma_per_graph = self.compute_sigma_per_graph(batch_data, device) + return sample_waters_scaled_gaussian( + num_waters=num_waters, + sigma_per_graph=sigma_per_graph, + device=device, + dtype=batch_data["protein"].pos.dtype, + ) + + def _effective_dynamic_edge_policy(self) -> str: + """Resolve the dynamic edge policy for the current sampling strategy.""" + if self.dynamic_edge_policy == "auto": + if self.sampling_strategy == "scaled_gaussian": + return "knn_if_isolated" + return "radius" + return self.dynamic_edge_policy @staticmethod def compute_sigma(data: HeteroData) -> float: @@ -547,7 +721,7 @@ def training_step( batch: HeteroData, use_self_conditioning: bool = True, accumulation_steps: int = 1, - ) -> dict[str, float | int | None | dict]: + ) -> dict[str, float | int | None]: """ Single flow matching training step (forward + backward only). @@ -579,15 +753,22 @@ def training_step( self.model.train() device = batch["protein"].pos.device + batch.dynamic_edge_policy = self._effective_dynamic_edge_policy() x1 = batch["water"].pos batch_w = batch["water"].batch batch_p = batch["protein"].batch - num_graphs = int(batch_p.max().item()) + 1 - - sigma = self.compute_sigma(batch) - - x0 = torch.randn_like(x1) * sigma + num_graphs = self._num_graphs(batch) + + sigma_per_graph = self.compute_sigma_per_graph(batch, device) + num_w_per_graph = scatter( + torch.ones(batch_w.size(0), device=device, dtype=torch.long), + batch_w, + dim=0, + dim_size=num_graphs, + reduce="sum", + ) + x0, _ = self._sample_waters(batch, num_w_per_graph, device) x0_star, x1_star = ot_coupling(x1=x1, batch=batch_w, x0=x0) t = torch.rand(num_graphs, device=device) @@ -624,6 +805,12 @@ def training_step( per_atom_mse = (v_pred - v_target).pow(2).mean(dim=-1, keepdim=True) loss = (w * per_atom_mse).sum() / w.sum() + # training RMSD + with torch.no_grad(): + x1_hat = x_t + (1.0 - t_per_atom) * v_pred + diff2 = ((x1_hat - x1_star) ** 2).sum(-1) # (Nw,) + rmsd = torch.sqrt(scatter_mean(diff2, batch_w, dim=0)).mean() + # check for high loss and compute per-sample losses for debugging per_sample_info = None if loss.item() > 100.0: @@ -631,28 +818,17 @@ def training_step( from torch_scatter import scatter_add weighted_mse = (w * per_atom_mse).squeeze(-1) - # compute per-graph loss: sum(weighted_mse) / sum(w) for each graph numerator = scatter_add(weighted_mse, batch_w, dim=0) denominator = scatter_add(w.squeeze(-1), batch_w, dim=0) per_sample_loss = numerator / (denominator + 1e-8) per_sample_info = {"losses": per_sample_loss, "num_graphs": num_graphs} - # backward (scale loss for gradient accumulation) (loss / accumulation_steps).backward() - # training RMSD - with torch.no_grad(): - x1_hat = x_t + (1.0 - t_per_atom) * v_pred - # rmsd = compute_rmsd(x1_hat, x1_star) - - # on-gpu version of rmsd - diff2 = ((x1_hat - x1_star) ** 2).sum(-1) # (Nw,) - rmsd = torch.sqrt(scatter_mean(diff2, batch_w, dim=0)).mean().item() - return { "loss": loss.item(), - "rmsd": rmsd, - "sigma": sigma, + "rmsd": rmsd.item(), + "sigma": sigma_per_graph, "per_sample_info": per_sample_info, } @@ -674,14 +850,21 @@ def validation_step(self, batch: HeteroData) -> dict[str, float]: """ self.model.eval() device = batch["protein"].pos.device + batch.dynamic_edge_policy = self._effective_dynamic_edge_policy() x1 = batch["water"].pos batch_w = batch["water"].batch batch_p = batch["protein"].batch - num_graphs = int(batch_p.max().item()) + 1 - - sigma = self.compute_sigma(batch) - x0 = torch.randn_like(x1) * sigma + num_graphs = self._num_graphs(batch) + + num_w_per_graph = scatter( + torch.ones(batch_w.size(0), device=device, dtype=torch.long), + batch_w, + dim=0, + dim_size=num_graphs, + reduce="sum", + ) + x0, _ = self._sample_waters(batch, num_w_per_graph, device) x0_star, x1_star = ot_coupling(x1=x1, batch=batch_w, x0=x0) t = torch.rand(num_graphs, device=device) @@ -690,6 +873,7 @@ def validation_step(self, batch: HeteroData) -> dict[str, float]: batch["water"].pos = x_t v_pred = self.model(batch, t, self_cond=None) + v_target = x1_star - x0_star w = 1.0 / (self.loss_eps + (1.0 - t_per_atom)) @@ -699,9 +883,12 @@ def validation_step(self, batch: HeteroData) -> dict[str, float]: # GPU RMSD x1_hat = x_t + (1.0 - t_per_atom) * v_pred diff2 = ((x1_hat - x1_star) ** 2).sum(-1) # (Nw,) - rmsd = torch.sqrt(scatter_mean(diff2, batch_w, dim=0)).mean().item() + rmsd = torch.sqrt(scatter_mean(diff2, batch_w, dim=0)).mean() - return {"loss": loss.item(), "rmsd": rmsd} + return { + "loss": loss.item(), + "rmsd": rmsd.item(), + } def _setup_water_nodes_from_ratio( self, @@ -722,23 +909,55 @@ def _setup_water_nodes_from_ratio( batch_w: (N_water_total,) batch indices """ num_residues = g["protein"].num_residues # (num_graphs,) - num_graphs = num_residues.size(0) # compute waters per graph: num_residues * ratio, minimum 1 num_waters = (num_residues.float() * water_ratio).long().clamp(min=1) - # create batch indices (vectorized) - batch_w = torch.repeat_interleave( - torch.arange(num_graphs, device=device), num_waters - ) + x, batch_w = self._sample_waters(g, num_waters, device) total_waters = batch_w.size(0) - # compute sigma per graph and expand to per-water - sigma_per_graph = self.compute_sigma_per_graph(g, device) - sigma_per_water = sigma_per_graph[batch_w] + # create water features (oxygen one-hot, index 2 for 'O' in ELEMENT_VOCAB) + water_x = torch.zeros(total_waters, 16, device=device) + water_x[:, 2] = 1.0 # oxygen is index 2 in ELEMENT_VOCAB - # sample noise - x = torch.randn(total_waters, 3, device=device) * sigma_per_water.unsqueeze(-1) + # update graph with new water nodes + g["water"].pos = x + g["water"].x = water_x + g["water"].batch = batch_w + g["water"].num_nodes = total_waters + + return x, batch_w + + def _setup_water_nodes_from_count( + self, + g: Batch, + water_count: int, + device: torch.device, + ) -> tuple[Tensor, Tensor]: + """ + Create water node positions and batch indices using a fixed count per protein. + + Args: + g: Batched HeteroData graph (modified in-place) + water_count: Exact number of waters to sample per protein + device: Device to create tensors on + + Returns: + x: (N_water_total, 3) initial noise positions + batch_w: (N_water_total,) batch indices + """ + num_residues = g["protein"].num_residues # (num_graphs,) + num_graphs = num_residues.size(0) + + num_waters = torch.full( + (num_graphs,), + water_count, + dtype=torch.long, + device=device, + ) + + x, batch_w = self._sample_waters(g, num_waters, device) + total_waters = batch_w.size(0) # create water features (oxygen one-hot, index 2 for 'O' in ELEMENT_VOCAB) water_x = torch.zeros(total_waters, 16, device=device) @@ -761,6 +980,7 @@ def euler_integrate( sc_ema_alpha: float = 0.2, device: str | torch.device = "cuda", water_ratio: float | None = None, + water_count: int | None = None, ) -> list[dict[str, np.ndarray]]: """ Euler integration from noise to final positions. @@ -773,11 +993,12 @@ def euler_integrate( device: Device to run on water_ratio: If provided, sample num_residues * water_ratio waters instead of using ground truth water count + water_count: If provided, sample exactly this many waters per protein Returns: List of dicts, one per input graph, each with keys: 'protein_pos': (Np, 3) - includes both ASU and mate atoms - 'water_true': (Nw, 3) - None if water_ratio is used + 'water_true': (Nw, 3) - None if water_ratio/water_count is used 'water_pred': (Nw, 3) final prediction 'pdb_id': PDB identifier """ @@ -800,19 +1021,26 @@ def euler_integrate( x1_true = g["water"].pos.clone() batch_w_true = g["water"].batch.clone() - if water_ratio is not None: + if water_count is not None: + # sample fixed number of waters per protein + x, batch_w = self._setup_water_nodes_from_count(g, water_count, device) + num_graphs = g["protein"].num_residues.size(0) + elif water_ratio is not None: # sample waters based on residue count x, batch_w = self._setup_water_nodes_from_ratio(g, water_ratio, device) num_graphs = g["protein"].num_residues.size(0) else: # use existing water nodes batch_w = g["water"].batch - num_graphs = int(batch_w.max().item()) + 1 - sigma_per_graph = self.compute_sigma_per_graph(g, device) - sigma_per_water = sigma_per_graph[batch_w] - x = torch.randn( - g["water"].num_nodes, 3, device=device - ) * sigma_per_water.unsqueeze(-1) + num_graphs = len(graphs) + num_waters = scatter( + torch.ones(batch_w.size(0), device=device, dtype=torch.long), + batch_w, + dim=0, + dim_size=num_graphs, + reduce="sum", + ) + x, batch_w = self._sample_waters(g, num_waters, device) x1_pred_ema = x.clone() @@ -872,6 +1100,7 @@ def rk4_integrate( device: str | torch.device = "cuda", return_trajectory: bool = True, water_ratio: float | None = None, + water_count: int | None = None, ) -> list[dict[str, np.ndarray]]: """ RK4 integration from noise to final positions. @@ -885,11 +1114,12 @@ def rk4_integrate( return_trajectory: Whether to return full trajectory and metrics water_ratio: If provided, sample num_residues * water_ratio waters instead of using ground truth water count + water_count: If provided, sample exactly this many waters per protein Returns: List of dicts, one per input graph, each with keys: 'protein_pos': (Np, 3) - includes both ASU and mate atoms - 'water_true': (Nw, 3) - None if water_ratio is used + 'water_true': (Nw, 3) - None if water_ratio/water_count is used 'water_pred': (Nw, 3) final prediction 'trajectory': list of (Nw, 3) at each step (if return_trajectory=True) """ @@ -912,17 +1142,26 @@ def rk4_integrate( x1_true = g["water"].pos.clone() batch_w_true = g["water"].batch.clone() - if water_ratio is not None: + if water_count is not None: + # sample fixed number of waters per protein + x, batch_w = self._setup_water_nodes_from_count(g, water_count, device) + num_graphs = g["protein"].num_residues.size(0) + elif water_ratio is not None: # sample waters based on residue count x, batch_w = self._setup_water_nodes_from_ratio(g, water_ratio, device) num_graphs = g["protein"].num_residues.size(0) else: # use existing water nodes batch_w = g["water"].batch - num_graphs = int(batch_w.max().item()) + 1 - sigma_per_graph = self.compute_sigma_per_graph(g, device) - sigma_per_water = sigma_per_graph[batch_w] - x = torch.randn_like(x1_true) * sigma_per_water.unsqueeze(-1) + num_graphs = len(graphs) + num_waters = scatter( + torch.ones(batch_w.size(0), device=device, dtype=torch.long), + batch_w, + dim=0, + dim_size=num_graphs, + reduce="sum", + ) + x, batch_w = self._sample_waters(g, num_waters, device) x1_pred_ema = x.clone() diff --git a/tests/test_flow.py b/tests/test_flow.py index a3a519c..312870a 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -16,6 +16,8 @@ FlowMatcher, FlowWaterGVP, ProteinWaterUpdate, + sample_waters_scaled_gaussian, + sample_waters_uniform_ball, ) from src.gvp_encoder import GVPEncoder, make_gvp_encoder_data, ProteinGVPEncoder @@ -400,6 +402,23 @@ def test_validation_step(self, flow_matcher, simple_hetero_data): assert "rmsd" in result assert result["loss"] >= 0 + def test_scaled_gaussian_auto_policy_enables_knn_fallback( + self, device, gvp_encoder + ): + model = FlowWaterGVP( + encoder=gvp_encoder, + hidden_dims=(64, 8), + layers=1, + ).to(device) + + flow_matcher = FlowMatcher( + model, + sampling_strategy="scaled_gaussian", + dynamic_edge_policy="auto", + ) + + assert flow_matcher._effective_dynamic_edge_policy() == "knn_if_isolated" + @pytest.mark.slow def test_euler_integrate(self, flow_matcher, simple_hetero_data, device): results = flow_matcher.euler_integrate( @@ -452,6 +471,108 @@ def test_sample_rk4(self, flow_matcher, simple_hetero_data, device): assert water_pred.shape == (n_water, 3) +# ============== Tests for water sampling strategies ============== + + +@pytest.mark.unit +class TestUniformBallSampling: + def test_shapes_and_counts(self, device): + torch.manual_seed(0) + protein_pos = torch.tensor( + [[0.0, 0.0, 0.0], [0.5, 0.0, 0.0], + [0.0, 0.0, 0.0], [0.5, 0.0, 0.0]], + device=device, + ) + batch_p = torch.tensor([0, 0, 1, 1], dtype=torch.long, device=device) + num_waters = torch.tensor([4, 3], dtype=torch.long, device=device) + + pos, batch_w = sample_waters_uniform_ball( + protein_pos=protein_pos, batch_p=batch_p, + num_waters=num_waters, cutoff=2.0, device=device, + ) + + assert pos.shape == (7, 3) + assert batch_w.shape == (7,) + assert (batch_w == 0).sum().item() == 4 + assert (batch_w == 1).sum().item() == 3 + + def test_all_within_cutoff(self, device): + torch.manual_seed(42) + protein_pos = torch.randn(20, 3, device=device) * 50 + batch_p = torch.cat([torch.zeros(10), torch.ones(10)]).long().to(device) + num_waters = torch.tensor([50, 50], dtype=torch.long, device=device) + cutoff = 8.0 + + pos, batch_w = sample_waters_uniform_ball( + protein_pos=protein_pos, batch_p=batch_p, + num_waters=num_waters, cutoff=cutoff, device=device, + ) + + for g in range(2): + g_waters = pos[batch_w == g] + g_protein = protein_pos[batch_p == g] + dists = torch.cdist(g_waters, g_protein) + assert dists.min(dim=1).values.max().item() <= cutoff + 1e-5 + + def test_empty_waters(self, device): + protein_pos = torch.randn(5, 3, device=device) + batch_p = torch.zeros(5, dtype=torch.long, device=device) + num_waters = torch.tensor([0], dtype=torch.long, device=device) + + pos, batch_w = sample_waters_uniform_ball( + protein_pos=protein_pos, batch_p=batch_p, + num_waters=num_waters, cutoff=8.0, device=device, + ) + + assert pos.shape == (0, 3) + assert batch_w.shape == (0,) + + def test_large_spread_protein_succeeds(self, device): + """The scenario that crashes truncated Gaussian (sigma~50) works here.""" + torch.manual_seed(0) + protein_pos = torch.randn(500, 3, device=device) * 50 + batch_p = torch.zeros(500, dtype=torch.long, device=device) + num_waters = torch.tensor([301], dtype=torch.long, device=device) + + pos, batch_w = sample_waters_uniform_ball( + protein_pos=protein_pos, batch_p=batch_p, + num_waters=num_waters, cutoff=8.0, device=device, + ) + + assert pos.shape == (301, 3) + assert batch_w.shape == (301,) + + +@pytest.mark.unit +class TestScaledGaussianSampling: + def test_shapes_and_counts(self, device): + torch.manual_seed(0) + num_waters = torch.tensor([4, 3], dtype=torch.long, device=device) + sigma = torch.tensor([1.0, 2.0], device=device) + + pos, batch_w = sample_waters_scaled_gaussian( + num_waters=num_waters, sigma_per_graph=sigma, + device=device, dtype=torch.float32, + ) + + assert pos.shape == (7, 3) + assert batch_w.shape == (7,) + assert (batch_w == 0).sum().item() == 4 + assert (batch_w == 1).sum().item() == 3 + + def test_empty_waters(self, device): + num_waters = torch.tensor([0], dtype=torch.long, device=device) + sigma = torch.tensor([1.0], device=device) + + pos, batch_w = sample_waters_scaled_gaussian( + num_waters=num_waters, sigma_per_graph=sigma, + device=device, dtype=torch.float32, + ) + + assert pos.shape == (0, 3) + assert batch_w.shape == (0,) + + # ============== Tests for distortion ============== From 6fff7113a17ad0c73543ecac4cef2ba962fe6575 Mon Sep 17 00:00:00 2001 From: vratins Date: Wed, 17 Jun 2026 06:15:15 +0000 Subject: [PATCH 2/6] more tests --- tests/test_flow.py | 49 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/test_flow.py b/tests/test_flow.py index 312870a..4ca4ea7 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -542,6 +542,55 @@ def test_large_spread_protein_succeeds(self, device): assert pos.shape == (301, 3) assert batch_w.shape == (301,) + def test_real_structure_cutoff_and_batch(self, device, pdb_6eey): + """Cutoff guarantee holds on real protein geometry; batch indexing is correct + when two structures with different water counts are packed into one call.""" + import biotite.structure as bts + from biotite.structure.io.pdb import PDBFile, get_structure + + pdb_file = PDBFile.read(pdb_6eey) + atoms = get_structure(pdb_file, model=1, altloc="occupancy") + atoms = atoms[atoms.element != "H"] + protein_atoms = atoms[bts.filter_amino_acids(atoms)] + + protein_pos_np = protein_atoms.coord # (N, 3) float64 + n_atoms = len(protein_pos_np) + + # batch two copies: graph 0 gets 50 waters, graph 1 gets 30 + protein_pos = torch.tensor(protein_pos_np, dtype=torch.float32, device=device) + protein_pos_both = torch.cat([protein_pos, protein_pos], dim=0) + batch_p = torch.cat([ + torch.zeros(n_atoms, dtype=torch.long, device=device), + torch.ones(n_atoms, dtype=torch.long, device=device), + ]) + num_waters = torch.tensor([50, 30], dtype=torch.long, device=device) + cutoff = 8.0 + + pos, batch_w = sample_waters_uniform_ball( + protein_pos=protein_pos_both, + batch_p=batch_p, + num_waters=num_waters, + cutoff=cutoff, + device=device, + ) + + # correct total count and per-graph split + assert pos.shape == (80, 3) + assert batch_w.shape == (80,) + assert (batch_w == 0).sum().item() == 50 + assert (batch_w == 1).sum().item() == 30 + + # every water must be within cutoff of at least one protein atom in its graph + for g, n_w in enumerate(num_waters.tolist()): + g_waters = pos[batch_w == g] # (n_w, 3) + g_protein = protein_pos_both[batch_p == g] # (n_atoms, 3) + dists = torch.cdist(g_waters, g_protein) # (n_w, n_atoms) + min_dists = dists.min(dim=1).values # (n_w,) + assert min_dists.max().item() <= cutoff + 1e-4, ( + f"Graph {g}: water too far from protein " + f"(max dist {min_dists.max().item():.4f} > {cutoff})" + ) + @pytest.mark.unit class TestScaledGaussianSampling: From 056283ce218645b62af486bc40ff8390bd498c93 Mon Sep 17 00:00:00 2001 From: vratins Date: Wed, 17 Jun 2026 07:11:54 +0000 Subject: [PATCH 3/6] linting fixes --- src/flow.py | 5 +---- tests/test_flow.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/flow.py b/src/flow.py index 7c6511b..f4fa9b6 100644 --- a/src/flow.py +++ b/src/flow.py @@ -26,7 +26,6 @@ from src.utils import ot_coupling - def sample_waters_uniform_ball( protein_pos: Tensor, batch_p: Tensor, @@ -721,7 +720,7 @@ def training_step( batch: HeteroData, use_self_conditioning: bool = True, accumulation_steps: int = 1, - ) -> dict[str, float | int | None]: + ) -> dict[str, object]: """ Single flow matching training step (forward + backward only). @@ -757,7 +756,6 @@ def training_step( x1 = batch["water"].pos batch_w = batch["water"].batch - batch_p = batch["protein"].batch num_graphs = self._num_graphs(batch) sigma_per_graph = self.compute_sigma_per_graph(batch, device) @@ -854,7 +852,6 @@ def validation_step(self, batch: HeteroData) -> dict[str, float]: x1 = batch["water"].pos batch_w = batch["water"].batch - batch_p = batch["protein"].batch num_graphs = self._num_graphs(batch) num_w_per_graph = scatter( diff --git a/tests/test_flow.py b/tests/test_flow.py index 4ca4ea7..9f4f1cb 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -546,7 +546,7 @@ def test_real_structure_cutoff_and_batch(self, device, pdb_6eey): """Cutoff guarantee holds on real protein geometry; batch indexing is correct when two structures with different water counts are packed into one call.""" import biotite.structure as bts - from biotite.structure.io.pdb import PDBFile, get_structure + from biotite.structure.io.pdb import get_structure, PDBFile pdb_file = PDBFile.read(pdb_6eey) atoms = get_structure(pdb_file, model=1, altloc="occupancy") From 6317df5b5fd9d4bee93dce90f17a650da6d2a9d1 Mon Sep 17 00:00:00 2001 From: vratins <114123331+vratins@users.noreply.github.com> Date: Wed, 17 Jun 2026 07:13:00 +0000 Subject: [PATCH 4/6] Auto-commit ruff fixes [skip ci] --- src/flow.py | 6 ++--- tests/test_flow.py | 61 +++++++++++++++++++++++++++++----------------- 2 files changed, 42 insertions(+), 25 deletions(-) diff --git a/src/flow.py b/src/flow.py index f4fa9b6..0802028 100644 --- a/src/flow.py +++ b/src/flow.py @@ -90,9 +90,9 @@ def sample_waters_uniform_ball( direction = direction / direction.norm(dim=-1, keepdim=True).clamp(min=1e-12) # uniform radius inside the ball: r = R * U^(1/3) - r = cutoff * torch.rand(total_waters, 1, device=device, dtype=protein_pos.dtype).pow( - 1.0 / 3.0 - ) + r = cutoff * torch.rand( + total_waters, 1, device=device, dtype=protein_pos.dtype + ).pow(1.0 / 3.0) return anchors + r * direction, batch_w diff --git a/tests/test_flow.py b/tests/test_flow.py index 9f4f1cb..913f829 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -479,16 +479,18 @@ class TestUniformBallSampling: def test_shapes_and_counts(self, device): torch.manual_seed(0) protein_pos = torch.tensor( - [[0.0, 0.0, 0.0], [0.5, 0.0, 0.0], - [0.0, 0.0, 0.0], [0.5, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.5, 0.0, 0.0], [0.0, 0.0, 0.0], [0.5, 0.0, 0.0]], device=device, ) batch_p = torch.tensor([0, 0, 1, 1], dtype=torch.long, device=device) num_waters = torch.tensor([4, 3], dtype=torch.long, device=device) pos, batch_w = sample_waters_uniform_ball( - protein_pos=protein_pos, batch_p=batch_p, - num_waters=num_waters, cutoff=2.0, device=device, + protein_pos=protein_pos, + batch_p=batch_p, + num_waters=num_waters, + cutoff=2.0, + device=device, ) assert pos.shape == (7, 3) @@ -504,8 +506,11 @@ def test_all_within_cutoff(self, device): cutoff = 8.0 pos, batch_w = sample_waters_uniform_ball( - protein_pos=protein_pos, batch_p=batch_p, - num_waters=num_waters, cutoff=cutoff, device=device, + protein_pos=protein_pos, + batch_p=batch_p, + num_waters=num_waters, + cutoff=cutoff, + device=device, ) for g in range(2): @@ -520,8 +525,11 @@ def test_empty_waters(self, device): num_waters = torch.tensor([0], dtype=torch.long, device=device) pos, batch_w = sample_waters_uniform_ball( - protein_pos=protein_pos, batch_p=batch_p, - num_waters=num_waters, cutoff=8.0, device=device, + protein_pos=protein_pos, + batch_p=batch_p, + num_waters=num_waters, + cutoff=8.0, + device=device, ) assert pos.shape == (0, 3) @@ -535,8 +543,11 @@ def test_large_spread_protein_succeeds(self, device): num_waters = torch.tensor([301], dtype=torch.long, device=device) pos, batch_w = sample_waters_uniform_ball( - protein_pos=protein_pos, batch_p=batch_p, - num_waters=num_waters, cutoff=8.0, device=device, + protein_pos=protein_pos, + batch_p=batch_p, + num_waters=num_waters, + cutoff=8.0, + device=device, ) assert pos.shape == (301, 3) @@ -559,10 +570,12 @@ def test_real_structure_cutoff_and_batch(self, device, pdb_6eey): # batch two copies: graph 0 gets 50 waters, graph 1 gets 30 protein_pos = torch.tensor(protein_pos_np, dtype=torch.float32, device=device) protein_pos_both = torch.cat([protein_pos, protein_pos], dim=0) - batch_p = torch.cat([ - torch.zeros(n_atoms, dtype=torch.long, device=device), - torch.ones(n_atoms, dtype=torch.long, device=device), - ]) + batch_p = torch.cat( + [ + torch.zeros(n_atoms, dtype=torch.long, device=device), + torch.ones(n_atoms, dtype=torch.long, device=device), + ] + ) num_waters = torch.tensor([50, 30], dtype=torch.long, device=device) cutoff = 8.0 @@ -582,10 +595,10 @@ def test_real_structure_cutoff_and_batch(self, device, pdb_6eey): # every water must be within cutoff of at least one protein atom in its graph for g, n_w in enumerate(num_waters.tolist()): - g_waters = pos[batch_w == g] # (n_w, 3) - g_protein = protein_pos_both[batch_p == g] # (n_atoms, 3) - dists = torch.cdist(g_waters, g_protein) # (n_w, n_atoms) - min_dists = dists.min(dim=1).values # (n_w,) + g_waters = pos[batch_w == g] # (n_w, 3) + g_protein = protein_pos_both[batch_p == g] # (n_atoms, 3) + dists = torch.cdist(g_waters, g_protein) # (n_w, n_atoms) + min_dists = dists.min(dim=1).values # (n_w,) assert min_dists.max().item() <= cutoff + 1e-4, ( f"Graph {g}: water too far from protein " f"(max dist {min_dists.max().item():.4f} > {cutoff})" @@ -600,8 +613,10 @@ def test_shapes_and_counts(self, device): sigma = torch.tensor([1.0, 2.0], device=device) pos, batch_w = sample_waters_scaled_gaussian( - num_waters=num_waters, sigma_per_graph=sigma, - device=device, dtype=torch.float32, + num_waters=num_waters, + sigma_per_graph=sigma, + device=device, + dtype=torch.float32, ) assert pos.shape == (7, 3) @@ -614,8 +629,10 @@ def test_empty_waters(self, device): sigma = torch.tensor([1.0], device=device) pos, batch_w = sample_waters_scaled_gaussian( - num_waters=num_waters, sigma_per_graph=sigma, - device=device, dtype=torch.float32, + num_waters=num_waters, + sigma_per_graph=sigma, + device=device, + dtype=torch.float32, ) assert pos.shape == (0, 3) From 15dbee847eb71c993f0be9ed5baec13c33b25657 Mon Sep 17 00:00:00 2001 From: vratins Date: Wed, 24 Jun 2026 23:22:01 +0000 Subject: [PATCH 5/6] addressing comments --- src/flow.py | 32 ++++++++++++++++++++++++++++---- tests/test_flow.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/src/flow.py b/src/flow.py index 0802028..02ca153 100644 --- a/src/flow.py +++ b/src/flow.py @@ -76,6 +76,19 @@ def sample_waters_uniform_ball( dim_size=num_graphs, reduce="sum", ) + + # fail fast: a graph that requests waters must have at least one protein atom. + # Otherwise graph_sizes is 0 and graph_offsets + local_idx would index into a + # neighbouring graph's atoms (or out of bounds) when picking anchors below. + num_waters_dev = num_waters.to(device) + empty_protein = (num_waters_dev > 0) & (num_p_per_graph == 0) + if empty_protein.any(): + bad = torch.nonzero(empty_protein, as_tuple=False).flatten().tolist() + raise ValueError( + f"Cannot sample waters for graph(s) {bad}: requested " + f"{num_waters_dev[bad].tolist()} water(s) but they have zero protein atoms." + ) + offsets = torch.zeros(num_graphs + 1, dtype=torch.long, device=device) offsets[1:] = num_p_per_graph.cumsum(dim=0) @@ -585,8 +598,8 @@ class FlowMatcher: High level class for flow matching training, validation, and numerical integration """ - SAMPLING_STRATEGIES = {"uniform_ball", "scaled_gaussian"} - DYNAMIC_EDGE_POLICIES = {"auto", "radius", "knn_if_isolated"} + SAMPLING_STRATEGIES = ("uniform_ball", "scaled_gaussian") + DYNAMIC_EDGE_POLICIES = ("auto", "radius", "knn_if_isolated") def __init__( self, @@ -943,6 +956,9 @@ def _setup_water_nodes_from_count( x: (N_water_total, 3) initial noise positions batch_w: (N_water_total,) batch indices """ + if water_count < 0: + raise ValueError(f"water_count must be >= 0, got {water_count}") + num_residues = g["protein"].num_residues # (num_graphs,) num_graphs = num_residues.size(0) @@ -995,7 +1011,8 @@ def euler_integrate( Returns: List of dicts, one per input graph, each with keys: 'protein_pos': (Np, 3) - includes both ASU and mate atoms - 'water_true': (Nw, 3) - None if water_ratio/water_count is used + 'water_true': (Nw, 3) ground-truth waters (always returned; when + water_ratio/water_count is set its count may differ from water_pred) 'water_pred': (Nw, 3) final prediction 'pdb_id': PDB identifier """ @@ -1038,6 +1055,9 @@ def euler_integrate( reduce="sum", ) x, batch_w = self._sample_waters(g, num_waters, device) + # keep the graph's water batch in sync with the resampled layout so + # the model expands t against the correct per-water graph indices + g["water"].batch = batch_w x1_pred_ema = x.clone() @@ -1116,7 +1136,8 @@ def rk4_integrate( Returns: List of dicts, one per input graph, each with keys: 'protein_pos': (Np, 3) - includes both ASU and mate atoms - 'water_true': (Nw, 3) - None if water_ratio/water_count is used + 'water_true': (Nw, 3) ground-truth waters (always returned; when + water_ratio/water_count is set its count may differ from water_pred) 'water_pred': (Nw, 3) final prediction 'trajectory': list of (Nw, 3) at each step (if return_trajectory=True) """ @@ -1159,6 +1180,9 @@ def rk4_integrate( reduce="sum", ) x, batch_w = self._sample_waters(g, num_waters, device) + # keep the graph's water batch in sync with the resampled layout so + # the model expands t against the correct per-water graph indices + g["water"].batch = batch_w x1_pred_ema = x.clone() diff --git a/tests/test_flow.py b/tests/test_flow.py index 913f829..1a84666 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -535,6 +535,22 @@ def test_empty_waters(self, device): assert pos.shape == (0, 3) assert batch_w.shape == (0,) + def test_zero_protein_graph_raises(self, device): + """Requesting waters for a graph with no protein atoms fails fast.""" + # graph 0 has protein atoms, graph 1 has none but requests waters + protein_pos = torch.randn(5, 3, device=device) + batch_p = torch.zeros(5, dtype=torch.long, device=device) + num_waters = torch.tensor([3, 4], dtype=torch.long, device=device) + + with pytest.raises(ValueError, match="zero protein atoms"): + sample_waters_uniform_ball( + protein_pos=protein_pos, + batch_p=batch_p, + num_waters=num_waters, + cutoff=8.0, + device=device, + ) + def test_large_spread_protein_succeeds(self, device): """The scenario that crashes truncated Gaussian (sigma~50) works here.""" torch.manual_seed(0) @@ -553,6 +569,7 @@ def test_large_spread_protein_succeeds(self, device): assert pos.shape == (301, 3) assert batch_w.shape == (301,) + @pytest.mark.slow def test_real_structure_cutoff_and_batch(self, device, pdb_6eey): """Cutoff guarantee holds on real protein geometry; batch indexing is correct when two structures with different water counts are packed into one call.""" @@ -639,6 +656,17 @@ def test_empty_waters(self, device): assert batch_w.shape == (0,) +@pytest.mark.unit +class TestWaterCountValidation: + def test_negative_water_count_raises(self, device): + """A negative water_count is rejected before any sampling work.""" + fm = FlowMatcher(model=Mock(cutoff=8.0)) + g = HeteroData() # guard fires before touching graph contents + + with pytest.raises(ValueError, match="water_count must be >= 0"): + fm._setup_water_nodes_from_count(g, -1, device) + + # ============== Tests for distortion ============== From 350bba9dc1bd66979a90afe5866d1a3b6b4b1e94 Mon Sep 17 00:00:00 2001 From: vratins Date: Wed, 24 Jun 2026 23:29:05 +0000 Subject: [PATCH 6/6] addressing comment for batch index offset --- src/flow.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/flow.py b/src/flow.py index 02ca153..d1ae024 100644 --- a/src/flow.py +++ b/src/flow.py @@ -68,10 +68,16 @@ def sample_waters_uniform_ball( torch.arange(num_graphs, device=device), num_waters.to(device) ) + # offsets below assume protein atoms are grouped contiguously by graph; + # interleaved batch_p would pick anchors from the wrong graph. + batch_p = batch_p.to(device) + if batch_p.numel() > 1 and (batch_p[1:] < batch_p[:-1]).any(): + raise ValueError("batch_p must be sorted (non-decreasing) by graph index.") + # per-graph protein atom counts and cumulative offsets num_p_per_graph = scatter( torch.ones(batch_p.size(0), device=device, dtype=torch.long), - batch_p.to(device), + batch_p, dim=0, dim_size=num_graphs, reduce="sum",