fix(jax): support charge-spin savedmodel export#5737
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (8)
💤 Files with no reviewable changes (1)
🚧 Files skipped from review as they are similar to previous changes (4)
📝 WalkthroughWalkthroughThis PR adds optional ChangesCharge-spin propagation
Estimated code review effort: 4 (Complex) | ~60 minutes Possibly related PRs
Suggested labels: Suggested reviewers: 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
source/api_c/include/deepmd.hpp (1)
2326-2387: 🎯 Functional Correctness | 🟡 Minor | ⚡ Quick winStale comment contradicts the new charge_spin threading logic.
Line 2337/2427's
// charge_spin is not supported via the C-API model-deviation path.directly contradicts the code right below it (and the@param[in] charge_spindoc block at 2321-2324), which now buildscharge_spin_tiled_and forwardscharge_spin__into_DP_DeepPotModelDeviCompute. Either the comment is stale and should be removed, or_DP_DeepPotModelDeviComputegenuinely ignores this parameter for model-deviation (in which case the added tiling/validation here is dead work). Please clarify and fix the comment/doc mismatch.📝 Suggested comment fix (if the C-API does honor charge_spin here)
- // charge_spin is not supported via the C-API model-deviation path. unsigned int natoms = atype.size();Also applies to: 2414-2491
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/api_c/include/deepmd.hpp` around lines 2326 - 2387, The charge_spin handling in compute and the related model-deviation path is inconsistent: the comment says it is unsupported, but the code now validates, tiles, and forwards charge_spin__ into _DP_DeepPotModelDeviCompute. Update the stale comment and the `@param`[in] charge_spin documentation to match the actual behavior, or remove the new charge_spin_tiled_ / validate_charge_spin wiring in compute if the C-API path truly ignores it; use the compute method and _DP_DeepPotModelDeviCompute as the reference points.
🧹 Nitpick comments (6)
deepmd/jax/jax2tf/make_model.py (1)
34-44: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winStale
call_lowerCallabletype hint doesn't reflect the newcharge_spinargument.The
call_lowerparameter'sCallable[[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, bool], dict[...]]signature isn't updated even though the function now conditionally invokescall_lowerwith an extracharge_spinkeyword (Lines 87-94), and downstreamcall_lowerimplementations (e.g.,source/jax2tf_tests/test_make_model.py'scall_lower,serialization.py's dispatch wrappers) explicitly accept a 7thcharge_spin: tf.Tensorparameter. Update the type hint so static type checkers reflect the real accepted signature.Also applies to: 54-54
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/jax/jax2tf/make_model.py` around lines 34 - 44, The call_lower type hint is stale and still describes the old 6-argument signature, while make_model now passes charge_spin and downstream call_lower implementations accept it. Update the Callable annotation in make_model to include the extra charge_spin tf.Tensor parameter so the signature matches the actual invocation in the call_lower path and related wrappers/tests.deepmd/jax/infer/deep_eval.py (1)
252-262: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winDocument the new
charge_spinkwarg.
charge_spinis now a meaningful, extracted keyword (Line 262) but the docstring's**kwargssection still just says "Other parameters" without mentioning it. Since this is a new public capability, document its expected shape/semantics.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/jax/infer/deep_eval.py` around lines 252 - 262, The docstring for the evaluation method currently leaves the extracted charge_spin kwarg undocumented, so update the **kwargs section to explicitly describe charge_spin alongside the other parameters. Use the method in deep_eval.py where charge_spin is popped from kwargs to locate the spot, and add its expected shape and semantics so callers know how to pass it and what it represents.deepmd/jax/jax2tf/tfmodel.py (1)
112-134: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winDocstrings omit new
charge_spinparameter.
__call__andcallboth gained acharge_spinargument, but their docstringParameterssections still only documentdo_atomic_virialand earlier params.📝 Proposed docstring addition
do_atomic_virial If calculate the atomic virial. + charge_spin + The charge and spin conditioning input. shape: nf x dim_chg_spin ReturnsAlso applies to: 157-181
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/jax/jax2tf/tfmodel.py` around lines 112 - 134, The docstrings for TFModel.__call__ and call are missing the new charge_spin parameter, so update each Parameters section to document charge_spin alongside the existing coord, atype, box, fparam, aparam, and do_atomic_virial entries. Keep the wording consistent with the method signatures in tfmodel.py and ensure the new argument is described wherever these methods are documented.source/jax2tf_tests/test_tfmodel.py (1)
1-47: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winAdd edge-case tests for invalid
charge_spinshapes.Current tests cover the default, single-frame-broadcast, and missing-value branches of
_make_charge_spin_input, but not the shape-validation error paths (wrong last dimension,ndim > 2, or a 2D input whose first dimension matches neither1nornframes). These are distinctraise ValueErrorbranches intfmodel.py(lines 416-420, 423-424) that would otherwise be silently unverified.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/jax2tf_tests/test_tfmodel.py` around lines 1 - 47, Add negative tests for TFModelWrapper._make_charge_spin_input to cover the shape-validation branches that are currently untested. In source/jax2tf_tests/test_tfmodel.py, extend the existing _make_wrapper-based tests to assert ValueError for an input with the wrong last dimension, for an input with ndim > 2, and for a 2D input whose first dimension is neither 1 nor nframes. Use the _make_charge_spin_input method name and the existing charge_spin handling setup to target the distinct validation paths in tfmodel.py.source/jax2tf_tests/test_serialization.py (1)
40-113: 📐 Maintainability & Code Quality | 🔵 Trivial | 🏗️ Heavy liftNo test exercises the
has_chg_spin=TrueSavedModel export path.
DummyModelreportsget_dim_chg_spin() == 0, sodeserialize_to_file's entire charge_spin branch (extended TensorSpecs, extended polymorphic_shapes,tf.conddispatch withcharge_spin, and the charge_spin-aware exportedcall/call_lowerwrappers) is never traced bytest_savedmodel_export_contains_xla_call_module. Other new tests (test_make_model.py,test_tfmodel.py) only cover lower-level helpers, not the actual SavedModel export/dispatch logic.Consider adding a variant of
DummyModelwithget_dim_chg_spin() > 0(and acall_common_loweracceptingcharge_spin) to exercise the export path end-to-end, similar to how the base case validatesXlaCallModulepresence.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/jax2tf_tests/test_serialization.py` around lines 40 - 113, The SavedModel export tests do not cover the `has_chg_spin=True` path, so the charge-spin-specific export/dispatch logic in `deserialize_to_file` is untested. Add a `DummyModel` variant that reports `get_dim_chg_spin() > 0` and whose `call_common_lower` accepts `charge_spin`, then extend `test_savedmodel_export_contains_xla_call_module` (or a sibling test) to export and load it end-to-end, verifying the charge-spin-aware `call`/`call_lower` wrappers and `tf.cond` branch are exercised alongside the existing base case.source/api_c/include/deepmd.hpp (1)
863-929: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valueRemove the unused
validate_charge_spinoverload
The 3-arg overload is no longer called anywhere, so keeping it alongside the tiling version leaves two different validation rules under the same name. Removing it would avoid future confusion.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/api_c/include/deepmd.hpp` around lines 863 - 929, Remove the unused 3-argument validate_charge_spin overload from deepmd::hpp in deepmd.hpp so only the tiling-aware version remains. Keep the existing 4-argument validate_charge_spin as the single source of validation logic, and make sure any call sites still resolve to that overload so there are no duplicate validation rules under the same name.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/jax/infer/deep_eval.py`:
- Around line 399-420: The charge_spin argument is being silently ignored in the
JAX inference path when the model does not support charge-spin embeddings.
Update the logic in deep_eval.py around the charge_spin_input/model_kwargs setup
in the inference method to detect when charge_spin is provided but
self.has_chg_spin_ebd() is false, and fail loudly with a ValueError or
RuntimeError (or at minimum emit a warning) instead of dropping the input. Keep
the existing to_jax_array/model_kwargs flow for supported models, but gate it
with an explicit validation using the same charge_spin_input and
has_chg_spin_ebd symbols.
In `@source/api_cc/src/DeepPotJAX.cc`:
- Around line 463-499: The helper make_charge_spin_input currently returns an
empty vector when dchgspin is 0, which silently discards non-empty charge_spin
input instead of rejecting it. Update this function to mirror
deepmd::hpp::validate_charge_spin by checking for non-empty charge_spin or
default_chg_spin before the early return, and throw a deepmd_exception when
charge/spin data is provided but the model does not support it. Keep the
existing size/shape handling in make_charge_spin_input for the supported
dchgspin cases.
---
Outside diff comments:
In `@source/api_c/include/deepmd.hpp`:
- Around line 2326-2387: The charge_spin handling in compute and the related
model-deviation path is inconsistent: the comment says it is unsupported, but
the code now validates, tiles, and forwards charge_spin__ into
_DP_DeepPotModelDeviCompute. Update the stale comment and the `@param`[in]
charge_spin documentation to match the actual behavior, or remove the new
charge_spin_tiled_ / validate_charge_spin wiring in compute if the C-API path
truly ignores it; use the compute method and _DP_DeepPotModelDeviCompute as the
reference points.
---
Nitpick comments:
In `@deepmd/jax/infer/deep_eval.py`:
- Around line 252-262: The docstring for the evaluation method currently leaves
the extracted charge_spin kwarg undocumented, so update the **kwargs section to
explicitly describe charge_spin alongside the other parameters. Use the method
in deep_eval.py where charge_spin is popped from kwargs to locate the spot, and
add its expected shape and semantics so callers know how to pass it and what it
represents.
In `@deepmd/jax/jax2tf/make_model.py`:
- Around line 34-44: The call_lower type hint is stale and still describes the
old 6-argument signature, while make_model now passes charge_spin and downstream
call_lower implementations accept it. Update the Callable annotation in
make_model to include the extra charge_spin tf.Tensor parameter so the signature
matches the actual invocation in the call_lower path and related wrappers/tests.
In `@deepmd/jax/jax2tf/tfmodel.py`:
- Around line 112-134: The docstrings for TFModel.__call__ and call are missing
the new charge_spin parameter, so update each Parameters section to document
charge_spin alongside the existing coord, atype, box, fparam, aparam, and
do_atomic_virial entries. Keep the wording consistent with the method signatures
in tfmodel.py and ensure the new argument is described wherever these methods
are documented.
In `@source/api_c/include/deepmd.hpp`:
- Around line 863-929: Remove the unused 3-argument validate_charge_spin
overload from deepmd::hpp in deepmd.hpp so only the tiling-aware version
remains. Keep the existing 4-argument validate_charge_spin as the single source
of validation logic, and make sure any call sites still resolve to that overload
so there are no duplicate validation rules under the same name.
In `@source/jax2tf_tests/test_serialization.py`:
- Around line 40-113: The SavedModel export tests do not cover the
`has_chg_spin=True` path, so the charge-spin-specific export/dispatch logic in
`deserialize_to_file` is untested. Add a `DummyModel` variant that reports
`get_dim_chg_spin() > 0` and whose `call_common_lower` accepts `charge_spin`,
then extend `test_savedmodel_export_contains_xla_call_module` (or a sibling
test) to export and load it end-to-end, verifying the charge-spin-aware
`call`/`call_lower` wrappers and `tf.cond` branch are exercised alongside the
existing base case.
In `@source/jax2tf_tests/test_tfmodel.py`:
- Around line 1-47: Add negative tests for
TFModelWrapper._make_charge_spin_input to cover the shape-validation branches
that are currently untested. In source/jax2tf_tests/test_tfmodel.py, extend the
existing _make_wrapper-based tests to assert ValueError for an input with the
wrong last dimension, for an input with ndim > 2, and for a 2D input whose first
dimension is neither 1 nor nframes. Use the _make_charge_spin_input method name
and the existing charge_spin handling setup to target the distinct validation
paths in tfmodel.py.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 5b5209f6-f6de-4d2c-9dc3-aa8cf95cab72
📒 Files selected for processing (10)
deepmd/jax/infer/deep_eval.pydeepmd/jax/jax2tf/make_model.pydeepmd/jax/jax2tf/serialization.pydeepmd/jax/jax2tf/tfmodel.pysource/api_c/include/deepmd.hppsource/api_cc/include/DeepPotJAX.hsource/api_cc/src/DeepPotJAX.ccsource/jax2tf_tests/test_make_model.pysource/jax2tf_tests/test_serialization.pysource/jax2tf_tests/test_tfmodel.py
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #5737 +/- ##
==========================================
- Coverage 81.29% 81.10% -0.19%
==========================================
Files 990 990
Lines 111020 111219 +199
Branches 4232 4245 +13
==========================================
- Hits 90252 90207 -45
- Misses 19242 19482 +240
- Partials 1526 1530 +4 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
|
Addressed the review comments in c3ea9ba:\n\n- reject explicit charge_spin when the JAX model does not support charge/spin conditioning in both Python inference and the C++ JAX backend;\n- remove the stale C++ header comment and unused validate_charge_spin overload;\n- update charge_spin docs/type hints;\n- add tests for unsupported charge_spin, invalid TFModelWrapper charge_spin shapes, and the charge-spin SavedModel export/call path. |
Summary
Tests
Notes
cmake --build source/build --target deepmd_cc -j2still fails before compiling this change because local CMake TensorFlow detection stops atcmake/Findtensorflow.cmake:234 (file): file unknown error.Summary by CodeRabbit