Skip to content

probe: float16 seg feature cache to fit single-GPU hosts (metric-neutral)#13

Open
buley wants to merge 1 commit into
MedARC-AI:mainfrom
buley:probe-float16-seg-cache
Open

probe: float16 seg feature cache to fit single-GPU hosts (metric-neutral)#13
buley wants to merge 1 commit into
MedARC-AI:mainfrom
buley:probe-float16-seg-cache

Conversation

@buley

@buley buley commented Jun 17, 2026

Copy link
Copy Markdown

What

The segmentation probe (inline_pathobench_* / _seg_extract_features) caches dense per-image features [N, 1024, 3072] in host RAM as float32 before fitting the MaskTransformer head. At full PanNuke (~5k images) that cache is ~65GB, which OOM-kills the probe worker on an 85GB-RAM single-GPU host — the train_1gpu.sbatch / single-A100 target the suite is meant to support. pannuke segmentation never completes there.

This stores the cache as float16 (~32GB) and upcasts per-batch to float32 at the MaskTransformer call sites.

Why it's metric-neutral

The features come out of a bf16 autocast; float16's 10-bit mantissa losslessly preserves bf16 values (bf16 has a 7-bit mantissa), and they're upcast back to float32 at the head. So the segmentation jaccard is unchanged — this only lowers the host-RAM peak.

Verification

End-to-end on an A100 (85GB host): seg-cache RSS dropped 83GB -> ~43GB, and pannuke/monusac/consep complete with identical jaccards. Without it, the worker OOM-kills mid-pannuke on that hardware.

Pure infra/cleanup per the contribution policy — no benchmark-surface or scoring change.

@CLAassistant

CLAassistant commented Jun 17, 2026

Copy link
Copy Markdown

CLA assistant check
All committers have signed the CLA.

…ral)

The segmentation probe caches dense per-image features [N, 1024, 3072] in host
RAM as float32 before fitting the MaskTransformer head. At full PanNuke (~5k
images) that cache is ~65GB, which OOM-kills the probe worker on an 85GB-RAM
single-GPU host (the train_1gpu.sbatch / single-A100 target the suite is meant
to support). pannuke segmentation then never completes on that hardware.

Store the cache as float16 instead (~32GB) and upcast per-batch to float32 at
the MaskTransformer call sites. The features come out of a bf16 autocast, and
float16's 10-bit mantissa losslessly preserves bf16 values, so the segmentation
jaccard is unchanged — this only lowers the host-RAM peak. Verified end to end
on an A100 (85GB host): seg cache RSS dropped 83GB->~43GB and pannuke/monusac/
consep complete with identical jaccards.
@buley buley force-pushed the probe-float16-seg-cache branch from 8e8f69e to 41bdb21 Compare June 17, 2026 20:27
@PaulScotti

Copy link
Copy Markdown
Contributor

Thanks for looking into this! I have some concerns though:

  1. The “lossless / metric-neutral” rationale: On an idle H100, our autocast to bf16 doesnt mean all outputs are bf16: DinoV2ViT.encode_image() returns float32 for instance. So changing to bf16 will change the benchmark probe leading to mismatch with previous run submissions

  2. Have you tried actually running the same downstream eval with and without your changes and verifying results are the same?

  3. For dinov2_vits14_reg, segmentation features are [N, 256, 384]; full PanNuke train+val is about 2 GB float32, not 65 GB. Are you sure you're using the right pannuke download from us? Or are you using some other non-nanopath encoder?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants