Skip to content

perf(distributed): add retrieval tuning knobs#2452

Open
yuhezhang-ai wants to merge 12 commits into
mainfrom
yuhez/perf/retrieval-distributed-tuning
Open

perf(distributed): add retrieval tuning knobs#2452
yuhezhang-ai wants to merge 12 commits into
mainfrom
yuhez/perf/retrieval-distributed-tuning

Conversation

@yuhezhang-ai

@yuhezhang-ai yuhezhang-ai commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

What does this PR do ?

Adds distributed tuning and retrieval training fixes used for Nemotron VL retrieval fine-tuning benchmarks. The main goals are to make DDP configurable and faster for retrieval, preserve retrieval optimizer parameter groups, and make loss logging easier to compare with nemo-retriever-research.

Changelog

  • Expose additional DDP config flags:
    • broadcast_buffers
    • find_unused_parameters
    • static_graph
    • bucket_cap_mb
    • gradient_as_bucket_view
  • Forward the new DDP flags into torch.nn.parallel.DistributedDataParallel.
  • Expose FSDP2 reshard_after_forward so no-reshard variants can be configured from YAML.
  • Thread reshard_after_forward through the FSDP2 manager and recursive sharding helper.
  • Speed up DDP gradient clipping by clipping directly on DDP bucket gradients when available.
  • Fix DDP recipe metric/logging by reducing it even if no mesh.
  • Fix retrieval recipe attribute access when the bi-encoder is wrapped by DDP.
  • Preserve retrieval decay/no-decay optimizer parameter groups when constructing typed optimizer configs via build_from_param_groups(...).
  • Add step_scheduler.loss_average_window_steps with default 50.
  • Log retrieval loss_avg_window alongside raw per-step loss, so noisy retrieval loss curves can be compared more easily in W&B or local Slurm logs.
  • Add/update unit coverage for DDP config parsing, DDP manager wiring, FSDP2 reshard override behavior, optimizer param-group construction, retrieval optimizer setup, and averaged retrieval loss logging.

Notes

  • loss_avg_window is intentionally scoped under step_scheduler because it controls training-step logging behavior, similar to log_every_steps and log_remote_every_steps.
  • The averaged loss is logging-only. It does not change the training loss, backward pass, gradient scaling, optimizer step, or scheduler behavior.
  • The optimizer API change keeps the existing one-line simple optimizer construction path intact while adding a path for recipes that need explicit parameter groups.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?

Additional Information

  • Related to: Nemotron VL 1B retrieval fine-tuning performance/loss-curve debugging.
  • Branch was rebased onto latest main before the latest update.

Validation on current rebased branch:

  • git diff --check origin/main...HEAD
  • python -m py_compile nemo_automodel/components/optim/optimizer.py nemo_automodel/components/training/step_scheduler.py nemo_automodel/recipes/retrieval/train_bi_encoder.py nemo_automodel/components/distributed/ddp.py nemo_automodel/components/training/utils.py

Earlier validation before the latest rebase:

  • source work/runs/_shared/env.sh && uv run --no-sync ruff check ...
  • source work/runs/_shared/env.sh && uv run --no-sync ruff format --check ...
  • git diff --check
  • source work/runs/_shared/env.sh && uv run --no-sync pytest tests/unit_tests/recipes/test_dist_setup.py tests/unit_tests/distributed/test_parallelizer.py tests/unit_tests/distributed/test_ddp_manager.py -q
  • Result: 126 passed, 17 warnings

Experiment sanity checks:

  • Real-data DDP torch AdamW fp32, weight_decay=0.1, 40-minute run reached step 1763 and ended due to Slurm time limit without a Python exception.
  • Real-data DDP TE FusedAdam bf16, weight_decay=0.1, 40-minute run reached step 2002 and ended due to Slurm time limit without a Python exception.
  • Reconstructed local 50-step averaged loss curves from Slurm logs show both DDP variants descending normally.

Latest Update (2026-06-13)

  • Added configurable retrieval autocast via distributed.autocast_dtype; default remains disabled.
  • Wired top-level compile: config into retrieval bi-encoder model instantiation for DDP compile experiments.

Additional validation:

  • uv run pytest tests/unit_tests/recipes/test_retrieval_bi_encoder_recipe.py tests/unit_tests/recipes/test_dist_utils.py -q (77 passed)
  • uv run ruff check on touched retrieval/distributed files

@copy-pr-bot

copy-pr-bot Bot commented Jun 8, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
@yuhezhang-ai yuhezhang-ai force-pushed the yuhez/perf/retrieval-distributed-tuning branch from c564193 to fcfcc72 Compare June 10, 2026 01:30
Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
@yuhezhang-ai yuhezhang-ai force-pushed the yuhez/perf/retrieval-distributed-tuning branch from fcfcc72 to 3387071 Compare June 10, 2026 01:49
Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
@yuhezhang-ai

Copy link
Copy Markdown
Contributor Author

/ok to test e703720

Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
@yuhezhang-ai

Copy link
Copy Markdown
Contributor Author

/ok to test 3239799

Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
@yuhezhang-ai

Copy link
Copy Markdown
Contributor Author

/ok to test 4d43a66

Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
@yuhezhang-ai

Copy link
Copy Markdown
Contributor Author

/ok to test 55343c9

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant