Skip to content

learnslowly/BiU-Net

Repository files navigation

BiU-Net

A bio-aware 1D U-Net for SNP genotype imputation. BiU-Net reconstructs masked genotypes from a segmented haplotype context, optionally conditioned on a per-variant biological prior (normalized position / genetic-map gap / MAF). It is benchmarked against Beagle and an SCDA autoencoder baseline.

Reference implementation for the paper BiU-Net: a Biological-informed U-Net for Genotype Imputation (preprint).


Installation

BiU-Net targets Python 3.10 and PyTorch 2.5 (CUDA 12.1).

conda create -n biunet python=3.10 -y
conda activate biunet
# Install the CUDA build of torch that matches your driver, e.g.:
pip install torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121
pip install -r requirements.txt

Configuration

All private/cluster settings live in one file. Copy the template and fill it in:

cp configs/credentials.sh.example configs/credentials.sh
$EDITOR configs/credentials.sh

configs/credentials.sh is git-ignored (it carries the Weights & Biases key) and exports everything the batch scripts and config/modelconfig.py read at runtime — the remote host/account, conda interpreter (PYBIN), SLURM partitions, DDP (MASTER_PORT, NCCL_SOCKET_IFNAME), WANDB_API_KEY, and BIO_FILE. No private path or credential is hard-coded anywhere else in the tree.

Model recipe (default)

The dataclass defaults in config/modelconfig.py are the production recipe; a bare config inherits them:

Architecture 1D U-Net, depth 6, 48 channels, kernel 7
Bio prior bioAware (normalized position by default)
Loss hybridFocalLoss (focal + dosage), dosageLossLambda=0.1
Regularization dropout 0.05
Curriculum missing-rate annealing toward the target rate

Per-run settings (dataset, epochs, missing rates, eval) come from the YAML config passed via --configFile.

Datasets

Seven dataset/population configurations are provided. Each has a segmentation config, a training config, and an evaluation config (used by both test.py and benchmark.py).

Dataset Segmentation Train Eval
1KGP chr22 ALL 1KGP_chr22_ALL_seg128_overlap16 train_1KGP_chr22_ALL_seg128 test_seg128_1KGP_chr22_ALL
LOS chr22 ALL LOS_chr22_ALL_seg128_overlap16 train_LOS_chr22_ALL_seg128 test_seg128_LOS_chr22_ALL
SGDP chr22 ALL SGDP_chr22_ALL_seg128_overlap16 train_SGDP_chr22_ALL_seg128 test_seg128_SGDP_chr22_ALL
SGDP HLA chr6 ALL HLA_chr6_ALL_seg128_overlap64 train_HLA_chr6_ALL_seg128 test_seg128_HLA_chr6_ALL
SGDP chr19 ALL SGDP_chr19_ALL_seg128_overlap16 train_SGDP_chr19_ALL_seg128 test_seg128_SGDP_chr19_ALL
LOS chr22 AA LOS_chr22_AA_seg128_overlap16 train_LOS_chr22_AA_seg128 test_seg128_LOS_chr22_AA
LOS chr22 CA LOS_chr22_CA_seg128_overlap16 train_LOS_chr22_CA_seg128 test_seg128_LOS_chr22_CA

The SCDA baseline (full-length, no segmentation) is configured for SGDP chr19: SGDP_chr19_ALL_seg-1_overlap0_scda (train) and test_scda_SGDP_chr19_ALL (eval).

Pipeline

Each stage is a plain Python entry point. Substitute any config from the table above.

# 1. Segment haplotypes into fixed-length windows
python segmenting.py --configFile configs/SGDP_chr19_ALL_seg128_overlap16.yaml

# 2. Train (DDP-ready; launch under srun/torchrun for multi-GPU)
python train.py --configFile configs/train_SGDP_chr19_ALL_seg128.yaml

# 3. Test (reconstruct masked genotypes, write imputed CSVs)
python test.py --configFile configs/test_seg128_SGDP_chr19_ALL.yaml

# 4. Benchmark one (missing-level, random-state); add --impMethod beagle for Beagle
python benchmark.py --configFile configs/test_seg128_SGDP_chr19_ALL.yaml \
                    --randState 42 --missingLevelIdx 0

# 5. Report: aggregate the per-bin benchmark CSVs into a comparison table
python benchmark.py --report --configFile configs/test_seg128_SGDP_chr19_ALL.yaml \
                    --reportScda configs/test_scda_SGDP_chr19_ALL.yaml

The report mode is a pure view over the CSVs that step 4 wrote: it averages over random states and pivots into a tab-separated table (metric × method by MAF bin, one block per missing rate) — ready to paste into a results document.

Running on SLURM

job/ holds a batch script for every (stage × dataset) — e.g. job/train_SGDP_chr19_ALL.batch, job/report_SGDP_chr19_ALL.batch. Each sources configs/credentials.sh, writes logs to logs/, and omits the private --account/--mail-user directives (those are injected at submit time). Submit from the project root on the cluster:

./submit.sh job/train_SGDP_chr19_ALL.batch   # injects --account / --mail-user from credentials.sh

Regenerate the whole batch set after editing the recipe with python scripts/gen_job_batches.py.

Repository layout

config/        ModelConfig dataclass (defaults = production recipe) + YAML loader
configs/       per-dataset segmentation / train / eval YAMLs + credentials template
data/          dataset, segmentation, metrics, masking utilities
model/         U-Net and SCDA architectures
job/           one SLURM batch per (stage × dataset)
scripts/       batch/config generators and helpers
*.py           segmenting / train / test / benchmark entry points
submit.sh      cluster-side sbatch wrapper (injects private SLURM directives)

Citation

If you use BiU-Net, please cite the preprint:

BiU-Net: a Biological-informed U-Net for Genotype Imputation. Research Square (2025). https://doi.org/10.21203/rs.3.rs-6797863/v1

@article{biunet2025,
  title   = {BiU-Net: a Biological-informed U-Net for Genotype Imputation},
  journal = {Research Square (preprint)},
  year    = {2025},
  doi     = {10.21203/rs.3.rs-6797863/v1},
  url     = {https://doi.org/10.21203/rs.3.rs-6797863/v1}
}

License

Released under the MIT License.

About

Biologically-informed U-Net for genotype imputation

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors