Skip to content

fix(jax): support charge-spin savedmodel export#5737

Open
njzjz wants to merge 2 commits into
deepmodeling:masterfrom
njzjz:fix/jax-charge-spin-savedmodel
Open

fix(jax): support charge-spin savedmodel export#5737
njzjz wants to merge 2 commits into
deepmodeling:masterfrom
njzjz:fix/jax-charge-spin-savedmodel

Conversation

@njzjz

@njzjz njzjz commented Jul 5, 2026

Copy link
Copy Markdown
Member

Summary

  • export charge_spin-aware JAX SavedModel signatures only for models that need charge/spin conditioning
  • add Python and C++ JAX runtime fallback to stored default_chg_spin, while keeping old SavedModel signatures compatible
  • allow C++ header wrapper to broadcast one-frame charge_spin input safely and add focused jax2tf regressions

Tests

  • ruff format .
  • ruff check .
  • pytest source/jax2tf_tests/test_make_model.py source/jax2tf_tests/test_serialization.py source/jax2tf_tests/test_tfmodel.py -q
  • c++ -std=c++17 -Isource/api_c/include -fsyntax-only -x c++ ...
  • c++ -std=c++17 -fsyntax-only ... source/api_cc/src/DeepPotJAX.cc

Notes

  • cmake --build source/build --target deepmd_cc -j2 still fails before compiling this change because local CMake TensorFlow detection stops at cmake/Findtensorflow.cmake:234 (file): file unknown error.
  • Review requested from @anyangml for charge_spin C++ runtime context and @iProzd for charge_spin/DPA3 behavior.

Summary by CodeRabbit

  • New Features
    • Added optional charge/spin conditioning to JAX, TensorFlow, and C++ model interfaces.
    • Models can expose charge/spin embedding support and default values; charge/spin can be supplied per frame or broadcast from a single vector.
  • Bug Fixes
    • Improved behavior when charge/spin is omitted by using model defaults when available.
    • Added stricter validation and clearer errors for unsupported charge/spin usage and invalid shape/dimension mismatches.
  • Tests
    • Added/updated coverage for charge/spin forwarding, SavedModel serialization, and broadcasting/validation.

@njzjz njzjz requested review from anyangml, Copilot and iProzd July 5, 2026 14:48

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot was unable to review this pull request because the user who requested the review has reached their quota limit.

@coderabbitai

coderabbitai Bot commented Jul 5, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 404432ff-0fd5-4970-8e8d-1dcaec285c49

📥 Commits

Reviewing files that changed from the base of the PR and between 688879e and c3ea9ba.

📒 Files selected for processing (8)
  • deepmd/jax/infer/deep_eval.py
  • deepmd/jax/jax2tf/make_model.py
  • deepmd/jax/jax2tf/tfmodel.py
  • source/api_c/include/deepmd.hpp
  • source/api_cc/src/DeepPotJAX.cc
  • source/jax2tf_tests/test_deep_eval.py
  • source/jax2tf_tests/test_serialization.py
  • source/jax2tf_tests/test_tfmodel.py
💤 Files with no reviewable changes (1)
  • source/api_c/include/deepmd.hpp
🚧 Files skipped from review as they are similar to previous changes (4)
  • source/jax2tf_tests/test_tfmodel.py
  • deepmd/jax/infer/deep_eval.py
  • source/api_cc/src/DeepPotJAX.cc
  • deepmd/jax/jax2tf/tfmodel.py

📝 Walkthrough

Walkthrough

This PR adds optional charge_spin support across JAX inference, JAX-to-TF export and wrapper code, and the C++ DeepPot/DeepPotJAX APIs, including capability checks, default handling, tiling/broadcasting, and tests.

Changes

Charge-spin propagation

Layer / File(s) Summary
DeepEval charge_spin support
deepmd/jax/infer/deep_eval.py, source/jax2tf_tests/test_deep_eval.py
eval and _eval_model accept optional charge_spin, forward it when supported, reject unsupported conditioning, and add capability helper methods.
JAX2TF export plumbing
deepmd/jax/jax2tf/make_model.py, deepmd/jax/jax2tf/serialization.py, source/jax2tf_tests/test_make_model.py
model_call_from_call_lower and SavedModel export wiring conditionally include charge_spin in signatures, dispatch, and exported call functions, with tests covering forwarding.
TFModelWrapper charge_spin handling
deepmd/jax/jax2tf/tfmodel.py, source/jax2tf_tests/test_tfmodel.py, source/jax2tf_tests/test_serialization.py
TFModelWrapper loads charge-spin metadata, normalizes and validates charge_spin, and exposes accessors; tests cover defaults, broadcast, invalid shapes, and SavedModel loading.
C API charge_spin validation
source/api_c/include/deepmd.hpp
validate_charge_spin gains a tiled scratch vector, and DeepPot/DeepPotModelDevi compute call sites pass it through in all variants.
DeepPotJAX charge_spin integration
source/api_cc/include/DeepPotJAX.h, source/api_cc/src/DeepPotJAX.cc
DeepPotJAX adds charge-spin accessors and signatures, plus input construction, capability loading, TensorFlow input wiring, and computew forwarding.

Estimated code review effort: 4 (Complex) | ~60 minutes

Possibly related PRs

Suggested labels: bug

Suggested reviewers: wanghan-iapcm, iProzd

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main change: adding charge-spin support to JAX SavedModel export.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
source/api_c/include/deepmd.hpp (1)

2326-2387: 🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win

Stale comment contradicts the new charge_spin threading logic.

Line 2337/2427's // charge_spin is not supported via the C-API model-deviation path. directly contradicts the code right below it (and the @param[in] charge_spin doc block at 2321-2324), which now builds charge_spin_tiled_ and forwards charge_spin__ into _DP_DeepPotModelDeviCompute. Either the comment is stale and should be removed, or _DP_DeepPotModelDeviCompute genuinely ignores this parameter for model-deviation (in which case the added tiling/validation here is dead work). Please clarify and fix the comment/doc mismatch.

📝 Suggested comment fix (if the C-API does honor charge_spin here)
-    // charge_spin is not supported via the C-API model-deviation path.
     unsigned int natoms = atype.size();

Also applies to: 2414-2491

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/api_c/include/deepmd.hpp` around lines 2326 - 2387, The charge_spin
handling in compute and the related model-deviation path is inconsistent: the
comment says it is unsupported, but the code now validates, tiles, and forwards
charge_spin__ into _DP_DeepPotModelDeviCompute. Update the stale comment and the
`@param`[in] charge_spin documentation to match the actual behavior, or remove the
new charge_spin_tiled_ / validate_charge_spin wiring in compute if the C-API
path truly ignores it; use the compute method and _DP_DeepPotModelDeviCompute as
the reference points.
🧹 Nitpick comments (6)
deepmd/jax/jax2tf/make_model.py (1)

34-44: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Stale call_lower Callable type hint doesn't reflect the new charge_spin argument.

The call_lower parameter's Callable[[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, bool], dict[...]] signature isn't updated even though the function now conditionally invokes call_lower with an extra charge_spin keyword (Lines 87-94), and downstream call_lower implementations (e.g., source/jax2tf_tests/test_make_model.py's call_lower, serialization.py's dispatch wrappers) explicitly accept a 7th charge_spin: tf.Tensor parameter. Update the type hint so static type checkers reflect the real accepted signature.

Also applies to: 54-54

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/jax/jax2tf/make_model.py` around lines 34 - 44, The call_lower type
hint is stale and still describes the old 6-argument signature, while make_model
now passes charge_spin and downstream call_lower implementations accept it.
Update the Callable annotation in make_model to include the extra charge_spin
tf.Tensor parameter so the signature matches the actual invocation in the
call_lower path and related wrappers/tests.
deepmd/jax/infer/deep_eval.py (1)

252-262: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Document the new charge_spin kwarg.

charge_spin is now a meaningful, extracted keyword (Line 262) but the docstring's **kwargs section still just says "Other parameters" without mentioning it. Since this is a new public capability, document its expected shape/semantics.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/jax/infer/deep_eval.py` around lines 252 - 262, The docstring for the
evaluation method currently leaves the extracted charge_spin kwarg undocumented,
so update the **kwargs section to explicitly describe charge_spin alongside the
other parameters. Use the method in deep_eval.py where charge_spin is popped
from kwargs to locate the spot, and add its expected shape and semantics so
callers know how to pass it and what it represents.
deepmd/jax/jax2tf/tfmodel.py (1)

112-134: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Docstrings omit new charge_spin parameter.

__call__ and call both gained a charge_spin argument, but their docstring Parameters sections still only document do_atomic_virial and earlier params.

📝 Proposed docstring addition
         do_atomic_virial
             If calculate the atomic virial.
+        charge_spin
+            The charge and spin conditioning input. shape: nf x dim_chg_spin
 
         Returns

Also applies to: 157-181

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/jax/jax2tf/tfmodel.py` around lines 112 - 134, The docstrings for
TFModel.__call__ and call are missing the new charge_spin parameter, so update
each Parameters section to document charge_spin alongside the existing coord,
atype, box, fparam, aparam, and do_atomic_virial entries. Keep the wording
consistent with the method signatures in tfmodel.py and ensure the new argument
is described wherever these methods are documented.
source/jax2tf_tests/test_tfmodel.py (1)

1-47: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Add edge-case tests for invalid charge_spin shapes.

Current tests cover the default, single-frame-broadcast, and missing-value branches of _make_charge_spin_input, but not the shape-validation error paths (wrong last dimension, ndim > 2, or a 2D input whose first dimension matches neither 1 nor nframes). These are distinct raise ValueError branches in tfmodel.py (lines 416-420, 423-424) that would otherwise be silently unverified.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/jax2tf_tests/test_tfmodel.py` around lines 1 - 47, Add negative tests
for TFModelWrapper._make_charge_spin_input to cover the shape-validation
branches that are currently untested. In source/jax2tf_tests/test_tfmodel.py,
extend the existing _make_wrapper-based tests to assert ValueError for an input
with the wrong last dimension, for an input with ndim > 2, and for a 2D input
whose first dimension is neither 1 nor nframes. Use the _make_charge_spin_input
method name and the existing charge_spin handling setup to target the distinct
validation paths in tfmodel.py.
source/jax2tf_tests/test_serialization.py (1)

40-113: 📐 Maintainability & Code Quality | 🔵 Trivial | 🏗️ Heavy lift

No test exercises the has_chg_spin=True SavedModel export path.

DummyModel reports get_dim_chg_spin() == 0, so deserialize_to_file's entire charge_spin branch (extended TensorSpecs, extended polymorphic_shapes, tf.cond dispatch with charge_spin, and the charge_spin-aware exported call/call_lower wrappers) is never traced by test_savedmodel_export_contains_xla_call_module. Other new tests (test_make_model.py, test_tfmodel.py) only cover lower-level helpers, not the actual SavedModel export/dispatch logic.

Consider adding a variant of DummyModel with get_dim_chg_spin() > 0 (and a call_common_lower accepting charge_spin) to exercise the export path end-to-end, similar to how the base case validates XlaCallModule presence.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/jax2tf_tests/test_serialization.py` around lines 40 - 113, The
SavedModel export tests do not cover the `has_chg_spin=True` path, so the
charge-spin-specific export/dispatch logic in `deserialize_to_file` is untested.
Add a `DummyModel` variant that reports `get_dim_chg_spin() > 0` and whose
`call_common_lower` accepts `charge_spin`, then extend
`test_savedmodel_export_contains_xla_call_module` (or a sibling test) to export
and load it end-to-end, verifying the charge-spin-aware `call`/`call_lower`
wrappers and `tf.cond` branch are exercised alongside the existing base case.
source/api_c/include/deepmd.hpp (1)

863-929: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value

Remove the unused validate_charge_spin overload
The 3-arg overload is no longer called anywhere, so keeping it alongside the tiling version leaves two different validation rules under the same name. Removing it would avoid future confusion.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/api_c/include/deepmd.hpp` around lines 863 - 929, Remove the unused
3-argument validate_charge_spin overload from deepmd::hpp in deepmd.hpp so only
the tiling-aware version remains. Keep the existing 4-argument
validate_charge_spin as the single source of validation logic, and make sure any
call sites still resolve to that overload so there are no duplicate validation
rules under the same name.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/jax/infer/deep_eval.py`:
- Around line 399-420: The charge_spin argument is being silently ignored in the
JAX inference path when the model does not support charge-spin embeddings.
Update the logic in deep_eval.py around the charge_spin_input/model_kwargs setup
in the inference method to detect when charge_spin is provided but
self.has_chg_spin_ebd() is false, and fail loudly with a ValueError or
RuntimeError (or at minimum emit a warning) instead of dropping the input. Keep
the existing to_jax_array/model_kwargs flow for supported models, but gate it
with an explicit validation using the same charge_spin_input and
has_chg_spin_ebd symbols.

In `@source/api_cc/src/DeepPotJAX.cc`:
- Around line 463-499: The helper make_charge_spin_input currently returns an
empty vector when dchgspin is 0, which silently discards non-empty charge_spin
input instead of rejecting it. Update this function to mirror
deepmd::hpp::validate_charge_spin by checking for non-empty charge_spin or
default_chg_spin before the early return, and throw a deepmd_exception when
charge/spin data is provided but the model does not support it. Keep the
existing size/shape handling in make_charge_spin_input for the supported
dchgspin cases.

---

Outside diff comments:
In `@source/api_c/include/deepmd.hpp`:
- Around line 2326-2387: The charge_spin handling in compute and the related
model-deviation path is inconsistent: the comment says it is unsupported, but
the code now validates, tiles, and forwards charge_spin__ into
_DP_DeepPotModelDeviCompute. Update the stale comment and the `@param`[in]
charge_spin documentation to match the actual behavior, or remove the new
charge_spin_tiled_ / validate_charge_spin wiring in compute if the C-API path
truly ignores it; use the compute method and _DP_DeepPotModelDeviCompute as the
reference points.

---

Nitpick comments:
In `@deepmd/jax/infer/deep_eval.py`:
- Around line 252-262: The docstring for the evaluation method currently leaves
the extracted charge_spin kwarg undocumented, so update the **kwargs section to
explicitly describe charge_spin alongside the other parameters. Use the method
in deep_eval.py where charge_spin is popped from kwargs to locate the spot, and
add its expected shape and semantics so callers know how to pass it and what it
represents.

In `@deepmd/jax/jax2tf/make_model.py`:
- Around line 34-44: The call_lower type hint is stale and still describes the
old 6-argument signature, while make_model now passes charge_spin and downstream
call_lower implementations accept it. Update the Callable annotation in
make_model to include the extra charge_spin tf.Tensor parameter so the signature
matches the actual invocation in the call_lower path and related wrappers/tests.

In `@deepmd/jax/jax2tf/tfmodel.py`:
- Around line 112-134: The docstrings for TFModel.__call__ and call are missing
the new charge_spin parameter, so update each Parameters section to document
charge_spin alongside the existing coord, atype, box, fparam, aparam, and
do_atomic_virial entries. Keep the wording consistent with the method signatures
in tfmodel.py and ensure the new argument is described wherever these methods
are documented.

In `@source/api_c/include/deepmd.hpp`:
- Around line 863-929: Remove the unused 3-argument validate_charge_spin
overload from deepmd::hpp in deepmd.hpp so only the tiling-aware version
remains. Keep the existing 4-argument validate_charge_spin as the single source
of validation logic, and make sure any call sites still resolve to that overload
so there are no duplicate validation rules under the same name.

In `@source/jax2tf_tests/test_serialization.py`:
- Around line 40-113: The SavedModel export tests do not cover the
`has_chg_spin=True` path, so the charge-spin-specific export/dispatch logic in
`deserialize_to_file` is untested. Add a `DummyModel` variant that reports
`get_dim_chg_spin() > 0` and whose `call_common_lower` accepts `charge_spin`,
then extend `test_savedmodel_export_contains_xla_call_module` (or a sibling
test) to export and load it end-to-end, verifying the charge-spin-aware
`call`/`call_lower` wrappers and `tf.cond` branch are exercised alongside the
existing base case.

In `@source/jax2tf_tests/test_tfmodel.py`:
- Around line 1-47: Add negative tests for
TFModelWrapper._make_charge_spin_input to cover the shape-validation branches
that are currently untested. In source/jax2tf_tests/test_tfmodel.py, extend the
existing _make_wrapper-based tests to assert ValueError for an input with the
wrong last dimension, for an input with ndim > 2, and for a 2D input whose first
dimension is neither 1 nor nframes. Use the _make_charge_spin_input method name
and the existing charge_spin handling setup to target the distinct validation
paths in tfmodel.py.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 5b5209f6-f6de-4d2c-9dc3-aa8cf95cab72

📥 Commits

Reviewing files that changed from the base of the PR and between ffe57a3 and 688879e.

📒 Files selected for processing (10)
  • deepmd/jax/infer/deep_eval.py
  • deepmd/jax/jax2tf/make_model.py
  • deepmd/jax/jax2tf/serialization.py
  • deepmd/jax/jax2tf/tfmodel.py
  • source/api_c/include/deepmd.hpp
  • source/api_cc/include/DeepPotJAX.h
  • source/api_cc/src/DeepPotJAX.cc
  • source/jax2tf_tests/test_make_model.py
  • source/jax2tf_tests/test_serialization.py
  • source/jax2tf_tests/test_tfmodel.py

Comment thread deepmd/jax/infer/deep_eval.py
Comment thread source/api_cc/src/DeepPotJAX.cc
@codecov

codecov Bot commented Jul 5, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 56.66667% with 104 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.10%. Comparing base (ffe57a3) to head (688879e).

Files with missing lines Patch % Lines
source/api_cc/src/DeepPotJAX.cc 36.92% 38 Missing and 3 partials ⚠️
deepmd/jax/jax2tf/serialization.py 62.66% 28 Missing ⚠️
deepmd/jax/jax2tf/tfmodel.py 68.88% 14 Missing ⚠️
source/api_c/include/deepmd.hpp 66.66% 8 Missing and 3 partials ⚠️
deepmd/jax/infer/deep_eval.py 58.82% 7 Missing ⚠️
deepmd/jax/jax2tf/make_model.py 0.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5737      +/-   ##
==========================================
- Coverage   81.29%   81.10%   -0.19%     
==========================================
  Files         990      990              
  Lines      111020   111219     +199     
  Branches     4232     4245      +13     
==========================================
- Hits        90252    90207      -45     
- Misses      19242    19482     +240     
- Partials     1526     1530       +4     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@njzjz

njzjz commented Jul 5, 2026

Copy link
Copy Markdown
Member Author

Addressed the review comments in c3ea9ba:\n\n- reject explicit charge_spin when the JAX model does not support charge/spin conditioning in both Python inference and the C++ JAX backend;\n- remove the stale C++ header comment and unused validate_charge_spin overload;\n- update charge_spin docs/type hints;\n- add tests for unsupported charge_spin, invalid TFModelWrapper charge_spin shapes, and the charge-spin SavedModel export/call path.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants