fix(jax): avoid device lookup during type embedding tracing#5736
Conversation
📝 WalkthroughWalkthroughAdds a ChangesDevice Inference Fallback
Estimated code review effort: 1 (Trivial) | ~5 minutes 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
source/tests/jax/test_type_embed.py (1)
30-39: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winGrad test only checks for non-
None, not shape/values.
self.assertIsNotNone(grad(type_embedding))only confirms the traced grad call doesn't crash; it doesn't verify the gradient has the expected structure or is free of NaNs, unlike theforwardoutput check. Consider tightening this to mirror theforwardassertions (e.g., checking gradient leaf shapes/no-NaN) for a more meaningful regression guard.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/jax/test_type_embed.py` around lines 30 - 39, The grad test in the type embedding test only asserts the result is non-None, so it does not verify the gradient’s structure or numerical validity. Tighten the `grad(model)` assertion to inspect the returned gradient from `nnx.grad(loss)(model)` more thoroughly, using the same `type_embedding` setup to check expected leaf shapes and that the values are finite/no NaNs, so the test actually guards the gradient behavior rather than just successful execution.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@source/tests/jax/test_type_embed.py`:
- Around line 30-39: The grad test in the type embedding test only asserts the
result is non-None, so it does not verify the gradient’s structure or numerical
validity. Tighten the `grad(model)` assertion to inspect the returned gradient
from `nnx.grad(loss)(model)` more thoroughly, using the same `type_embedding`
setup to check expected leaf shapes and that the values are finite/no NaNs, so
the test actually guards the gradient behavior rather than just successful
execution.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: d196f460-f0bd-40c7-8a4c-e161d1c856cd
📒 Files selected for processing (2)
deepmd/dpmodel/utils/type_embed.pysource/tests/jax/test_type_embed.py
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #5736 +/- ##
==========================================
- Coverage 81.29% 81.16% -0.13%
==========================================
Files 990 990
Lines 111020 111022 +2
Branches 4232 4234 +2
==========================================
- Hits 90252 90110 -142
- Misses 19242 19388 +146
+ Partials 1526 1524 -2 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Summary
Fix a JAX/Flax NNX tracing failure in
TypeEmbedNet.call()by avoiding a hard requirement that traced values expose a readable.deviceattribute.Background
JAX training can call
TypeEmbedNet.call()from insidennx.jit/nnx.grad. In that context the type-embedding weights are represented by NNX/JAX traced values such asDynamicJaxprTracer. The previous code passedto
xp.eye, and did the same for the paddingxp.zeros. For an NNX traced parameter,array_api_compat.device(...)eventually tries to read.device; the tracer does not provide that attribute during tracing, so the training step fails with an error like:or, through JAX core wrapping:
Change
Add a small local helper in
deepmd/dpmodel/utils/type_embed.pythat returnsarray_api_compat.device(array)when available, and falls back toNonewhen the backend value has no readable device during tracing. Passingdevice=Nonelets JAX create the constants on its default traced device, while preserving the existing explicit-device behavior for backends that expose it normally.This keeps the change limited to the type-embedding constants that triggered the crash, rather than changing global array API device handling.
Tests
Added
source/tests/jax/test_type_embed.pyto cover both:TypeEmbedNet.call()undernnx.jitTypeEmbedNet.call()undernnx.jit+nnx.gradThe old implementation fails this test during tracing before producing an output.
Validation run locally:
Summary by CodeRabbit
Bug Fixes
Tests