ContVAR learns structure-aware protein representations for single amino acid variants (SAVs). A graph neural network (GATv2) encodes each variant structure as:
- a global graph embedding (whole-protein representation), and
- a local embedding at the mutation site.
Training uses metric learning on triplets from the same protein family:
| Role | Structure |
|---|---|
| Anchor | Wild-type (WT) |
| Positive | Benign variant |
| Negative | Pathogenic variant |
The model is trained to pull benign variants toward the WT and push pathogenic variants away, both globally and at the mutation position.
A downstream Phase 2 decoder trains feed-forward networks on top of frozen protein embeddings to predict Gene Ontology (GO) annotations and to score variant-induced functional changes.
Try ContVAR without installing anything:
https://huggingface.co/spaces/fd55/contvarspace
The demo supports two analysis modes:
| Mode | Input | Output |
|---|---|---|
| Protein structure only | One PDB or CIF file | Thresholded GO term predictions (MF, BP, CC) with confidence scores |
| WT + variant | Wild-type and variant PDB/CIF files | Per-term WT vs variant scores, delta, and gained / lost / stable functional changes |
Both modes accept .pdb and .cif structures and include example files for quick testing. Predictions are reported when scores reach the model threshold (≥ 0.60); the WT vs variant mode additionally classifies terms as gained or lost using a score-delta threshold.
┌─────────────────────────────────────────────────────────────────────────┐
│ Phase 0 — GO semantic-similarity pretraining │
│ semantic_similarity/*.tsv + prebuilt GO graph .pt + protein split JSON │
│ → model_phase0_best_loss.pt │
└───────────────────────────────┬─────────────────────────────────────────┘
│ initialize Phase 1 from Phase 0 weights
▼
┌─────────────────────────────────────────────────────────────────────────┐
│ Phase 1 — DMS triplet metric learning │
│ protein_triplets_data/ + ESM2 H5 + dms_protein_split.json │
│ → model_best_loss.pt, model_last.pt │
└───────────────────────────────┬─────────────────────────────────────────┘
│
┌─────────────────────┼─────────────────────┐
▼ ▼ ▼
Embedding export t-SNE plots Frozen inference
(exports/*.h5) (visualizations/) (contvar.inference)
│
▼
┌─────────────────────────────────────────────────────────────────────────┐
│ Phase 2 decoder — GO term prediction & variant function scoring │
│ ESM / ContVAR H5 + GOA + UniRef split → decoder_best_*.pt │
└─────────────────────────────────────────────────────────────────────────┘
The encoder is trained in two steps: Phase 0 learns from GO semantic-similarity triplets on Swiss-Prot proteins, then Phase 1 fine-tunes on DMS variant triplets. Phase 1 always starts from Phase 0 weights — either by training Phase 0 yourself, or by loading the published model_phase0_best_loss.pt checkpoint.
End-to-end local entry point: python starter.py runs Phase 0 (when configured to train) and Phase 1, then exports embeddings and generates t-SNE plots.
Colab notebook: run.ipynb walks through the same encoder pipeline with Google Drive data paths.
ContVAR/
├── starter.py # Main CLI — edit STARTER_PATHS here
├── run.ipynb # Colab reproduction notebook
├── setup.py
├── local_splits/
│ └── dms_protein_split.json # Fixed protein-family train/val/test split
├── contvar/ # Encoder: model, training, graph building, inference
│ ├── training.py # train_pipeline()
│ ├── go_pretraining.py # Phase 0
│ ├── prebuild_graphs.py # Build PyG graphs from mmCIF + ESM2
│ ├── inference.py # Frozen-checkpoint embedding export
│ └── export_embeddings.py # Post-training H5 export
└── phase2_decoder/ # Phase 2: GO decoder training, evaluation, variant prediction
├── train.py
├── test_eval.py
├── benchmark_eval.py
├── predict_variants.py
└── grid_search.py
This repository contains the training and evaluation code, together with fixed data splits. Datasets, prebuilt graphs, and trained model checkpoints are distributed with the project release (see Data and model artifacts).
- Python ≥ 3.8
- CUDA GPU recommended for training (CPU works for small smoke tests)
- Weights & Biases account (training logs experiments automatically)
git clone https://github.com/HUBioDataLab/ContVAR.git
cd ContVAR
pip install -e .Install PyTorch and PyG for your platform first, then the remaining dependencies:
pip install torch torch-geometric graphein wandb biopython h5py \
scikit-learn matplotlib tqdm pandas numpy networkxOn Linux, Graphein may require DSSP for some graph features:
sudo apt-get install dsspLog in to Weights & Biases before training:
wandb loginDownload the release bundle and place files according to the layouts below. The release includes trained encoder and decoder weights; exact download location will be listed in the project release page.
| Checkpoint | Stage |
|---|---|
model_phase0_best_loss.pt |
Phase 0 encoder (best validation loss) |
model_phase0_last.pt |
Phase 0 encoder (final epoch) |
model_best_loss.pt |
Phase 1 encoder (best validation loss) |
model_last.pt |
Phase 1 encoder (final epoch) |
decoder_best_{embedding}_{aspect}.pt |
Phase 2 decoder |
To reproduce Phase 1 without re-running Phase 0 training, set go_phase0_epochs to 0 and point go_phase0_init_checkpoint_path at model_phase0_best_loss.pt. This is the workflow used in run.ipynb.
| Path | Purpose |
|---|---|
semantic_similarity/semantic_similarity_swissprot_filtered_low0.2_high0.8_mf.tsv |
MF semantic-similarity triplets |
semantic_similarity/semantic_similarity_swissprot_filtered_low0.2_high0.8_bp.tsv |
BP triplets |
semantic_similarity/semantic_similarity_swissprot_filtered_low0.2_high0.8_cc.tsv |
CC triplets |
<go_prebuilt_graph_root>/ |
Directory tree of prebuilt PyG .pt graph files (one per Swiss-Prot protein) |
local_splits/phase0_protein_split_removed_graphless.json |
protein_to_split mapping for GO proteins |
protein_triplets_data/
├── originals/ # WT mmCIF files, one per protein family
│ └── <family>_wt_model.cif
├── positives/ # Benign variant structures
│ └── <family>_wt_model/
│ └── *.cif
├── negatives/ # Pathogenic variant structures
│ └── <family>_wt_model/
│ └── *.cif
└── processed/ # Auto-generated PyG graph cache (created on first run)
Additional files:
| File | Purpose |
|---|---|
embeddings_variable.h5 |
Per-residue ESM-2 embeddings keyed by structure filename stem |
local_splits/dms_protein_split.json |
Protein-family → train / val / test mapping (included in repo) |
The split JSON uses family_to_split keys that match WT filenames without extension (e.g. blat_ecolx_stiffler_2015_p62593_wt_model).
Place these files in the working directory (or update paths in phase2_decoder/config.py):
| File | Purpose |
|---|---|
goa_2025-12-04_swissprot_noiea.tsv |
Swiss-Prot GO annotations |
esm2_t33_650M_UR50D_protein_embedding.h5 |
ESM-2 protein embeddings |
go_pretraining_contvar_embeddings.h5 |
ContVAR embeddings after Phase 0 |
stage2_best_pretraining_protein_embeddings.h5 |
ContVAR embeddings after Phase 1 (contvar_full) |
protein_uniref50.tsv |
Protein → UniRef50 cluster mapping |
phase0_go_split.json |
UniRef50 cluster → split assignment |
go.obo |
GO hierarchy (for optional propagation) |
variant_specific_go_benchmark.tsv |
LOF/GOF benchmark (for benchmark_eval) |
All local file paths for the encoder pipeline are centralized in STARTER_PATHS at the top of starter.py. Edit this block before running:
STARTER_PATHS = {
"data_root": "protein_triplets_data",
"embeddings_path": "embeddings_variable.h5",
"dms_protein_split_json_path": "local_splits/dms_protein_split.json",
"go_protein_split_json_path": "local_splits/phase0_protein_split_removed_graphless.json",
"go_tsv_dir": "semantic_similarity",
"go_prebuilt_graph_root": "/path/to/prebuilt_go_graphs", # directory of .pt graphs
"go_phase0_init_checkpoint_path": "model_phase0_best_loss.pt",
# ... checkpoint and export paths ...
}Important notes:
go_prebuilt_graph_rootmust be a directory of.ptgraph files, not a model checkpoint.starter.pysetsgo_phase0_epochs: 200to train Phase 0 from scratch. Setgo_phase0_epochs: 0and providego_phase0_init_checkpoint_pathto start Phase 1 from the published Phase 0 weights without retraining.- Hyperparameters (learning rate, margin, epochs, etc.) live in
contvar/config.py(ProjectConfig). - Decoder paths and hyperparameters live in
phase2_decoder/config.py(DecoderConfig).
Note: internal code still uses stage2_* variable names for Phase 1 checkpoint paths — these refer to the same files (model_best_loss.pt, etc.).
python starter.pyThis will:
- Train Phase 0 on GO semantic-similarity triplets (
go_phase0_epochs: 200by default instarter.py). - Initialize Phase 1 from the resulting Phase 0 weights.
- Build or reuse cached graphs under
protein_triplets_data/processed/. - Train Phase 1 with streaming semi-hard negative mining (300 epochs by default).
- Save
model_best_loss.pt(lowest validation loss) andmodel_last.pt. - Export embeddings to
exports/and write t-SNE plots tovisualizations/.
Rebuild all processed graphs from scratch:
python starter.py --forceTo skip Phase 0 training and fine-tune directly from the released weights (as in run.ipynb):
# In starter.py _build_config_overrides or via train_pipeline config dict:
"go_phase0_epochs": 0,
"go_phase0_init_checkpoint_path": "model_phase0_best_loss.pt",Then run python starter.py or follow run.ipynb.
To prebuild PyG graphs (for GO proteins or custom structures) without training:
python -m contvar.prebuild_graphs \
--structure-dir path/to/cif_files \
--output-dir path/to/output_graphs \
--embeddings-h5 embeddings_variable.h5Build graphs and stream them through one or more checkpoints in a single pass:
python -m contvar.prebuild_graphs \
--structure-dir path/to/cif_files \
--output-dir path/to/output_graphs \
--embeddings-h5 embeddings_variable.h5 \
--checkpoint model_best_loss.pt \
--inference-output-dir exports/prebuild_inferenceExport normalized global embeddings from existing .pt graph files:
python -m contvar.inference \
--checkpoint model_best_loss.pt \
--graph-root protein_triplets_data/processed \
--out exports/inference_contvar_embeddings.h5 \
--batch-size 32Train a GO decoder (aspect: F=MF, P=BP, C=CC):
python -m phase2_decoder.train --aspect F --embedding concatEmbedding modes: esm, contvar, contvar_full, concat, concat_full.
Evaluate the best saved checkpoint on the held-out test split:
python -m phase2_decoder.test_eval --aspect F --embedding concatRun the variant-specific LOF/GOF benchmark:
python -m phase2_decoder.benchmark_eval --aspect F --embedding concat
python -m phase2_decoder.benchmark_eval --allPredict functional changes for variants in an H5 file:
python -m phase2_decoder.predict_variants \
--var_h5 embeddings_variable.h5 \
--out predictions.csvHyperparameter search:
python -m phase2_decoder.grid_search --aspect F --embedding concat| Artifact | Description |
|---|---|
model_phase0_best_loss.pt |
Best Phase 0 encoder weights |
model_phase0_last.pt |
Final Phase 0 encoder weights |
model_best_loss.pt |
Best Phase 1 encoder (validation loss) |
model_last.pt |
Final Phase 1 encoder |
model_epoch_{N}.pt |
Periodic Phase 1 snapshots (epoch 80, then every 100) |
exports/phase0_contvar_embeddings.h5 |
Global embeddings for GO proteins |
exports/dms_variant_contvar_embeddings.h5 |
Global embeddings for DMS variants |
visualizations/best/, visualizations/last/ |
t-SNE plots (global vs local, baseline vs projected) |
| Artifact | Description |
|---|---|
decoder_best_{embedding}_{aspect}.pt |
Best decoder checkpoint |
go_vocab_{aspect}.json |
GO term vocabulary |
Metrics logged to Weights & Biases include triplet loss, MRR, alignment, uniformity (encoder) and mAP, F1, MCC (decoder).
| Item | Location |
|---|---|
| DMS protein-family split | local_splits/dms_protein_split.json (version 2) |
| GO protein split | local_splits/phase0_protein_split_removed_graphless.json |
| Decoder UniRef50 split | phase0_go_split.json |
| Random seed (GO split) | go_split_seed = 42 in ProjectConfig |
| Decoder seed | seed = 42 in DecoderConfig |
Training is deterministic given fixed splits and seeds, but exact GPU numerics may vary slightly across hardware.
Run these after preparing data and downloading the release artifacts:
# 1. Package imports
python -c "from contvar import train_pipeline, ProjectConfig; print('contvar OK')"
# 2. DMS split loads (requires protein_triplets_data/)
python -c "
from contvar.data.mapper import TripletDataPathMapper
m = TripletDataPathMapper('protein_triplets_data', 'local_splits/dms_protein_split.json')
print(f'Families: {len(m.triplets)} | train={len(m.train_triplets)} val={len(m.val_triplets)} test={len(m.test_triplets)}')
"
# 3. Decoder evaluation (requires decoder data + trained checkpoint)
python -m phase2_decoder.test_eval --aspect F --embedding esmEncoder (DeepProteinGAT):
- Input nodes: amino-acid one-hot (20-d) + ESM-2 residue embedding (1280-d)
- Edges: SALAD-style hybrid connectivity (default) or Graphein kNN
- Two GATv2 layers with residual connections and edge features
- Projection heads → 256-d global and local embeddings
- Ontology-specific heads for Phase 0 (MF / BP / CC)
Decoder (FFNDecoder):
- Multi-layer feed-forward network with dropout
- Multi-label sigmoid output over GO vocabulary
- Class-weighted BCE loss with extra weight on rare terms
If you use this code, please cite the ContVAR paper (add citation when available) and acknowledge the HUBioDataLab/ContVAR repository.
See repository license file for terms of use.