Skip to content

mbn312/VisionTransformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

24 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

VisionTransformer

Open In Colab

Medium

VisionTransformer is a small PyTorch project for image classification with a Vision Transformer (ViT). The repository began as a from-scratch MNIST tutorial and now includes a configurable training script, dataset presets for MNIST, FashionMNIST, and CIFAR-10, and a modular implementation split across models/ and data/.

The current source code is the authoritative reference for how this repo works today. The linked Colab notebook and article are still useful background material, but they describe an earlier, MNIST-only tutorial version of the project.

What This Repo Includes

  • A ViT implementation built from patch embeddings, positional encodings, transformer encoder blocks, and a classification head.
  • Dataset-specific defaults in data/configs.py for MNIST, FashionMNIST, and CIFAR-10.
  • A CLI training entrypoint in training.py with overrides for model width, patch size, augmentation, optimization, checkpoint path, and more.
  • Optional train/validation splits, cosine learning rate scheduling, warmup, and AdamW weight decay.
  • A Colab notebook in notebooks/VisionTransformer.ipynb that mirrors the original educational walkthrough.

Repository Layout

Quick Start

Run commands from the repository root.

1. Create an Environment

python3 -m venv .venv
source .venv/bin/activate
python -m pip install --upgrade pip

2. Install Dependencies

The pinned requirements.txt expects CUDA 12.1 PyTorch wheels:

pip install --extra-index-url https://download.pytorch.org/whl/cu121 -r requirements.txt

If you are on CPU-only hardware or a different CUDA version, install matching torch and torchvision wheels first, then install the remaining packages:

pip install --index-url https://download.pytorch.org/whl/cpu torch==2.2.0 torchvision==0.17.0
pip install datasets==2.17.1 numpy==2.0.0

3. Train a Model

MNIST with default settings:

python3 training.py --dataset mnist

FashionMNIST with a validation split and validation accuracy reporting:

python3 training.py --dataset fashion_mnist --train_val_split 55000 5000 --get_val_accuracy True

CIFAR-10 with an explicit checkpoint path:

python3 training.py --dataset cifar10 --model_location checkpoints/cifar10_vit.pt

Training downloads the training dataset if needed, saves the best checkpoint seen during training, and then evaluates that checkpoint on the test set.

Runtime Behavior

  • Device selection is automatic: CUDA if available, otherwise CPU.
  • Dataset defaults come from data/configs.py and are overridden by CLI arguments.
  • Training data is downloaded into ../datasets relative to the current working directory.
  • The default checkpoint path is model.pt in the current working directory.
  • Validation is only used if train_val_split leaves examples for a validation set.

Further Documentation

Background Resources

Use those resources to understand the original teaching flow. Use this repository's Python modules and the docs in docs/ to understand the current implementation.

Known Caveats

  • The tutorial materials and notebook are older than the current training pipeline and do not document later additions like validation splits, CIFAR-10 support, warmup, cosine scheduling, or AdamW.
  • requirements.txt pins CUDA-specific PyTorch builds, so installation instructions must be adjusted for CPU-only systems or different CUDA versions.
  • The repository does not include an automated test suite or published benchmark table, so training behavior should still be validated in your target environment.
  • The tutorial materials are still useful for concepts, but the implementation has evolved. See docs/tutorial-drift.md for the current differences.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors