Skip to content

[NNX] Fix multi-token-prediction (MTP) eval under pure_nnx#4133

Open
ecnal-cienet wants to merge 1 commit into
mainfrom
feat/nnx-e2e-fixes
Open

[NNX] Fix multi-token-prediction (MTP) eval under pure_nnx#4133
ecnal-cienet wants to merge 1 commit into
mainfrom
feat/nnx-e2e-fixes

Conversation

@ecnal-cienet

@ecnal-cienet ecnal-cienet commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator

Description

Two NNX-only bugs prevent multi-token-prediction (MTP) eval (mtp_num_layers>0 with mtp_eval_target_module) from running under pure_nnx=True. Both already live on main and surface only on the NNX path, since main still defaults to Linen (pure_nnx=False) — invisible by default, but they break the moment NNX is selected. Fixing them now keeps MTP working once the NNX defaults flip lands.

1. Missing logical_axis_rules for concat_embed (and num_activations).
The MTP projection kernel (DenseGeneral, 2*emb_dim → emb_dim) is tagged kernel_axes=("concat_embed", "embed") in multi_token_prediction.py, and the fused gated-MLP wi uses ("embed", "num_activations", "mlp") in linears.py — but neither concat_embed nor num_activations is in logical_axis_rules. On the NNX path, nnx.get_abstract_model (train_utils.py) resolves logical→physical via flax core.spmd.from_sharding_rules, which returns unmapped names verbatim (no replicate fallback), so the literal name reaches NamedSharding and is rejected at state init:

ValueError: Resource axis: concat_embed of P('concat_embed', ('fsdp','context','expert')) is not found in mesh

Linen never hits this because linen.spmd.logical_to_mesh_axes maps unmatched names to None (replicated). Fix: add both as replicated ([]) in base.yml. Note that mapping concat_embed to embed's physical axes would be wrong — they are two axes of the same kernel, so reusing fsdp/context/expert raises DuplicateSpecError.

2. MTP sown variables leak into the train state and silently zero the MTP loss.
mtp_losses / mtp_acceptance are nnx.Variable subclasses, not nnx.Intermediate, so loss_fn's nnx.pop(model, nnx.Intermediate) never extracts them. Two consequences:

  • They remain on model.mtp_block and leak into the returned train state, so out_shardings (built from the abstract param state) no longer matches at p_train_step.lower(...).compile():
    pytree structure error ... at pjit out_shardings[0]['model']['mtp_block']:
      prefix subtree has 1 child key (mtp_layer_1)
      but the full pytree has 5 (losses, mtp_layer_1, mtp_mask, mtp_preds, weights)
    
  • They are absent from intermediate_outputs, so calculate_mtp_loss reads None and returns 0.0the MTP loss is silently dropped from training (a correctness bug independent of the crash).

Fix (train.py loss_fn): pop mtp_losses / mtp_acceptance into intermediate_outputs under their collection names so the loss/acceptance extractors find them, and exclude them from the returned state alongside nnx.Intermediate.

Tests

Validated on TPU — both errors are structural (sharding resolution and out_shardings pytree shape), so they reproduce without a TPU:

PYTHONPATH=src python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
  pure_nnx=True enable_nnx=True pure_nnx_decoder=True \
  model_name=gpt3-52k dataset_type=synthetic steps=6 \
  mtp_num_layers=1 mtp_eval_target_module=1 eval_interval=2 eval_steps=2 \
  attention=dot_product ici_fsdp_parallelism=4 skip_jax_distributed_system=True \
  tokenizer_path=src/maxtext/assets/tokenizers/tokenizer.llama2 \
  base_output_directory=/tmp/mtp run_name=r \
  --xla_force_host_platform_device_count=4

End-to-end TPU run (gpt3-6b, the 27_mtp_eval config) recommended before merge.

Stats

  • Diff: +21 / −4 across 2 files — configs/base.yml (+6, two logical_axis_rules) and trainers/pre_train/train.py (+19 / −4, the loss_fn pop/filter).

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@github-actions

Copy link
Copy Markdown

🤖 Hi @ecnal-cienet, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@codecov

codecov Bot commented Jun 10, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 40.00000% with 3 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/trainers/pre_train/train.py 40.00% 2 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details.

2 similar comments
@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details.

Two NNX-only bugs blocked MTP eval (mtp_num_layers>0 + mtp_eval_target_module)
under pure_nnx=True. Both exist on main and surface only on the NNX path (main
defaults to Linen).

1. logical_axis_rules: the MTP projection kernel is tagged with logical axis
   'concat_embed' (and the fused gated-MLP wi uses 'num_activations'), but neither
   was in logical_axis_rules. get_abstract_model resolves logical->physical via
   flax's from_sharding_rules, which returns unmapped names verbatim, so the literal
   name reached NamedSharding and was rejected. Linen replicates unmapped axes by
   default. Add both as replicated ([])  (base.yml).

2. loss_fn: mtp_losses/mtp_acceptance are nnx.Variable subclasses, not
   nnx.Intermediate, so nnx.pop(model, nnx.Intermediate) missed them. They leaked
   into the returned train state and broke out_shardings at compile, and were
   absent from intermediate_outputs so calculate_mtp_loss read None and the MTP loss
   was silently dropped. Pop them into intermediate_outputs under their collection
   names and exclude them from the returned state (train.py).

Validated on CPU (gpt3-52k, mtp_num_layers=1 mtp_eval_target_module=1): trains,
mtp_loss is non-zero in the total loss, and mtp_acceptance_rate is reported in eval.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants