[NNX] Fix multi-token-prediction (MTP) eval under pure_nnx#4133
Open
ecnal-cienet wants to merge 1 commit into
Open
[NNX] Fix multi-token-prediction (MTP) eval under pure_nnx#4133ecnal-cienet wants to merge 1 commit into
ecnal-cienet wants to merge 1 commit into
Conversation
|
🤖 Hi @ecnal-cienet, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
|
🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details. |
2 similar comments
|
🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details. |
|
🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details. |
|
🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details. |
bvandermoon
approved these changes
Jun 10, 2026
Two NNX-only bugs blocked MTP eval (mtp_num_layers>0 + mtp_eval_target_module) under pure_nnx=True. Both exist on main and surface only on the NNX path (main defaults to Linen). 1. logical_axis_rules: the MTP projection kernel is tagged with logical axis 'concat_embed' (and the fused gated-MLP wi uses 'num_activations'), but neither was in logical_axis_rules. get_abstract_model resolves logical->physical via flax's from_sharding_rules, which returns unmapped names verbatim, so the literal name reached NamedSharding and was rejected. Linen replicates unmapped axes by default. Add both as replicated ([]) (base.yml). 2. loss_fn: mtp_losses/mtp_acceptance are nnx.Variable subclasses, not nnx.Intermediate, so nnx.pop(model, nnx.Intermediate) missed them. They leaked into the returned train state and broke out_shardings at compile, and were absent from intermediate_outputs so calculate_mtp_loss read None and the MTP loss was silently dropped. Pop them into intermediate_outputs under their collection names and exclude them from the returned state (train.py). Validated on CPU (gpt3-52k, mtp_num_layers=1 mtp_eval_target_module=1): trains, mtp_loss is non-zero in the total loss, and mtp_acceptance_rate is reported in eval.
880e677 to
3bf55a5
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Two NNX-only bugs prevent multi-token-prediction (MTP) eval (
mtp_num_layers>0withmtp_eval_target_module) from running underpure_nnx=True. Both already live onmainand surface only on the NNX path, sincemainstill defaults to Linen (pure_nnx=False) — invisible by default, but they break the moment NNX is selected. Fixing them now keeps MTP working once the NNX defaults flip lands.1. Missing
logical_axis_rulesforconcat_embed(andnum_activations).The MTP projection kernel (
DenseGeneral,2*emb_dim → emb_dim) is taggedkernel_axes=("concat_embed", "embed")inmulti_token_prediction.py, and the fused gated-MLPwiuses("embed", "num_activations", "mlp")inlinears.py— but neitherconcat_embednornum_activationsis inlogical_axis_rules. On the NNX path,nnx.get_abstract_model(train_utils.py) resolves logical→physical via flaxcore.spmd.from_sharding_rules, which returns unmapped names verbatim (no replicate fallback), so the literal name reachesNamedShardingand is rejected at state init:Linen never hits this because
linen.spmd.logical_to_mesh_axesmaps unmatched names toNone(replicated). Fix: add both as replicated ([]) inbase.yml. Note that mappingconcat_embedtoembed's physical axes would be wrong — they are two axes of the same kernel, so reusingfsdp/context/expertraisesDuplicateSpecError.2. MTP sown variables leak into the train state and silently zero the MTP loss.
mtp_losses/mtp_acceptancearennx.Variablesubclasses, notnnx.Intermediate, soloss_fn'snnx.pop(model, nnx.Intermediate)never extracts them. Two consequences:model.mtp_blockand leak into the returned train state, soout_shardings(built from the abstract param state) no longer matches atp_train_step.lower(...).compile():intermediate_outputs, socalculate_mtp_lossreadsNoneand returns0.0— the MTP loss is silently dropped from training (a correctness bug independent of the crash).Fix (
train.pyloss_fn): popmtp_losses/mtp_acceptanceintointermediate_outputsunder their collection names so the loss/acceptance extractors find them, and exclude them from the returned state alongsidennx.Intermediate.Tests
Validated on TPU — both errors are structural (sharding resolution and
out_shardingspytree shape), so they reproduce without a TPU:End-to-end TPU run (
gpt3-6b, the27_mtp_evalconfig) recommended before merge.Stats
configs/base.yml(+6, twological_axis_rules) andtrainers/pre_train/train.py(+19 / −4, theloss_fnpop/filter).Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.