SuperWater predicts water-molecule positions on protein surfaces using a score-based diffusion model with equivariant neural networks (e3nn): random water "particles" are moved onto hydration sites by reverse diffusion, a confidence model scores them, and a clustering step produces the final waters.
Paper: Communications Chemistry (article) · bioRxiv preprint · Contact: xiaohan.kuang@takeda.com, zhaoqian.su@takeda.com
Detailed running reference:
docs/RUNNING.mddocuments every runnable workflow end to end — setup, data prep, score training, score inference, confidence setup/training, and full-pipeline inference — with the exact commands, flags, and outputs.
Requires an NVIDIA GPU (CPU is not supported).
For a CUDA 12.6 driver, one command resolves the full stack from uv.lock (PyTorch
2.8 + the matching PyTorch Geometric extension wheels, e3nn, rdkit, fair-esm, …) and
installs superwater itself:
uv sync --extra cu126 # add --extra dev for pytestThen run any command through the locked environment, e.g.:
uv run superwater-predict --config examples/configs/predict_5srf.yaml
uv run python -m pytest # add --extra dev firstThe PyTorch / PyG indexes and pins live in pyproject.toml ([tool.uv]). torch-geometric
is pinned to 2.6.1 on purpose — 2.7 made BaseTransform.forward abstract, which breaks
the dataset's NoiseTransform.
Different CUDA version? Edit the two
[[tool.uv.index]]URLs inpyproject.toml(download.pytorch.org/whl/<cuXXX>anddata.pyg.org/whl/torch-<ver>+<cuXXX>.html) to a matching pair, then re-runuv lock && uv sync --extra cu126. Or use the conda path.
bash scripts/install.sh
conda activate superwaterThis creates the superwater conda env (PyTorch 2.5.1+cu118, e3nn, rdkit),
installs the PyTorch Geometric CUDA wheels, and runs pip install -e .. Equivalent manual
steps:
conda env create -f environment.yml
conda activate superwater
pip install -r requirements-pyg-cu118.txt
pip install -e .Check the GPU with python scripts/check_gpu.py (or uv run python scripts/check_gpu.py).
Pretrained models ship under models/ and an example input folder is bundled. Predict
waters for every structure in a folder with one command:
superwater-predict --config examples/configs/predict_5srf.yamlOutput is concise by default; add --verbose for detailed progress and paths, or
--debug for library warnings and full tracebacks (or set runtime.verbosity: quiet|normal|verbose|debug in the config).
The first run downloads the ESM-2 model (~2.5 GB) to ~/.cache/torch; embeddings are
generated in-process. Put one or many .pdb/.cif/.mmcif files in input.structure_dir
to run a batch — CIF/mmCIF inputs are converted automatically and unsupported files are
skipped. Outputs for each input <name> go to outputs/predictions/<name>/:
<name>_centroid.pdb(or.cif) — final predicted waters. Withinclude_protein: true(the example default) it also contains the input protein, with waters added asHOHon a separate chain.<name>_centroid.txt— final water coordinates (xyz).<name>_filtered.txt— all sampled positions + scores, written only whensave_filtered: true(off by default).
examples/configs/predict_5srf.yaml:
input:
structure_dir: examples/data/batch_structures # folder of .pdb/.cif/.mmcif files
name: null # optional name override (single-file input only)
models:
score_model_dir: models/water_score_res15
confidence_model_dir: models/water_confidence_res15_sigmoid
output:
output_dir: outputs/predictions
overwrite: true # re-run structures whose output already exists
format: pdb # output structure format: pdb or cif
include_protein: true # include the input protein with the predicted waters
runtime:
device: cuda # only cuda is supported (no CPU)
seed: 42 # random seed for reproducibility
cleanup_intermediates: false # delete this run's per-structure work files after success
keep_embeddings: true # keep the reusable ESM embeddings when cleaning up
keep_graph_cache: true # keep the reusable PyG graph cache when cleaning up
prediction:
water_ratio: 10 # waters sampled per residue (higher = more coverage, more memory)
inference_steps: 20 # number of reverse-diffusion steps
confidence_cutoff: 0.1 # keep-probability threshold ~[0.02, 0.5] (higher = stricter)
batch_size: 1 # structures scored per forward pass
save_structure: true # write <name>_centroid.{pdb,cif}
save_filtered: false # also write <name>_filtered.txt (off by default)The example default writes protein + predicted waters and does not write the
filtered file. With cleanup_intermediates: true, per-run work files are removed after a
successful prediction while the reusable ESM embeddings and graph cache are kept (unless
keep_embeddings/keep_graph_cache are set to false).
python apps/webapp/app.pyOpen http://localhost:8891/, go to Predict, upload one or more .pdb/.cif/.mmcif
files, set the options (water ratio, inference steps, confidence cutoff, output format,
overwrite, include protein, cleanup), and run. Results appear per structure with a water
count and download links — per structure or all as a zip. Predictions run synchronously on
the GPU, so the page returns once the whole batch finishes.
Show the retraining workflow (data prep, then score- and confidence-model training)
Retraining is a research-grade workflow (there is no single wrapper script): generate
ESM-2 embeddings, train the score model, then train the confidence model on water
positions sampled from it. The two stages are python -m superwater.train and
python -m superwater.confidence.train — run either with --help for the full argument
list. The commands below reproduce the shipped water_score_res15 /
water_confidence_res15_sigmoid checkpoints; add --wandb --wandb_entity <user> to log to
Weights & Biases. For score-only evaluation (scripts/score_inference.py,
benchmark_score_pr.py) and full-pipeline evaluation over a split (superwater-infer),
see docs/RUNNING.md.
Download the dataset from Zenodo (waterbind,
17,092 complexes) and place it under data/<dataset>/, one folder per complex:
data/<dataset>/<PDB_ID>/
├── <PDB_ID>_protein_processed.{cif,pdb}
└── <PDB_ID>_water.{cif,pdb}
The data layer reads CIF in preference to PDB and falls back to the other format if one is
missing or fails to parse — for both the raw inputs and the _protein_processed/_water
files. CIF is preferred because legacy fixed-column PDB cannot represent the newer
5-character ligand CCD codes (e.g. A1ADA) without corrupting the file.
The paper's train/val/test splits are in examples/data/splits/ (train_res15.txt,
val_res15.txt, test_res15.txt) — each is a plain list of PDB IDs; supply your own to
retrain on a different set.
To assemble the per-complex layout above from raw <id>_final.cif/<id>_final.pdb files
(the PDB-REDO archive layout), use scripts/setup_custom_dataset.py. It splits each source
into a protein-only _protein_processed file and a _water file (CIF by default),
generates ESM embeddings, and writes normalized (lowercased, _final-stripped) split files
to --split_out_dir — use those for training.
python scripts/setup_custom_dataset.py \
--raw_data_dir <raw_dir> \
--split_train <train.txt> --split_val <val.txt> --split_test <test.txt> \
--out_dir data/<dataset> --split_out_dir data/<dataset>_splits \
--embeddings_dir data/<dataset>_embeddings \
--skip_existing --download_missingUseful flags:
--skip_existing— incremental re-run: complexes that already have both output files are left untouched, and the embedding stage skips complexes that already have a_chain_0.pt. (It does not re-fix already-written-but-corrupt files.)--download_missing— split ids absent from--raw_data_dirare fetched from PDB-REDO into a temp dir, processed into--out_dir, and the downloaded files are deleted.--out_format {cif,pdb}— output format (defaultcif).--skip_embeddings/--build_cache— skip the ESM stage, or also prebuild the PyG graph cache.
Per-water quality filters (ON by default) drop unreliable crystallographic waters before they become training targets — a water is removed if it fails any enabled filter:
--max_protein_dist(5.0 Å) — distance from the nearest protein atom;--no_filter_by_distance.--min_edia(0.4) — minimum EDIAm electron-density score, read from the PDB-REDO<id>_final.jsonsidecar (no-op when absent);--no_filter_by_edia.--max_bfactor_zscore(1.5) — per-structure water B-factor z-score;--no_filter_by_bfactor.
The logs/prepared.tsv waters_kept/total column reports how many survived per complex.
Because filtering bakes into the written _water files, rebuild any existing data/cache/
graphs after re-running prep with different filter settings.
Run with --help for the full list (featurization flags, --num_workers, etc.).
Skip this if you already generated embeddings via setup_custom_dataset.py above.
superwater-embed --data_dir data/<dataset> --out_dir data/<dataset>_embeddingsAdd --skip_existing to embed only complexes that don't yet have a _chain_0.pt in the
output dir (the ESM model is loaded lazily, so a fully-cached re-run does no model load).
python -m superwater.train \
--run_name water_score_res15_retrain \
--data_dir data/<dataset> \
--esm_embeddings_path data/<dataset>_embeddings \
--split_train examples/data/splits/train_res15.txt \
--split_val examples/data/splits/val_res15.txt \
--split_test examples/data/splits/test_res15.txt \
--log_dir models \
--all_atoms --remove_hs --receptor_radius 15 --c_alpha_max_neighbors 24 \
--ns 24 --nv 6 --num_conv_layers 3 \
--distance_embed_dim 64 --cross_distance_embed_dim 64 --sigma_embed_dim 64 \
--tr_sigma_min 0.1 --tr_sigma_max 30 --scale_by_sigma --dynamic_max_cross \
--lr 1e-3 --batch_size 8 --n_epochs 300 \
--scheduler plateau --scheduler_patience 30 --dropout 0.1 \
--use_ema --cudnn_benchmark --test_sigma_intervals \
--num_workers 10 --num_dataloader_workers 10 \
--cache_scope datasetCheckpoints are written to models/water_score_res15_retrain/ (best_model.pt,
best_ema_model.pt, last_model.pt, model_parameters.yml), alongside losses_iter.csv
(per-batch) and losses_epoch.csv (per-epoch train/val) for plotting. The dataset is
preprocessed into a graph cache on the first run and reused afterwards; complexes that fail
graph preprocessing are recorded in failed_complexes.txt in the cache directory (the only
other trace is a missing <name>.pt).
Always pass --cache_scope dataset. It keys the graph cache by the dataset directory (one
shared .pt per complex) instead of by the split-file name, so per-complex graphs are
reused across any split that points at the same --data_dir — re-splitting train/val/test
never triggers a rebuild, and it builds only the complexes it has not seen yet. The legacy
split scope (the historical default) keys by split-file basename and silently drops
complexes that are not already cached under that exact basename.
This samples water positions with the score model from step 3, caches them, and trains a classifier on each water's deviation from the true sites. The dataset/architecture flags must match the score model.
python -m superwater.confidence.train \
--original_model_dir models/water_score_res15_retrain \
--run_name water_confidence_res15_retrain \
--data_dir data/<dataset> \
--esm_embeddings_path data/<dataset>_embeddings \
--split_train examples/data/splits/train_res15.txt \
--split_val examples/data/splits/val_res15.txt \
--split_test examples/data/splits/test_res15.txt \
--log_dir models \
--all_atoms --remove_hs \
--ns 24 --nv 6 --num_conv_layers 3 --scale_by_sigma --dynamic_max_cross --dropout 0.1 \
--inference_steps 20 --water_ratio 15 \
--lr 1e-3 --batch_size 8 --n_epochs 50 \
--running_mode train --mad_prediction \
--cache_creation_id 1 --cache_ids_to_combine 1 \
--cache_scope datasetThe first run is slow: it samples and caches positions for every training complex (under
--cache_path, default data/cache_confidence); later runs reuse that cache. Lower
--water_ratio (e.g. 10) and/or --batch_size if you hit GPU-memory limits.
Point a prediction config (see Quick start) at the new folders:
models:
score_model_dir: models/water_score_res15_retrain
confidence_model_dir: models/water_confidence_res15_retrain@software{kuang_2025_superwater,
author = {Kuang, Xiaohan and Su, Zhaoqian},
title = {SuperWater: Predicting Water Molecule Positions on Protein Structures by Generative AI},
year = {2025},
version = {v1.0.0},
publisher = {Zenodo},
doi = {10.5281/zenodo.17465949}
}
