Skip to content

fix(jax): avoid device lookup during type embedding tracing#5736

Open
njzjz wants to merge 1 commit into
deepmodeling:masterfrom
njzjz:fix/jax-type-embed-tracing
Open

fix(jax): avoid device lookup during type embedding tracing#5736
njzjz wants to merge 1 commit into
deepmodeling:masterfrom
njzjz:fix/jax-type-embed-tracing

Conversation

@njzjz

@njzjz njzjz commented Jul 5, 2026

Copy link
Copy Markdown
Member

Summary

Fix a JAX/Flax NNX tracing failure in TypeEmbedNet.call() by avoiding a hard requirement that traced values expose a readable .device attribute.

Background

JAX training can call TypeEmbedNet.call() from inside nnx.jit / nnx.grad. In that context the type-embedding weights are represented by NNX/JAX traced values such as DynamicJaxprTracer. The previous code passed

device=array_api_compat.device(sample_array)

to xp.eye, and did the same for the padding xp.zeros. For an NNX traced parameter, array_api_compat.device(...) eventually tries to read .device; the tracer does not provide that attribute during tracing, so the training step fails with an error like:

AttributeError: DynamicJaxprTracer has no attribute device

or, through JAX core wrapping:

AttributeError: 'ShapedArray' object has no attribute 'device'

Change

Add a small local helper in deepmd/dpmodel/utils/type_embed.py that returns array_api_compat.device(array) when available, and falls back to None when the backend value has no readable device during tracing. Passing device=None lets JAX create the constants on its default traced device, while preserving the existing explicit-device behavior for backends that expose it normally.

This keeps the change limited to the type-embedding constants that triggered the crash, rather than changing global array API device handling.

Tests

Added source/tests/jax/test_type_embed.py to cover both:

  • TypeEmbedNet.call() under nnx.jit
  • TypeEmbedNet.call() under nnx.jit + nnx.grad

The old implementation fails this test during tracing before producing an output.

Validation run locally:

pytest source/tests/jax/test_type_embed.py -v
ruff check .
ruff format .

Summary by CodeRabbit

  • Bug Fixes

    • Improved compatibility for model execution on array backends that do not support device detection in all cases.
    • Reduced failures when padding or initializing arrays during type embedding operations.
  • Tests

    • Added JAX coverage to verify the type embedding model runs under JIT, supports gradient tracing, returns the expected output shape, and produces valid numeric results.

Copilot AI review requested due to automatic review settings July 5, 2026 13:12

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot was unable to review this pull request because the user who requested the review has reached their quota limit.

@dosubot dosubot Bot added the bug label Jul 5, 2026
@github-actions github-actions Bot added the Python label Jul 5, 2026
@coderabbitai

coderabbitai Bot commented Jul 5, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

Adds a _array_device_or_none helper in deepmd/dpmodel/utils/type_embed.py that safely derives an array's device, falling back to None on AttributeError, and uses it in TypeEmbedNet.call()'s padding logic. Adds a new JAX test verifying call() works under nnx.jit/nnx.grad.

Changes

Device Inference Fallback

Layer / File(s) Summary
Device fallback helper and usage
deepmd/dpmodel/utils/type_embed.py
Adds _array_device_or_none() to safely derive array device, falling back to None on AttributeError; used in TypeEmbedNet.call() for xp.eye and xp.zeros device arguments.
JAX jit/grad compatibility test
source/tests/jax/test_type_embed.py
New test builds TypeEmbedNet, runs call() under nnx.jit, traces gradients via nnx.grad, and asserts output shape, non-None gradients, and no NaNs.

Estimated code review effort: 1 (Trivial) | ~5 minutes

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the JAX tracing fix in TypeEmbedNet and is concise and specific.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
source/tests/jax/test_type_embed.py (1)

30-39: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Grad test only checks for non-None, not shape/values.

self.assertIsNotNone(grad(type_embedding)) only confirms the traced grad call doesn't crash; it doesn't verify the gradient has the expected structure or is free of NaNs, unlike the forward output check. Consider tightening this to mirror the forward assertions (e.g., checking gradient leaf shapes/no-NaN) for a more meaningful regression guard.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/jax/test_type_embed.py` around lines 30 - 39, The grad test in
the type embedding test only asserts the result is non-None, so it does not
verify the gradient’s structure or numerical validity. Tighten the `grad(model)`
assertion to inspect the returned gradient from `nnx.grad(loss)(model)` more
thoroughly, using the same `type_embedding` setup to check expected leaf shapes
and that the values are finite/no NaNs, so the test actually guards the gradient
behavior rather than just successful execution.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@source/tests/jax/test_type_embed.py`:
- Around line 30-39: The grad test in the type embedding test only asserts the
result is non-None, so it does not verify the gradient’s structure or numerical
validity. Tighten the `grad(model)` assertion to inspect the returned gradient
from `nnx.grad(loss)(model)` more thoroughly, using the same `type_embedding`
setup to check expected leaf shapes and that the values are finite/no NaNs, so
the test actually guards the gradient behavior rather than just successful
execution.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: d196f460-f0bd-40c7-8a4c-e161d1c856cd

📥 Commits

Reviewing files that changed from the base of the PR and between ffe57a3 and 219c168.

📒 Files selected for processing (2)
  • deepmd/dpmodel/utils/type_embed.py
  • source/tests/jax/test_type_embed.py

@njzjz njzjz requested a review from wanghan-iapcm July 5, 2026 13:17
@codecov

codecov Bot commented Jul 5, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 81.16%. Comparing base (ffe57a3) to head (219c168).

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5736      +/-   ##
==========================================
- Coverage   81.29%   81.16%   -0.13%     
==========================================
  Files         990      990              
  Lines      111020   111022       +2     
  Branches     4232     4234       +2     
==========================================
- Hits        90252    90110     -142     
- Misses      19242    19388     +146     
+ Partials     1526     1524       -2     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants