Skip to content

Fix rank mismatch in MaxText synthetic data sharding#4122

Open
lukebaumann wants to merge 2 commits into
AI-Hypercomputer:mainfrom
lukebaumann:fix-synthetic-sharding
Open

Fix rank mismatch in MaxText synthetic data sharding#4122
lukebaumann wants to merge 2 commits into
AI-Hypercomputer:mainfrom
lukebaumann:fix-synthetic-sharding

Conversation

@lukebaumann

@lukebaumann lukebaumann commented Jun 9, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR fixes a rank mismatch issue in MaxText synthetic data sharding during data loading.

Root Cause

SyntheticDataIterator was using the legacy config.data_sharding which resolved to a 1D sharding spec P(('data', 'fsdp')) (after filtering). When applied to 2D output tensors of shape (batch, seq), JAX sharding validation failed with AssertionError: (1, 2) (rank mismatch) on JAX builds that strictly enforce this check.

Solution

Modified SyntheticDataIterator to use sharding.get_input_data_sharding(config, mesh). This helper uses config.input_data_sharding_logical_axes which correctly resolves to a 2D sharding spec P(('data', 'fsdp'), None), matching the rank of the output tensors.

Also removed the unused PartitionSpec as P import in synthetic_data_processing.py.

Tests

Added a new unit test tests/unit/synthetic_data_test.py which is parameterized to test both auto and explicit shard_mode:

  1. Forces 4 CPU devices.
  2. Creates a 2x2 mesh.
  3. Initializes SyntheticDataIterator with llama3.1-8b config.
  4. Verifies the output shape is (8, 16) and sharding is exactly P(('data', 'fsdp'), None).

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 9, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Comment thread src/maxtext/input_pipeline/synthetic_data_processing.py
Comment thread tests/unit/synthetic_data_test.py
* Change SyntheticDataIterator to use get_input_data_sharding instead of manual 1D sharding.
* This ensures the sharding spec is 2D, matching the rank of the output tensors.
* Fixes AssertionError: (1, 2) in JAX sharding validation on some JAX builds.
* Remove unused PartitionSpec import in synthetic_data_processing.py.
* Add parameterized unit test `tests/unit/synthetic_data_test.py` to verify synthetic data sharding works for both 'auto' and 'explicit' shard modes.
@lukebaumann lukebaumann force-pushed the fix-synthetic-sharding branch from 91b79f9 to 2fea2ba Compare June 10, 2026 20:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants