A bio-aware 1D U-Net for SNP genotype imputation. BiU-Net reconstructs masked genotypes from a segmented haplotype context, optionally conditioned on a per-variant biological prior (normalized position / genetic-map gap / MAF). It is benchmarked against Beagle and an SCDA autoencoder baseline.
Reference implementation for the paper BiU-Net: a Biological-informed U-Net for Genotype Imputation (preprint).
BiU-Net targets Python 3.10 and PyTorch 2.5 (CUDA 12.1).
conda create -n biunet python=3.10 -y
conda activate biunet
# Install the CUDA build of torch that matches your driver, e.g.:
pip install torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121
pip install -r requirements.txtAll private/cluster settings live in one file. Copy the template and fill it in:
cp configs/credentials.sh.example configs/credentials.sh
$EDITOR configs/credentials.shconfigs/credentials.sh is git-ignored (it carries the Weights & Biases key) and
exports everything the batch scripts and config/modelconfig.py read at runtime —
the remote host/account, conda interpreter (PYBIN), SLURM partitions, DDP
(MASTER_PORT, NCCL_SOCKET_IFNAME), WANDB_API_KEY, and BIO_FILE. No private
path or credential is hard-coded anywhere else in the tree.
The dataclass defaults in config/modelconfig.py are the production recipe; a bare
config inherits them:
| Architecture | 1D U-Net, depth 6, 48 channels, kernel 7 |
| Bio prior | bioAware (normalized position by default) |
| Loss | hybridFocalLoss (focal + dosage), dosageLossLambda=0.1 |
| Regularization | dropout 0.05 |
| Curriculum | missing-rate annealing toward the target rate |
Per-run settings (dataset, epochs, missing rates, eval) come from the YAML config
passed via --configFile.
Seven dataset/population configurations are provided. Each has a segmentation
config, a training config, and an evaluation config (used by both test.py and
benchmark.py).
| Dataset | Segmentation | Train | Eval |
|---|---|---|---|
| 1KGP chr22 ALL | 1KGP_chr22_ALL_seg128_overlap16 |
train_1KGP_chr22_ALL_seg128 |
test_seg128_1KGP_chr22_ALL |
| LOS chr22 ALL | LOS_chr22_ALL_seg128_overlap16 |
train_LOS_chr22_ALL_seg128 |
test_seg128_LOS_chr22_ALL |
| SGDP chr22 ALL | SGDP_chr22_ALL_seg128_overlap16 |
train_SGDP_chr22_ALL_seg128 |
test_seg128_SGDP_chr22_ALL |
| SGDP HLA chr6 ALL | HLA_chr6_ALL_seg128_overlap64 |
train_HLA_chr6_ALL_seg128 |
test_seg128_HLA_chr6_ALL |
| SGDP chr19 ALL | SGDP_chr19_ALL_seg128_overlap16 |
train_SGDP_chr19_ALL_seg128 |
test_seg128_SGDP_chr19_ALL |
| LOS chr22 AA | LOS_chr22_AA_seg128_overlap16 |
train_LOS_chr22_AA_seg128 |
test_seg128_LOS_chr22_AA |
| LOS chr22 CA | LOS_chr22_CA_seg128_overlap16 |
train_LOS_chr22_CA_seg128 |
test_seg128_LOS_chr22_CA |
The SCDA baseline (full-length, no segmentation) is configured for SGDP chr19:
SGDP_chr19_ALL_seg-1_overlap0_scda (train) and test_scda_SGDP_chr19_ALL (eval).
Each stage is a plain Python entry point. Substitute any config from the table above.
# 1. Segment haplotypes into fixed-length windows
python segmenting.py --configFile configs/SGDP_chr19_ALL_seg128_overlap16.yaml
# 2. Train (DDP-ready; launch under srun/torchrun for multi-GPU)
python train.py --configFile configs/train_SGDP_chr19_ALL_seg128.yaml
# 3. Test (reconstruct masked genotypes, write imputed CSVs)
python test.py --configFile configs/test_seg128_SGDP_chr19_ALL.yaml
# 4. Benchmark one (missing-level, random-state); add --impMethod beagle for Beagle
python benchmark.py --configFile configs/test_seg128_SGDP_chr19_ALL.yaml \
--randState 42 --missingLevelIdx 0
# 5. Report: aggregate the per-bin benchmark CSVs into a comparison table
python benchmark.py --report --configFile configs/test_seg128_SGDP_chr19_ALL.yaml \
--reportScda configs/test_scda_SGDP_chr19_ALL.yamlThe report mode is a pure view over the CSVs that step 4 wrote: it averages over random states and pivots into a tab-separated table (metric × method by MAF bin, one block per missing rate) — ready to paste into a results document.
job/ holds a batch script for every (stage × dataset) — e.g. job/train_SGDP_chr19_ALL.batch,
job/report_SGDP_chr19_ALL.batch. Each sources configs/credentials.sh, writes
logs to logs/, and omits the private --account/--mail-user directives (those
are injected at submit time). Submit from the project root on the cluster:
./submit.sh job/train_SGDP_chr19_ALL.batch # injects --account / --mail-user from credentials.shRegenerate the whole batch set after editing the recipe with
python scripts/gen_job_batches.py.
config/ ModelConfig dataclass (defaults = production recipe) + YAML loader
configs/ per-dataset segmentation / train / eval YAMLs + credentials template
data/ dataset, segmentation, metrics, masking utilities
model/ U-Net and SCDA architectures
job/ one SLURM batch per (stage × dataset)
scripts/ batch/config generators and helpers
*.py segmenting / train / test / benchmark entry points
submit.sh cluster-side sbatch wrapper (injects private SLURM directives)
If you use BiU-Net, please cite the preprint:
BiU-Net: a Biological-informed U-Net for Genotype Imputation. Research Square (2025). https://doi.org/10.21203/rs.3.rs-6797863/v1
@article{biunet2025,
title = {BiU-Net: a Biological-informed U-Net for Genotype Imputation},
journal = {Research Square (preprint)},
year = {2025},
doi = {10.21203/rs.3.rs-6797863/v1},
url = {https://doi.org/10.21203/rs.3.rs-6797863/v1}
}Released under the MIT License.