Skip to content

diff-use/SuperWater

 
 

Repository files navigation

SuperWater

DOI

SuperWater predicts water-molecule positions on protein surfaces using a score-based diffusion model with equivariant neural networks (e3nn): random water "particles" are moved onto hydration sites by reverse diffusion, a confidence model scores them, and a clustering step produces the final waters.

Paper: Communications Chemistry (article) · bioRxiv preprint · Contact: xiaohan.kuang@takeda.com, zhaoqian.su@takeda.com

Detailed running reference: docs/RUNNING.md documents every runnable workflow end to end — setup, data prep, score training, score inference, confidence setup/training, and full-pipeline inference — with the exact commands, flags, and outputs.

Installation

Requires an NVIDIA GPU (CPU is not supported).

Recommended: uv (reproducible, locked)

For a CUDA 12.6 driver, one command resolves the full stack from uv.lock (PyTorch 2.8 + the matching PyTorch Geometric extension wheels, e3nn, rdkit, fair-esm, …) and installs superwater itself:

uv sync --extra cu126          # add --extra dev for pytest

Then run any command through the locked environment, e.g.:

uv run superwater-predict --config examples/configs/predict_5srf.yaml
uv run python -m pytest          # add --extra dev first

The PyTorch / PyG indexes and pins live in pyproject.toml ([tool.uv]). torch-geometric is pinned to 2.6.1 on purpose — 2.7 made BaseTransform.forward abstract, which breaks the dataset's NoiseTransform.

Different CUDA version? Edit the two [[tool.uv.index]] URLs in pyproject.toml (download.pytorch.org/whl/<cuXXX> and data.pyg.org/whl/torch-<ver>+<cuXXX>.html) to a matching pair, then re-run uv lock && uv sync --extra cu126. Or use the conda path.

Alternative: conda (CUDA 11.8)

bash scripts/install.sh
conda activate superwater

This creates the superwater conda env (PyTorch 2.5.1+cu118, e3nn, rdkit), installs the PyTorch Geometric CUDA wheels, and runs pip install -e .. Equivalent manual steps:

conda env create -f environment.yml
conda activate superwater
pip install -r requirements-pyg-cu118.txt
pip install -e .

Check the GPU with python scripts/check_gpu.py (or uv run python scripts/check_gpu.py).

Quick start

Pretrained models ship under models/ and an example input folder is bundled. Predict waters for every structure in a folder with one command:

superwater-predict --config examples/configs/predict_5srf.yaml

Output is concise by default; add --verbose for detailed progress and paths, or --debug for library warnings and full tracebacks (or set runtime.verbosity: quiet|normal|verbose|debug in the config).

The first run downloads the ESM-2 model (~2.5 GB) to ~/.cache/torch; embeddings are generated in-process. Put one or many .pdb/.cif/.mmcif files in input.structure_dir to run a batch — CIF/mmCIF inputs are converted automatically and unsupported files are skipped. Outputs for each input <name> go to outputs/predictions/<name>/:

  • <name>_centroid.pdb (or .cif) — final predicted waters. With include_protein: true (the example default) it also contains the input protein, with waters added as HOH on a separate chain.
  • <name>_centroid.txt — final water coordinates (xyz).
  • <name>_filtered.txt — all sampled positions + scores, written only when save_filtered: true (off by default).

Config

examples/configs/predict_5srf.yaml:

input:
  structure_dir: examples/data/batch_structures  # folder of .pdb/.cif/.mmcif files
  name: null              # optional name override (single-file input only)

models:
  score_model_dir: models/water_score_res15
  confidence_model_dir: models/water_confidence_res15_sigmoid

output:
  output_dir: outputs/predictions
  overwrite: true         # re-run structures whose output already exists
  format: pdb             # output structure format: pdb or cif
  include_protein: true   # include the input protein with the predicted waters

runtime:
  device: cuda            # only cuda is supported (no CPU)
  seed: 42                # random seed for reproducibility
  cleanup_intermediates: false  # delete this run's per-structure work files after success
  keep_embeddings: true   # keep the reusable ESM embeddings when cleaning up
  keep_graph_cache: true  # keep the reusable PyG graph cache when cleaning up

prediction:
  water_ratio: 10         # waters sampled per residue (higher = more coverage, more memory)
  inference_steps: 20     # number of reverse-diffusion steps
  confidence_cutoff: 0.1  # keep-probability threshold ~[0.02, 0.5] (higher = stricter)
  batch_size: 1           # structures scored per forward pass
  save_structure: true    # write <name>_centroid.{pdb,cif}
  save_filtered: false    # also write <name>_filtered.txt (off by default)

The example default writes protein + predicted waters and does not write the filtered file. With cleanup_intermediates: true, per-run work files are removed after a successful prediction while the reusable ESM embeddings and graph cache are kept (unless keep_embeddings/keep_graph_cache are set to false).

Web app

python apps/webapp/app.py

Open http://localhost:8891/, go to Predict, upload one or more .pdb/.cif/.mmcif files, set the options (water ratio, inference steps, confidence cutoff, output format, overwrite, include protein, cleanup), and run. Results appear per structure with a water count and download links — per structure or all as a zip. Predictions run synchronously on the GPU, so the page returns once the whole batch finishes.

Retraining

Show the retraining workflow (data prep, then score- and confidence-model training)

Retraining is a research-grade workflow (there is no single wrapper script): generate ESM-2 embeddings, train the score model, then train the confidence model on water positions sampled from it. The two stages are python -m superwater.train and python -m superwater.confidence.train — run either with --help for the full argument list. The commands below reproduce the shipped water_score_res15 / water_confidence_res15_sigmoid checkpoints; add --wandb --wandb_entity <user> to log to Weights & Biases. For score-only evaluation (scripts/score_inference.py, benchmark_score_pr.py) and full-pipeline evaluation over a split (superwater-infer), see docs/RUNNING.md.

1. Prepare data

Download the dataset from Zenodo (waterbind, 17,092 complexes) and place it under data/<dataset>/, one folder per complex:

data/<dataset>/<PDB_ID>/
├── <PDB_ID>_protein_processed.{cif,pdb}
└── <PDB_ID>_water.{cif,pdb}

The data layer reads CIF in preference to PDB and falls back to the other format if one is missing or fails to parse — for both the raw inputs and the _protein_processed/_water files. CIF is preferred because legacy fixed-column PDB cannot represent the newer 5-character ligand CCD codes (e.g. A1ADA) without corrupting the file.

The paper's train/val/test splits are in examples/data/splits/ (train_res15.txt, val_res15.txt, test_res15.txt) — each is a plain list of PDB IDs; supply your own to retrain on a different set.

Build a dataset from raw PDB-REDO files

To assemble the per-complex layout above from raw <id>_final.cif/<id>_final.pdb files (the PDB-REDO archive layout), use scripts/setup_custom_dataset.py. It splits each source into a protein-only _protein_processed file and a _water file (CIF by default), generates ESM embeddings, and writes normalized (lowercased, _final-stripped) split files to --split_out_dir — use those for training.

python scripts/setup_custom_dataset.py \
    --raw_data_dir <raw_dir> \
    --split_train <train.txt> --split_val <val.txt> --split_test <test.txt> \
    --out_dir data/<dataset> --split_out_dir data/<dataset>_splits \
    --embeddings_dir data/<dataset>_embeddings \
    --skip_existing --download_missing

Useful flags:

  • --skip_existing — incremental re-run: complexes that already have both output files are left untouched, and the embedding stage skips complexes that already have a _chain_0.pt. (It does not re-fix already-written-but-corrupt files.)
  • --download_missing — split ids absent from --raw_data_dir are fetched from PDB-REDO into a temp dir, processed into --out_dir, and the downloaded files are deleted.
  • --out_format {cif,pdb} — output format (default cif).
  • --skip_embeddings / --build_cache — skip the ESM stage, or also prebuild the PyG graph cache.

Per-water quality filters (ON by default) drop unreliable crystallographic waters before they become training targets — a water is removed if it fails any enabled filter:

  • --max_protein_dist (5.0 Å) — distance from the nearest protein atom; --no_filter_by_distance.
  • --min_edia (0.4) — minimum EDIAm electron-density score, read from the PDB-REDO <id>_final.json sidecar (no-op when absent); --no_filter_by_edia.
  • --max_bfactor_zscore (1.5) — per-structure water B-factor z-score; --no_filter_by_bfactor.

The logs/prepared.tsv waters_kept/total column reports how many survived per complex. Because filtering bakes into the written _water files, rebuild any existing data/cache/ graphs after re-running prep with different filter settings.

Run with --help for the full list (featurization flags, --num_workers, etc.).

2. Generate ESM-2 embeddings

Skip this if you already generated embeddings via setup_custom_dataset.py above.

superwater-embed --data_dir data/<dataset> --out_dir data/<dataset>_embeddings

Add --skip_existing to embed only complexes that don't yet have a _chain_0.pt in the output dir (the ESM model is loaded lazily, so a fully-cached re-run does no model load).

3. Train the score (diffusion) model

python -m superwater.train \
    --run_name water_score_res15_retrain \
    --data_dir data/<dataset> \
    --esm_embeddings_path data/<dataset>_embeddings \
    --split_train examples/data/splits/train_res15.txt \
    --split_val   examples/data/splits/val_res15.txt \
    --split_test  examples/data/splits/test_res15.txt \
    --log_dir models \
    --all_atoms --remove_hs --receptor_radius 15 --c_alpha_max_neighbors 24 \
    --ns 24 --nv 6 --num_conv_layers 3 \
    --distance_embed_dim 64 --cross_distance_embed_dim 64 --sigma_embed_dim 64 \
    --tr_sigma_min 0.1 --tr_sigma_max 30 --scale_by_sigma --dynamic_max_cross \
    --lr 1e-3 --batch_size 8 --n_epochs 300 \
    --scheduler plateau --scheduler_patience 30 --dropout 0.1 \
    --use_ema --cudnn_benchmark --test_sigma_intervals \
    --num_workers 10 --num_dataloader_workers 10 \
    --cache_scope dataset

Checkpoints are written to models/water_score_res15_retrain/ (best_model.pt, best_ema_model.pt, last_model.pt, model_parameters.yml), alongside losses_iter.csv (per-batch) and losses_epoch.csv (per-epoch train/val) for plotting. The dataset is preprocessed into a graph cache on the first run and reused afterwards; complexes that fail graph preprocessing are recorded in failed_complexes.txt in the cache directory (the only other trace is a missing <name>.pt).

Always pass --cache_scope dataset. It keys the graph cache by the dataset directory (one shared .pt per complex) instead of by the split-file name, so per-complex graphs are reused across any split that points at the same --data_dir — re-splitting train/val/test never triggers a rebuild, and it builds only the complexes it has not seen yet. The legacy split scope (the historical default) keys by split-file basename and silently drops complexes that are not already cached under that exact basename.

4. Train the confidence model

This samples water positions with the score model from step 3, caches them, and trains a classifier on each water's deviation from the true sites. The dataset/architecture flags must match the score model.

python -m superwater.confidence.train \
    --original_model_dir models/water_score_res15_retrain \
    --run_name water_confidence_res15_retrain \
    --data_dir data/<dataset> \
    --esm_embeddings_path data/<dataset>_embeddings \
    --split_train examples/data/splits/train_res15.txt \
    --split_val   examples/data/splits/val_res15.txt \
    --split_test  examples/data/splits/test_res15.txt \
    --log_dir models \
    --all_atoms --remove_hs \
    --ns 24 --nv 6 --num_conv_layers 3 --scale_by_sigma --dynamic_max_cross --dropout 0.1 \
    --inference_steps 20 --water_ratio 15 \
    --lr 1e-3 --batch_size 8 --n_epochs 50 \
    --running_mode train --mad_prediction \
    --cache_creation_id 1 --cache_ids_to_combine 1 \
    --cache_scope dataset

The first run is slow: it samples and caches positions for every training complex (under --cache_path, default data/cache_confidence); later runs reuse that cache. Lower --water_ratio (e.g. 10) and/or --batch_size if you hit GPU-memory limits.

5. Predict with the retrained models

Point a prediction config (see Quick start) at the new folders:

models:
  score_model_dir: models/water_score_res15_retrain
  confidence_model_dir: models/water_confidence_res15_retrain

Inference animation

Inference animation

Citation

@software{kuang_2025_superwater,
  author    = {Kuang, Xiaohan and Su, Zhaoqian},
  title     = {SuperWater: Predicting Water Molecule Positions on Protein Structures by Generative AI},
  year      = {2025},
  version   = {v1.0.0},
  publisher = {Zenodo},
  doi       = {10.5281/zenodo.17465949}
}

About

Implementation for SuperWater

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Python 94.0%
  • HTML 3.4%
  • Shell 2.2%
  • CSS 0.4%