feat(dpa4): multiple updates for dpa4#5734
Conversation
feat(pt/dpa4): add bf16 infer
feat(pt): full validation supports ener_spin feat(pt): allow spin label missing fix(pt/spin): spin stat bug when using lmdb
perf(pt/dpa4): 3xfp16 chore(pt/dpa4): relocate custom kernels
for more information, see https://pre-commit.ci
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (11)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughThis PR adds native per-atom spin support, a configurable residual readout stack, focus-major tensor-layout updates, fused Triton/CuTe SeZM kernels, and spin-aware training/export/validation wiring. ChangesNative Spin Descriptor and Readout Stack
Focus-Major (fndc) Layout Refactor and Wigner-D Optimization
Fused Triton/CuTe Kernel Backends for SeZM SO(2)
Native Spin Model, Training, and Export Integration
Estimated code review effort: 5 (Critical) | ~150 minutes Sequence Diagram(s)sequenceDiagram
participant Caller
participant DescrptDPA4
participant SpinEmbedding
participant GIE as GeometricInitialEmbedding
participant Readout as _apply_readout
Caller->>DescrptDPA4: call(coord, atype, spin)
DescrptDPA4->>SpinEmbedding: _apply_spin_embedding(type_feat, spin)
SpinEmbedding-->>DescrptDPA4: scalar l=0, vector l=1
DescrptDPA4->>GIE: call(..., spin_l1_message)
GIE-->>DescrptDPA4: non_scalar_message with folded spin l=1
DescrptDPA4->>Readout: _apply_readout(x)
Readout-->>Caller: descriptor output
sequenceDiagram
participant Trainer
participant SeZMNativeSpinModel
participant SeZMModel as SeZMModel.core_compute
participant TransformOutput as edge_energy_deriv
Trainer->>SeZMNativeSpinModel: forward(coord, atype, spin)
SeZMNativeSpinModel->>SeZMModel: forward_common(spin=spin)
SeZMModel->>TransformOutput: edge_energy_deriv(edge_vec, spin_leaf=spin)
TransformOutput-->>SeZMModel: force, virial, energy_derv_r_mag
SeZMModel-->>SeZMNativeSpinModel: energy, force, force_mag
SeZMNativeSpinModel-->>Trainer: atom_energy, energy, force, force_mag, mask_mag
Possibly related PRs
Suggested labels: Suggested reviewers: 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (4 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
deepmd/kernels/triton/sezm/__init__.py (1)
9-42: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value
flash_atten.py's availability flag isn't included in the aggregate.
TRITON_AVAILABLEANDs 7 per-kernel flags but omitsFLASH_ATTEN-related availability fromflash_atten.py(part of this cohort). Functionally equivalent today since every flag reduces to "is triton importable" per the comment, but ifflash_atten.pyever gains its own stricter gating (e.g. a Triton-version check), the package flag would silently miss it.♻️ Proposed fix
+from .flash_atten import ( + FLASH_ATTEN_TRITON_AVAILABLE, +) from .force_assembly import ( FORCE_ASSEMBLY_TRITON_AVAILABLE, ) @@ TRITON_AVAILABLE = ( TRITON_ROTATION_AVAILABLE and RADIAL_MIX_TRITON_AVAILABLE and SO2_BLOCK_GEMM_TRITON_AVAILABLE and SO2_VALUE_PATH_TRITON_AVAILABLE and STACK_FP16X3_TRITON_AVAILABLE and WIGNER_MONOMIALS_TRITON_AVAILABLE and FORCE_ASSEMBLY_TRITON_AVAILABLE + and FLASH_ATTEN_TRITON_AVAILABLE )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/kernels/triton/sezm/__init__.py` around lines 9 - 42, The aggregate TRITON_AVAILABLE flag in the sezm package is missing the availability guard from flash_atten.py, so the package-level check can drift out of sync with all kernels in this cohort. Update the imports and the TRITON_AVAILABLE conjunction in deepmd/kernels/triton/sezm/__init__.py to include the flash_atten module’s availability symbol alongside TRITON_ROTATION_AVAILABLE, RADIAL_MIX_TRITON_AVAILABLE, SO2_BLOCK_GEMM_TRITON_AVAILABLE, SO2_VALUE_PATH_TRITON_AVAILABLE, STACK_FP16X3_TRITON_AVAILABLE, WIGNER_MONOMIALS_TRITON_AVAILABLE, and FORCE_ASSEMBLY_TRITON_AVAILABLE, keeping the package flag as the full AND of every per-kernel availability check.deepmd/pt/model/model/__init__.py (1)
142-156: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valueOptional: validate index range before scattering into the mask.
For the magnetic-index form,
mask[use_spin] = Truewill raise a raw NumPyIndexError(and silently accept negatives) when an index is out of range fortype_map. A small explicit check would give a clearer, actionable error consistent with the symbol-form validation just above.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/model/model/__init__.py` around lines 142 - 156, The spin.use_spin normalization in deepmd/pt/model/model/__init__.py should validate magnetic indices before scattering them into the mask. In the branch that handles index lists in the model parameter processing logic around use_spin and type_map, add an explicit range check for every index in use_spin so out-of-range values (including negatives) raise a clear ValueError instead of a raw NumPy IndexError. Keep the existing symbol-name validation path intact and update the mask assignment only after the indices are confirmed valid.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@deepmd/kernels/triton/sezm/__init__.py`:
- Around line 9-42: The aggregate TRITON_AVAILABLE flag in the sezm package is
missing the availability guard from flash_atten.py, so the package-level check
can drift out of sync with all kernels in this cohort. Update the imports and
the TRITON_AVAILABLE conjunction in deepmd/kernels/triton/sezm/__init__.py to
include the flash_atten module’s availability symbol alongside
TRITON_ROTATION_AVAILABLE, RADIAL_MIX_TRITON_AVAILABLE,
SO2_BLOCK_GEMM_TRITON_AVAILABLE, SO2_VALUE_PATH_TRITON_AVAILABLE,
STACK_FP16X3_TRITON_AVAILABLE, WIGNER_MONOMIALS_TRITON_AVAILABLE, and
FORCE_ASSEMBLY_TRITON_AVAILABLE, keeping the package flag as the full AND of
every per-kernel availability check.
In `@deepmd/pt/model/model/__init__.py`:
- Around line 142-156: The spin.use_spin normalization in
deepmd/pt/model/model/__init__.py should validate magnetic indices before
scattering them into the mask. In the branch that handles index lists in the
model parameter processing logic around use_spin and type_map, add an explicit
range check for every index in use_spin so out-of-range values (including
negatives) raise a clear ValueError instead of a raw NumPy IndexError. Keep the
existing symbol-name validation path intact and update the mask assignment only
after the indices are confirmed valid.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 1a30b519-28ad-43fe-835a-90307a68c825
📒 Files selected for processing (97)
deepmd/dpmodel/descriptor/dpa4.pydeepmd/dpmodel/descriptor/dpa4_nn/__init__.pydeepmd/dpmodel/descriptor/dpa4_nn/activation.pydeepmd/dpmodel/descriptor/dpa4_nn/embedding.pydeepmd/dpmodel/descriptor/dpa4_nn/grid_net.pydeepmd/dpmodel/descriptor/dpa4_nn/norm.pydeepmd/dpmodel/descriptor/dpa4_nn/so2.pydeepmd/dpmodel/descriptor/dpa4_nn/wignerd.pydeepmd/kernels/__init__.pydeepmd/kernels/cute/__init__.pydeepmd/kernels/cute/sezm/__init__.pydeepmd/kernels/cute/sezm/backward.pydeepmd/kernels/cute/sezm/forward.pydeepmd/kernels/cute/sezm/operator.pydeepmd/kernels/triton/__init__.pydeepmd/kernels/triton/sezm/__init__.pydeepmd/kernels/triton/sezm/flash_atten.pydeepmd/kernels/triton/sezm/force_assembly.pydeepmd/kernels/triton/sezm/radial_mix.pydeepmd/kernels/triton/sezm/so2_block_gemm.pydeepmd/kernels/triton/sezm/so2_rotation.pydeepmd/kernels/triton/sezm/so2_stack_fp16x3.pydeepmd/kernels/triton/sezm/so2_value_path.pydeepmd/kernels/triton/sezm/sweep_tile_configs.pydeepmd/kernels/triton/sezm/tile_config_data.pydeepmd/kernels/triton/sezm/tile_configs.pydeepmd/kernels/triton/sezm/wigner_monomials.pydeepmd/kernels/utils.pydeepmd/pt/entrypoints/freeze_pt2.pydeepmd/pt/infer/deep_eval.pydeepmd/pt/loss/ener_spin.pydeepmd/pt/model/atomic_model/sezm_atomic_model.pydeepmd/pt/model/descriptor/sezm.pydeepmd/pt/model/descriptor/sezm_nn/__init__.pydeepmd/pt/model/descriptor/sezm_nn/activation.pydeepmd/pt/model/descriptor/sezm_nn/cute/__init__.pydeepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.pydeepmd/pt/model/descriptor/sezm_nn/embedding.pydeepmd/pt/model/descriptor/sezm_nn/grid_net.pydeepmd/pt/model/descriptor/sezm_nn/norm.pydeepmd/pt/model/descriptor/sezm_nn/so2.pydeepmd/pt/model/descriptor/sezm_nn/triton/__init__.pydeepmd/pt/model/descriptor/sezm_nn/utils.pydeepmd/pt/model/descriptor/sezm_nn/wignerd.pydeepmd/pt/model/model/__init__.pydeepmd/pt/model/model/sezm_model.pydeepmd/pt/model/model/sezm_native_spin_model.pydeepmd/pt/model/model/sezm_property_model.pydeepmd/pt/model/model/sezm_spin_model.pydeepmd/pt/model/model/spin_model.pydeepmd/pt/model/model/transform_output.pydeepmd/pt/train/training.pydeepmd/pt/train/validation.pydeepmd/pt/utils/compile_compat.pydeepmd/pt/utils/serialization.pydeepmd/pt_expt/descriptor/dpa4.pydeepmd/pt_expt/descriptor/dpa4_nn/__init__.pydeepmd/pt_expt/descriptor/dpa4_nn/so2.pydeepmd/pt_expt/descriptor/dpa4_nn/triton/__init__.pydeepmd/pt_expt/descriptor/dpa4_nn/triton/radial_mix.pydeepmd/pt_expt/descriptor/dpa4_nn/triton/so2_rotation.pydeepmd/pt_expt/descriptor/dpa4_nn/wignerd.pydeepmd/pt_expt/infer/deep_eval.pydeepmd/pt_expt/utils/edge_schema.pydeepmd/utils/argcheck.pydeepmd/utils/eval_metrics.pydeepmd/utils/spin.pydoc/model/dpa4.mdexamples/spin/dpa4/input-deepspin.jsonexamples/spin/dpa4/input.jsonexamples/spin/dpa4/lmp/README.mdexamples/spin/dpa4/lmp/in.lammpsexamples/spin/dpa4/lmp/init.dataexamples/water/dpa4/README.mdexamples/water/dpa4/input.jsonpyproject.tomlsource/api_cc/include/DeepPot.hsource/api_cc/include/DeepPotPTExpt.hsource/api_cc/include/DeepSpinPTExpt.hsource/api_cc/src/DeepPot.ccsource/api_cc/src/DeepPotPTExpt.ccsource/api_cc/src/DeepSpinPTExpt.ccsource/lmp/pair_deepmd.cppsource/lmp/pair_deepspin.cppsource/op/pt/comm.ccsource/tests/common/dpmodel/test_dpa4_so3_projector.pysource/tests/common/test_examples.pysource/tests/pt/model/test_descriptor_sezm.pysource/tests/pt/model/test_descriptor_sezm_triton.pysource/tests/pt/model/test_dpa4_dpmodel_parity.pysource/tests/pt/model/test_dpa4_ptexpt_grad_parity.pysource/tests/pt/model/test_sezm_export.pysource/tests/pt/model/test_sezm_model.pysource/tests/pt/model/test_sezm_parallel.pysource/tests/pt/model/test_sezm_spin_model.pysource/tests/pt/test_validation.pysource/tests/pt_expt/utils/test_border_op_backward.py
💤 Files with no reviewable changes (6)
- deepmd/pt/model/descriptor/sezm_nn/cute/init.py
- deepmd/pt_expt/descriptor/dpa4_nn/triton/init.py
- deepmd/pt/model/descriptor/sezm_nn/triton/init.py
- deepmd/pt_expt/descriptor/dpa4_nn/triton/radial_mix.py
- deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py
- deepmd/pt/model/descriptor/sezm_nn/utils.py
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #5734 +/- ##
==========================================
- Coverage 81.29% 79.26% -2.03%
==========================================
Files 990 1001 +11
Lines 111020 113879 +2859
Branches 4232 4272 +40
==========================================
+ Hits 90252 90268 +16
- Misses 19242 22067 +2825
- Partials 1526 1544 +18 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Since each commits relies on each other to pass the parity tests, so this is a fusion pr.
Summary by CodeRabbit
readout_layersand native per-atom spin support across SeZM/DPA4 descriptors, including newspinruntime input and spin-aware scalar readout.SpinEmbeddingcomponent for the native spin scheme."fndc"support for DPA4/SeZM tensor layouts and expanded optional fused Triton acceleration (incl. flash-attention-style aggregation) for supported shapes.sezm_native_spin; removed an older experimental CuTe rotation path.