Skip to content

feat(pt): support spin virial #4545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: devel
Choose a base branch
from

Conversation

iProzd
Copy link
Collaborator

@iProzd iProzd commented Jan 10, 2025

Summary by CodeRabbit

  • New Features

    • Added support for virial loss calculations in spin energy models.
    • Enhanced model's ability to process spin-related coordinate corrections.
    • Improved handling of virial outputs in model computations.
    • Introduced dynamic selection mechanism for neighbor and angle processing in repflow descriptors.
    • Added utility functions for graph index computations and aggregation in model networks.
    • Introduced a new testing framework for evaluating spin virial functionality.
    • Added an alternative exponential switch function for environment matrix calculations.
    • Enabled optional use of exponential switching in environment matrix and descriptor computations.
  • Bug Fixes

    • Corrected virial array population in model deviation and computation methods.
    • Fixed issues with virial output processing in various model implementations.
  • Tests

    • Added new test cases for spin energy models with virial calculations.
    • Extended testing framework to support spin-related model evaluations.
    • Added consistency tests for descriptor models with and without dynamic selection.
    • Added tests validating descriptor model output consistency across configurations including dynamic selection.
    • Added tests for descriptor model consistency across precision and configuration variations.

Copy link
Contributor

coderabbitai bot commented Jan 10, 2025

📝 Walkthrough
## Walkthrough

This pull request introduces enhanced support for virial calculations in the DeepMD-kit framework, specifically focusing on spin-related models. The changes span multiple files across the project, adding new functionality to handle virial outputs, coordinate corrections, and spin-related computations. The modifications enable more comprehensive energy and virial loss calculations, with updates to model processing, loss computation, and testing frameworks. Additionally, a dynamic neighbor selection mechanism is introduced in the repflow descriptor layers, controlled by new parameters and supported by new utility functions and tests.

## Changes

| File | Change Summary |
|------|----------------|
| `deepmd/pt/loss/ener_spin.py` | Added virial loss calculation support in `EnergySpinLoss` class, updated `label_requirement` for "virial". |
| `deepmd/pt/model/model/make_model.py` | Introduced `coord_corr_for_virial` parameter in `forward_common` and `forward_common_lower` methods. |
| `deepmd/pt/model/model/spin_model.py` | Enhanced spin input processing and model forward methods to include spin-related corrections and virial output handling. |
| `deepmd/pt/model/model/transform_output.py` | Added `extended_coord_corr` parameter to `fit_output_to_model_output` function for derivative correction. |
| `deepmd/pt/model/descriptor/repflows.py` | Added dynamic neighbor selection support with `use_dynamic_sel` and `sel_reduce_factor` parameters; refactored forward method to handle dynamic graph indices. |
| `deepmd/dpmodel/descriptor/dpa3.py` | Added `edge_init_use_dist`, `use_exp_switch`, `use_dynamic_sel`, and `sel_reduce_factor` parameters to `RepFlowArgs` constructor. |
| `deepmd/pt/model/descriptor/dpa3.py` | Passed dynamic selection parameters from `repflow_args` to `DescrptBlockRepflows` in `DescrptDPA3` constructor. |
| `deepmd/pt/model/descriptor/repflow_layer.py` | Added dynamic selection logic and new methods for dynamic aggregation and updates; extended constructor and forward method to support dynamic neighbor and angle indexing. |
| `deepmd/pt/model/network/utils.py` | Added `aggregate` and `get_graph_index` utility functions for dynamic graph indexing and aggregation. |
| `deepmd/utils/argcheck.py` | Added optional arguments `edge_init_use_dist`, `use_exp_switch`, `use_dynamic_sel`, and `sel_reduce_factor` to `dpa3_repflow_args` function. |
| `deepmd/pt/model/descriptor/env_mat.py` | Added `use_exp_switch` parameter to environment matrix functions to select between smoothing weight functions. |
| `deepmd/pt/utils/preprocess.py` | Added `compute_exp_sw` function implementing an exponential switch function alternative. |
| `source/api_c/include/deepmd.hpp` | Fixed virial array population in `DeepSpinModelDevi` class's `compute` methods. |
| `source/api_c/src/c_api.cc` | Restored virial output handling in `DP_DeepSpinModelDeviCompute_variant` function. |
| `source/api_cc/src/DeepSpinPT.cc` | Implemented virial tensor retrieval and assignment in `DeepSpinPT` class's `compute` methods; changed comm_dict insertion to `insert_or_assign`. |
| `source/tests/pt/model/test_autodiff.py` | Added spin-related test cases and logic in `VirialTest` and `ForceTest` classes; introduced `TestEnergyModelSpinSeAVirial` class. |
| `source/tests/pt/model/test_ener_spin_model.py` | Updated spin input processing tests in `SpinTest` class to handle additional return values. |
| `source/tests/universal/common/cases/model/utils.py` | Modified virial calculation logic to include new conditions for spin virial testing. |
| `source/tests/universal/pt/model/test_model.py` | Added `test_spin_virial` property to `TestSpinEnergyModelDP` class to enable spin virial testing. |
| `source/tests/pt/model/test_nosel.py` | Added tests for `DescrptDPA3` descriptor with dynamic neighbor selection, validating output consistency. |

## Suggested Reviewers

- njzjz
- wanghan-iapcm

📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 106c973 and 8fdc524.

📒 Files selected for processing (2)
  • source/api_cc/src/DeepPotPT.cc (1 hunks)
  • source/api_cc/src/DeepSpinPT.cc (5 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • source/api_cc/src/DeepPotPT.cc
  • source/api_cc/src/DeepSpinPT.cc
✨ Finishing Touches
  • 📝 Generate Docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (12)
deepmd/pt/model/model/spin_model.py (3)

58-59: Enhance code readability by reshaping tensor in a single step

In line 58, consider reshaping self.virtual_scale_mask.to(atype.device)[atype] directly without wrapping it in parentheses for better readability.


379-381: Avoid unnecessary computation by not calling process_spin_input during stat computation

In the compute_or_load_stat method, calling process_spin_input may introduce unnecessary computational overhead if coord_corr is not used. Consider modifying the code to exclude coord_corr when it's not needed.


591-594: Consider consistency in handling do_grad_c checks

In the forward method, ensure that the handling of do_grad_c("energy") and subsequent assignments align with the changes made in translated_output_def. This maintains consistency across the methods.

source/tests/pt/model/test_autodiff.py (4)

144-144: Initialize the spin variable only when necessary

The spin variable is initialized even when test_spin is False. Consider moving the initialization inside the conditional block to optimize performance.

Apply this diff to adjust the initialization:

-        spin = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator)

Move the initialization to after line 150, within the if test_spin block.


148-148: Ensure spin is only converted to NumPy when necessary

Similar to the previous comment, the conversion of spin to a NumPy array should be conditional based on test_spin to avoid unnecessary computations.


151-154: Simplify the conditional assignment of test_keys

The assignment of test_keys can be streamlined for clarity.

Apply this diff to simplify the code:

-        if not test_spin:
-            test_keys = ["energy", "force", "virial"]
-        else:
-            test_keys = ["energy", "force", "force_mag", "virial"]
+        test_keys = ["energy", "force", "virial"]
+        if test_spin:
+            test_keys.append("force_mag")

263-268: Add a newline for code style consistency

Include a blank line after the class definition to follow PEP 8 style guidelines for better readability.

Apply this diff:

 class TestEnergyModelSpinSeAVirial(unittest.TestCase, VirialTest):
 
+    def setUp(self) -> None:
         model_params = copy.deepcopy(model_spin)
deepmd/pt/model/model/transform_output.py (3)

159-159: Update function documentation to include new parameter

The fit_output_to_model_output function has a new parameter extended_coord_corr. Update the docstring to describe this parameter and its role in the computation.


195-195: Avoid using # noqa comments for line length

Instead of using # noqa: RUF005 to suppress line length warnings, refactor the code to comply with style guidelines for better maintainability.

Apply this diff to split the line:

-                            ).view(list(dc.shape[:-2]) + [1, 9])  # noqa: RUF005
+                            )
+                        dc = dc.view(list(dc.shape[:-2]) + [1, 9])

Line range hint 226-226: Consider adding type annotations for function returns

Adding type annotations to functions enhances code clarity and aids in static analysis. Consider specifying the return types for the functions in this module.

deepmd/pt/loss/ener_spin.py (1)

271-286: LGTM! The virial loss calculation is well implemented.

The implementation follows the established pattern for loss calculations, with proper scaling and optional MAE computation. The code is clean and well-structured.

Consider extracting the common pattern of loss calculation (L2, MAE, scaling) into a helper method to reduce code duplication across energy, force, and virial loss calculations.

source/tests/pt/model/test_ener_spin_model.py (1)

118-118: Document the purpose of the ignored return values.

The additional return values (marked with _) from process_spin_input and process_spin_input_lower are silently ignored. Consider adding a comment explaining what these values represent and why they can be safely ignored in these tests.

Also applies to: 177-177

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9af197c and 8fd9565.

📒 Files selected for processing (11)
  • deepmd/pt/loss/ener_spin.py (1 hunks)
  • deepmd/pt/model/model/make_model.py (7 hunks)
  • deepmd/pt/model/model/spin_model.py (11 hunks)
  • deepmd/pt/model/model/transform_output.py (2 hunks)
  • source/api_c/include/deepmd.hpp (2 hunks)
  • source/api_c/src/c_api.cc (1 hunks)
  • source/api_cc/src/DeepSpinPT.cc (4 hunks)
  • source/tests/pt/model/test_autodiff.py (3 hunks)
  • source/tests/pt/model/test_ener_spin_model.py (2 hunks)
  • source/tests/universal/common/cases/model/utils.py (3 hunks)
  • source/tests/universal/pt/model/test_model.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (21)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Analyze (python)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Build C++ (cpu, cpu)
🔇 Additional comments (13)
deepmd/pt/model/model/spin_model.py (4)

63-64: Ensure proper alignment of coordinate corrections

The concatenation of tensors in coord_corr must maintain correct alignment with the corresponding atoms. Verify that torch.zeros_like(coord) and -spin_dist are correctly ordered, ensuring that the coordinate corrections apply to the appropriate atoms.


92-95: Validate the consistency of virtual atom handling

When creating extended_coord_corr, confirm that the virtual atoms are correctly accounted for, and that the concatenation preserves the intended structure. This is crucial for accurate virial calculations involving spin corrections.


410-412: Handle the new output coord_corr_for_virial appropriately

Ensure that all downstream methods that receive coord_corr_for_virial can handle this new parameter without errors. Verify that self.backbone_model.forward_common accepts coord_corr_for_virial as an argument.


631-636: Verify accurate squeezing of tensors and assignment

In the forward_lower method, confirm that the squeeze operations correctly reduce tensor dimensions and that the results are assigned to the appropriate keys in model_predict.

source/tests/pt/model/test_autodiff.py (1)

150-150: Conditionally handle spin and test_spin variables

Verify that all usages of spin and test_spin within the VirialTest class are properly guarded by conditionals to prevent errors when test_spin is False.

deepmd/pt/model/model/transform_output.py (1)

191-196: Ensure tensor shapes are compatible during matrix multiplication

In the computation of dc_corr, validate that the shapes of the tensors involved in the matrix multiplication are compatible to prevent runtime errors.

deepmd/pt/model/model/make_model.py (2)

138-138: LGTM! The parameter addition is well documented.

The new coord_corr_for_virial parameter is properly documented with clear shape information.

Also applies to: 157-159


187-194: LGTM! The coordinate correction handling is robust.

The implementation properly handles both cases:

  • When correction is provided: converts to correct dtype and gathers extended coordinates
  • When correction is not provided: sets extended_coord_corr to None
source/api_cc/src/DeepSpinPT.cc (1)

254-254: LGTM! The virial output handling is properly implemented.

The implementation correctly:

  • Retrieves virial tensor from outputs
  • Converts tensor to appropriate float type
  • Transfers data to CPU
  • Assigns to output vector

Also applies to: 269-273, 417-417, 432-435

source/tests/universal/pt/model/test_model.py (1)

716-717: Note the temporary nature of this flag.

The added test_spin_virial flag enables spin virial testing for this backend. As indicated by the comment, this can be removed once other backends support spin virial calculations.

Let's verify which backends currently lack spin virial support:

✅ Verification successful

Temporary flag is correctly annotated

The comment accurately reflects the current state of the codebase. The spin virial implementation is currently only available in the PyTorch (PT) backend, and the flag is appropriately marked as temporary until other backends (TensorFlow, PaddlePaddle) implement this functionality.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for spin virial implementations across backends
rg -A 5 "spin.*virial" source/api_cc/src/

Length of output: 43


Script:

#!/bin/bash
# Broader search for spin virial across the codebase
echo "=== Searching for spin virial across the codebase ==="
rg -i "spin.*virial" --type py

echo -e "\n=== Searching for related test files ==="
fd "test.*model.*py$"

echo -e "\n=== Looking for backend-specific implementations ==="
rg -i "spin.*virial" --type py -A 5 source/tests/

Length of output: 5891

source/tests/universal/common/cases/model/utils.py (1)

895-898: Improved control flow for spin virial testing.

The condition ensures virial computation only happens when appropriate - either for non-spin models or when spin virial testing is explicitly enabled.

source/api_c/src/c_api.cc (1)

865-869: Re-enabled virial data propagation.

The previously commented-out code for copying virial data has been restored, allowing proper propagation of virial values in spin model deviation computations.

source/api_c/include/deepmd.hpp (1)

2605-2607: Re-enabled virial data propagation in C++ interface.

The previously commented-out code for copying virial data has been restored, maintaining consistency with the C API changes.

@@ -408,6 +419,7 @@ def forward_common(
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
coord_corr_for_virial=coord_corr_for_virial,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Confirm the compatibility of the new argument coord_corr_for_virial

Check if the backbone model's forward_common method is designed to accept coord_corr_for_virial. If not, update the backbone model accordingly or modify the call to prevent runtime errors.

@@ -469,6 +482,7 @@
do_atomic_virial=do_atomic_virial,
comm_dict=comm_dict,
extra_nlist_sort=extra_nlist_sort,
extended_coord_corr=extended_coord_corr_for_virial,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure extended_coord_corr is accepted by forward_common_lower

Similar to the previous comment, verify that self.backbone_model.forward_common_lower accepts extended_coord_corr as an argument. This prevents potential issues during model execution.

Comment on lines +558 to +562
if self.do_grad_c("energy"):
output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"])
output_def["virial"].squeeze(-2)
output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"])
output_def["atom_virial"].squeeze(-3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Assign the result of squeeze operations to reduce tensor dimensions

The squeeze methods in lines 560 and 562 do not modify tensors in place. Assign the results to ensure the dimensions are correctly reduced.

Apply this diff to fix the issue:

-        output_def["virial"].squeeze(-2)
+        output_def["virial"] = output_def["virial"].squeeze(-2)
-        output_def["atom_virial"].squeeze(-3)
+        output_def["atom_virial"] = output_def["atom_virial"].squeeze(-3)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if self.do_grad_c("energy"):
output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"])
output_def["virial"].squeeze(-2)
output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"])
output_def["atom_virial"].squeeze(-3)
if self.do_grad_c("energy"):
output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"])
output_def["virial"] = output_def["virial"].squeeze(-2)
output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"])
output_def["atom_virial"] = output_def["atom_virial"].squeeze(-3)

Comment on lines +166 to 169
spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0),
)
# detach
ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure compatibility of tensor devices

When creating tensors within the np_infer function, ensure that all tensors are on the same device to prevent device mismatch errors, especially when env.DEVICE differs from "cpu".

Apply this diff to correct the device assignment:

-                    spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0),
+                    spins=torch.tensor(spin, device=new_cell.device).unsqueeze(0),

Committable suggestion skipped: line range outside the PR's diff.

@iProzd iProzd marked this pull request as draft January 13, 2025 06:05
@njzjz njzjz linked an issue Feb 9, 2025 that may be closed by this pull request
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
source/api_cc/src/DeepSpinPT.cc (1)

269-274: Consider future support for per-atom virial.

While global virial support is now implemented, per-atom virial remains commented out (lines 295-300). Consider whether this is a planned future enhancement for spin models.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8fd9565 and 6496194.

📒 Files selected for processing (3)
  • source/api_cc/src/DeepSpinPT.cc (4 hunks)
  • source/tests/universal/common/cases/model/utils.py (3 hunks)
  • source/tests/universal/pt/model/test_model.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • source/tests/universal/pt/model/test_model.py
  • source/tests/universal/common/cases/model/utils.py
⏰ Context from checks skipped due to timeout of 90000ms (29)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Analyze (python)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Analyze (c-cpp)
🔇 Additional comments (4)
source/api_cc/src/DeepSpinPT.cc (4)

254-254: Enabled virial tensor retrieval.

The code now retrieves the virial tensor from the model's outputs, enabling support for virial calculations in the DeepSpinPT class.


270-273: Successfully implemented virial tensor processing.

Added proper tensor conversion, CPU transfer, and data assignment for the virial tensor, following the same pattern used for energy and forces. This enables complete virial support for spin models.


417-417: Enabled virial tensor retrieval in overloaded method.

Consistent implementation of virial retrieval in the overloaded compute method, maintaining parity with the primary implementation.


432-435: Successfully implemented virial tensor processing in overloaded method.

The implementation correctly follows the same pattern used in the primary compute method, ensuring consistent virial handling across both implementations.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🔭 Outside diff range comments (2)
deepmd/dpmodel/descriptor/dpa3.py (1)

191-215: Add new parameters to the serialize method.

The newly added parameters use_dynamic_sel and sel_reduce_factor are not included in the serialize method, which will cause issues when saving and loading models.

Add these parameters to the serialize method:

     def serialize(self) -> dict:
         return {
             "n_dim": self.n_dim,
             "e_dim": self.e_dim,
             "a_dim": self.a_dim,
             "nlayers": self.nlayers,
             "e_rcut": self.e_rcut,
             "e_rcut_smth": self.e_rcut_smth,
             "e_sel": self.e_sel,
             "a_rcut": self.a_rcut,
             "a_rcut_smth": self.a_rcut_smth,
             "a_sel": self.a_sel,
             "a_compress_rate": self.a_compress_rate,
             "a_compress_e_rate": self.a_compress_e_rate,
             "a_compress_use_split": self.a_compress_use_split,
             "n_multi_edge_message": self.n_multi_edge_message,
             "axis_neuron": self.axis_neuron,
             "update_angle": self.update_angle,
             "update_style": self.update_style,
             "update_residual": self.update_residual,
             "update_residual_init": self.update_residual_init,
             "fix_stat_std": self.fix_stat_std,
             "optim_update": self.optim_update,
             "smooth_edge_update": self.smooth_edge_update,
+            "use_dynamic_sel": self.use_dynamic_sel,
+            "sel_reduce_factor": self.sel_reduce_factor,
         }
deepmd/pt/model/descriptor/repflow_layer.py (1)

1179-1243: Add new dynamic selection parameters to serialization.

The use_dynamic_sel and sel_reduce_factor parameters are not included in the serialization, which will cause issues when saving and loading models that use dynamic selection.

Update the serialize method to include these parameters:

 def serialize(self) -> dict:
     """Serialize the networks to a dict.
     
     Returns
     -------
     dict
         The serialized networks.
     """
     data = {
         "@class": "RepFlowLayer",
         "@version": 1,
         "e_rcut": self.e_rcut,
         "e_rcut_smth": self.e_rcut_smth,
         "e_sel": self.e_sel,
         "a_rcut": self.a_rcut,
         "a_rcut_smth": self.a_rcut_smth,
         "a_sel": self.a_sel,
         "ntypes": self.ntypes,
         "n_dim": self.n_dim,
         "e_dim": self.e_dim,
         "a_dim": self.a_dim,
         "a_compress_rate": self.a_compress_rate,
         "a_compress_e_rate": self.a_compress_e_rate,
         "a_compress_use_split": self.a_compress_use_split,
         "n_multi_edge_message": self.n_multi_edge_message,
         "axis_neuron": self.axis_neuron,
         "activation_function": self.activation_function,
         "update_angle": self.update_angle,
         "update_style": self.update_style,
         "update_residual": self.update_residual,
         "update_residual_init": self.update_residual_init,
         "precision": self.precision,
         "optim_update": self.optim_update,
         "smooth_edge_update": self.smooth_edge_update,
+        "use_dynamic_sel": self.use_dynamic_sel,
+        "sel_reduce_factor": self.sel_reduce_factor,
         "node_self_mlp": self.node_self_mlp.serialize(),
🧹 Nitpick comments (7)
deepmd/pt/model/network/utils.py (2)

43-45: Consider using broadcasting for better readability.

The transpose operations for division could be simplified using broadcasting.

Apply this diff to simplify the division:

-    if average:
-        output = (output.T / bin_count).T
-    return output
+    if average:
+        output = output / bin_count.unsqueeze(1)
+    return output

86-91: Clean up commented code for clarity.

There are commented-out lines that should be removed if not needed, or uncommented if they serve a purpose.

-    # nf x nloc x nnei x nnei
-    # nlist_mask_3d = nlist_mask[:, :, :, None] & nlist_mask[:, :, None, :]
     a_nlist_mask_3d = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :]
     n_edge = nlist_mask.sum().item()
-    # n_angle = a_nlist_mask_3d.sum().item()
deepmd/pt/model/descriptor/repflows.py (1)

479-482: Consider adding a comment explaining the JIT workaround.

The dummy tensor creation for JIT compatibility could benefit from more explanation.

         else:
-            # avoid jit assertion
+            # Create dummy tensors to avoid JIT assertion errors when dynamic selection is disabled.
+            # These tensors are not used in the computation but are required for type consistency.
             edge_index = angle_index = torch.zeros(
                 [1, 3], device=nlist.device, dtype=nlist.dtype
             )
source/tests/pt/model/test_nosel.py (2)

90-91: Consider adding tests with different sel_reduce_factor values.

The test currently only uses sel_reduce_factor=1.0. Consider adding test cases with other values to ensure the dynamic selection works correctly with different reduction factors.


143-206: Remove or enable the commented test_jit method.

The JIT test is completely commented out. Either remove it if it's not needed, or fix and enable it if JIT compatibility is important.

Would you like me to help implement a working JIT test or should this commented code be removed?

deepmd/pt/model/descriptor/repflow_layer.py (2)

806-806: Consider numerical stability when using float division results as scaling factors.

The dynamic_e_sel and dynamic_a_sel are computed as float divisions and used in scaling operations. For better numerical stability and clarity, consider storing the integer values separately.

 self.use_dynamic_sel = use_dynamic_sel
 self.sel_reduce_factor = sel_reduce_factor
-self.dynamic_e_sel = self.nnei / self.sel_reduce_factor
-self.dynamic_a_sel = self.a_sel / self.sel_reduce_factor
+self.dynamic_e_sel = int(self.nnei / self.sel_reduce_factor)
+self.dynamic_a_sel = int(self.a_sel / self.sel_reduce_factor)
+self.dynamic_e_scale = self.dynamic_e_sel ** (-0.5)
+self.dynamic_a_scale = self.dynamic_a_sel ** (-0.5)

Then use the pre-computed scale factors instead of computing them repeatedly:

-scale_factor=self.dynamic_e_sel ** (-0.5),
+scale_factor=self.dynamic_e_scale,

Also applies to: 825-825, 894-896, 1064-1064


731-743: Clarify when edge_index and angle_index parameters are required.

The documentation should explicitly state that these parameters are required when use_dynamic_sel is True.

Update the parameter documentation:

-        edge_index : Optional for dynamic sel, n_edge x 2
+        edge_index : n_edge x 2 (required when use_dynamic_sel is True)
             n2e_index : n_edge
                 Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i).
             n_ext2e_index : n_edge
                 Broadcast indices from extended node(j) to edge(ij).
-        angle_index : Optional for dynamic sel, n_angle x 3
+        angle_index : n_angle x 3 (required when use_dynamic_sel is True)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cbbce64 and 107cecd.

📒 Files selected for processing (7)
  • deepmd/dpmodel/descriptor/dpa3.py (2 hunks)
  • deepmd/pt/model/descriptor/dpa3.py (1 hunks)
  • deepmd/pt/model/descriptor/repflow_layer.py (17 hunks)
  • deepmd/pt/model/descriptor/repflows.py (7 hunks)
  • deepmd/pt/model/network/utils.py (1 hunks)
  • deepmd/utils/argcheck.py (1 hunks)
  • source/tests/pt/model/test_nosel.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.11.9)
deepmd/pt/model/descriptor/repflow_layer.py

754-754: Local variable nall is assigned to but never used

Remove assignment to unused variable nall

(F841)

🔇 Additional comments (5)
deepmd/pt/model/descriptor/dpa3.py (1)

153-154: LGTM!

The new parameters are correctly passed from repflow_args to the DescrptBlockRepflows constructor.

deepmd/pt/model/descriptor/repflows.py (1)

574-586: Verify the scale factor calculation.

The scale factor (self.nnei / self.sel_reduce_factor) ** (-0.5) is used in dynamic mode. Please ensure this scaling is appropriate for maintaining numerical stability across different neighbor counts.

Could you provide documentation or references for why this specific scale factor formula was chosen?

deepmd/pt/model/descriptor/repflow_layer.py (3)

330-381: Well-implemented dynamic calculation method.

The _cal_hg_dynamic method is properly documented and implements the dynamic version of the transposed rotation matrix calculation correctly. The use of the aggregate function with proper reshaping maintains mathematical equivalence with the original method.


454-506: Clean implementation of dynamic symmetrization.

The method properly implements the dynamic version of the symmetrization operation by reusing existing components (_cal_hg_dynamic and _cal_grrg). Good code reuse and clear documentation.


549-607: Well-implemented dynamic angle update with proper validation.

The optim_angle_update_dynamic method correctly implements the dynamic version of angle updates with:

  • Clear index range calculations for sub-matrices
  • Proper dimension validation via assertion
  • Efficient use of torch.index_select for gathering

Comment on lines +1600 to +1611
Argument(
"use_dynamic_sel",
bool,
optional=True,
default=False,
),
Argument(
"sel_reduce_factor",
float,
optional=True,
default=10.0,
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add documentation for the new dynamic selection arguments.

Both new arguments use_dynamic_sel and sel_reduce_factor are missing documentation strings. All other arguments in this function include comprehensive doc parameters to explain their purpose and usage. This is particularly important since this module generates user-facing documentation.

Please add documentation for both arguments:

        Argument(
            "use_dynamic_sel",
            bool,
            optional=True,
            default=False,
+           doc="Whether to enable dynamic neighbor selection mechanism. When enabled, the neighbor selection will be dynamically adjusted during training.",
        ),
        Argument(
            "sel_reduce_factor",
            float,
            optional=True,
            default=10.0,
+           doc="The reduction factor for dynamic neighbor selection. Controls how aggressively the neighbor count is reduced when dynamic selection is enabled.",
        ),

Note: The exact documentation content may need adjustment based on the specific implementation details of these features.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Argument(
"use_dynamic_sel",
bool,
optional=True,
default=False,
),
Argument(
"sel_reduce_factor",
float,
optional=True,
default=10.0,
),
Argument(
"use_dynamic_sel",
bool,
optional=True,
default=False,
doc="Whether to enable dynamic neighbor selection mechanism. When enabled, the neighbor selection will be dynamically adjusted during training.",
),
Argument(
"sel_reduce_factor",
float,
optional=True,
default=10.0,
doc="The reduction factor for dynamic neighbor selection. Controls how aggressively the neighbor count is reduced when dynamic selection is enabled.",
),
🤖 Prompt for AI Agents
In deepmd/utils/argcheck.py around lines 1600 to 1611, the new arguments
use_dynamic_sel and sel_reduce_factor lack docstrings explaining their purpose
and usage. Add appropriate doc parameters to both Argument instances, providing
clear and concise descriptions that align with their functionality, similar to
other documented arguments in this function, to ensure they are properly
documented for user-facing documentation generation.

Comment on lines +58 to +59
use_dynamic_sel: bool = False,
sel_reduce_factor: float = 10.0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add validation for sel_reduce_factor to prevent division by zero.

The sel_reduce_factor parameter is used as a divisor without validation. If a user passes 0 or a negative value, it will cause runtime errors.

Apply this diff to add parameter validation:

 def __init__(
     self,
     # ... other parameters ...
     use_dynamic_sel: bool = False,
     sel_reduce_factor: float = 10.0,
     # ... other parameters ...
 ) -> None:
     super().__init__()
+    if sel_reduce_factor <= 0:
+        raise ValueError(f"sel_reduce_factor must be positive, got {sel_reduce_factor}")
     self.epsilon = 1e-4  # protection of 1./nnei

Also applies to: 106-109

🤖 Prompt for AI Agents
In deepmd/pt/model/descriptor/repflow_layer.py around lines 58-59 and 106-109,
the sel_reduce_factor parameter is used as a divisor but lacks validation,
risking division by zero or errors with negative values. Add validation checks
to ensure sel_reduce_factor is greater than zero before it is used, raising a
ValueError or similar if the condition is not met to prevent runtime errors.

Comment on lines +705 to +706
edge_index: torch.Tensor, # n_edge x 2
angle_index: torch.Tensor, # n_angle x 3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

❌ Incorrect review comment

Add validation for required dynamic selection parameters.

When use_dynamic_sel is True, the edge_index and angle_index parameters are required but this isn't validated.

Run the following script to check if there are any existing validations in the codebase:

Consider adding validation at the beginning of the forward method:

 def forward(
     self,
     # ... parameters ...
     edge_index: torch.Tensor,  # n_edge x 2
     angle_index: torch.Tensor,  # n_angle x 3
 ):
+    if self.use_dynamic_sel:
+        if edge_index is None or angle_index is None:
+            raise ValueError("edge_index and angle_index must be provided when use_dynamic_sel is True")

Also applies to: 758-761


🏁 Script executed:

#!/bin/bash
# Description: Check for edge_index and angle_index validation in the codebase

# Search for validation of edge_index and angle_index when use_dynamic_sel is True
rg -B 5 -A 5 "edge_index.*angle_index.*use_dynamic_sel|use_dynamic_sel.*edge_index.*angle_index"

# Check how these parameters are passed in the descriptor files
fd -e py -x rg -l "RepFlowLayer" {} \; | xargs rg -B 3 -A 3 "edge_index.*angle_index"

Length of output: 1279


Add validation for required dynamic selection parameters.

When use_dynamic_sel is True, the edge_index and angle_index parameters are required but this isn't validated.

Run the following script to check if there are any existing validations in the codebase:

#!/bin/bash
# Description: Check for edge_index and angle_index validation in the codebase

# Search for validation of edge_index and angle_index when use_dynamic_sel is True
rg -B 5 -A 5 "edge_index.*angle_index.*use_dynamic_sel|use_dynamic_sel.*edge_index.*angle_index"

# Check how these parameters are passed in the descriptor files
fd -e py -x rg -l "RepFlowLayer" {} \; | xargs rg -B 3 -A 3 "edge_index.*angle_index"

Consider adding validation at the beginning of the forward method:

 def forward(
     self,
     # ... parameters ...
     edge_index: torch.Tensor,  # n_edge x 2
     angle_index: torch.Tensor,  # n_angle x 3
 ):
+    if self.use_dynamic_sel:
+        if edge_index is None or angle_index is None:
+            raise ValueError("edge_index and angle_index must be provided when use_dynamic_sel is True")

Also applies to: 758-761

🤖 Prompt for AI Agents
In deepmd/pt/model/descriptor/repflow_layer.py around lines 705-706 and also
758-761, there is no validation to ensure that edge_index and angle_index are
provided when use_dynamic_sel is True. Add explicit checks at the start of the
forward method to raise an error if use_dynamic_sel is True but either
edge_index or angle_index is None or missing. This will enforce the requirement
and prevent runtime errors due to missing parameters.

Comment on lines 1066 to +1072
if not self.smooth_edge_update:
# will be deprecated in the future
# not support dynamic index, will pass anyway
if self.use_dynamic_sel:
raise NotImplementedError(
"smooth_edge_update must be True when use_dynamic_sel is True!"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Move the compatibility check to initialization for earlier error detection.

The check for smooth_edge_update compatibility with use_dynamic_sel should be performed during initialization rather than at runtime.

Move this validation to the __init__ method:

 def __init__(
     self,
     # ... parameters ...
 ) -> None:
     super().__init__()
+    if use_dynamic_sel and not smooth_edge_update:
+        raise ValueError("smooth_edge_update must be True when use_dynamic_sel is True")
     if sel_reduce_factor <= 0:
         raise ValueError(f"sel_reduce_factor must be positive, got {sel_reduce_factor}")

And remove the runtime check:

-        if not self.smooth_edge_update:
-            # will be deprecated in the future
-            # not support dynamic index, will pass anyway
-            if self.use_dynamic_sel:
-                raise NotImplementedError(
-                    "smooth_edge_update must be True when use_dynamic_sel is True!"
-                )
+        if not self.smooth_edge_update:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if not self.smooth_edge_update:
# will be deprecated in the future
# not support dynamic index, will pass anyway
if self.use_dynamic_sel:
raise NotImplementedError(
"smooth_edge_update must be True when use_dynamic_sel is True!"
)
--- a/deepmd/pt/model/descriptor/repflow_layer.py
+++ b/deepmd/pt/model/descriptor/repflow_layer.py
@@ -1,6 +1,8 @@
class RepFlowLayer(...):
def __init__(
self,
# ... parameters ...
) -> None:
super().__init__()
+ # fail fast on incompatible flags
+ if use_dynamic_sel and not smooth_edge_update:
+ raise ValueError("smooth_edge_update must be True when use_dynamic_sel is True")
if sel_reduce_factor <= 0:
raise ValueError(f"sel_reduce_factor must be positive, got {sel_reduce_factor}")
# ... rest of init ...
@@ -1066,9 +1068,4 @@ class RepFlowLayer(...):
- if not self.smooth_edge_update:
- # will be deprecated in the future
- # not support dynamic index, will pass anyway
- if self.use_dynamic_sel:
- raise NotImplementedError(
- "smooth_edge_update must be True when use_dynamic_sel is True!"
- )
+ if not self.smooth_edge_update:
# (remaining logic under this branch)
🤖 Prompt for AI Agents
In deepmd/pt/model/descriptor/repflow_layer.py around lines 1066 to 1072, the
compatibility check between smooth_edge_update and use_dynamic_sel is currently
done at runtime, which delays error detection. Move this validation into the
__init__ method of the class to catch incompatible configurations earlier during
object creation, and remove the runtime check from this section of the code.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
deepmd/pt/utils/preprocess.py (1)

20-29: Enhance documentation and explain design choices.

The function needs better documentation for the following:

  1. Hardcoded constant: The value C = 20 is hardcoded without explanation. Consider adding a comment explaining its significance or making it configurable.

  2. Clamping behavior: Unlike compute_smooth_weight which clamps to [rmin, rmax], this function clamps to [0.0, rmax]. This difference should be documented.

  3. Mathematical formula: The docstring should explain the exponential double-exponential formula and its purpose.

 def compute_exp_sw(distance, rmin: float, rmax: float):
-    """Compute the exponential switch function for neighbor update."""
+    """Compute the exponential switch function for neighbor update.
+    
+    Implements exp(-exp(a * (distance - b))) where a = C/rmin, b = rmin.
+    Note: Unlike compute_smooth_weight, distance is clamped to [0.0, rmax].
+    
+    Parameters
+    ----------
+    distance : torch.Tensor
+        Input distances
+    rmin : float
+        Minimum radius for switch function
+    rmax : float  
+        Maximum cutoff radius
+    """
     if rmin >= rmax:
         raise ValueError("rmin should be less than rmax.")
     distance = torch.clamp(distance, min=0.0, max=rmax)
-    C = 20
+    C = 20  # Exponential switch parameter - controls decay rate
     a = C / rmin
     b = rmin
     exp_sw = torch.exp(-torch.exp(a * (distance - b)))
     return exp_sw
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 107cecd and 3587d07.

📒 Files selected for processing (6)
  • deepmd/dpmodel/descriptor/dpa3.py (2 hunks)
  • deepmd/pt/model/descriptor/dpa3.py (1 hunks)
  • deepmd/pt/model/descriptor/env_mat.py (6 hunks)
  • deepmd/pt/model/descriptor/repflows.py (8 hunks)
  • deepmd/pt/utils/preprocess.py (1 hunks)
  • deepmd/utils/argcheck.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • deepmd/pt/model/descriptor/dpa3.py
  • deepmd/utils/argcheck.py
  • deepmd/dpmodel/descriptor/dpa3.py
🔇 Additional comments (6)
deepmd/pt/model/descriptor/env_mat.py (1)

6-6: Well-implemented integration of exponential switch function.

The implementation correctly:

  • Imports the new function
  • Adds the parameter with appropriate default (False) for backward compatibility
  • Uses clear conditional logic to select the switch function
  • Propagates the parameter through the call stack
  • Updates documentation appropriately

Also applies to: 18-18, 38-42, 61-61, 74-74, 87-87

deepmd/pt/model/descriptor/repflows.py (5)

22-24: LGTM: Import addition.

The import of get_graph_index from deepmd.pt.model.network.utils is correctly added to support the new dynamic selection functionality.


189-191: Good parameter design with sensible defaults.

The three new parameters are well-designed:

  • use_exp_switch: Boolean flag with False default maintains backward compatibility
  • use_dynamic_sel: Boolean flag for new dynamic selection feature
  • sel_reduce_factor: Float with 10.0 default provides reasonable scaling

The parameters are properly stored as instance variables for use throughout the class.

Also applies to: 224-226


279-280: Correct parameter propagation.

The parameters are properly propagated:

  • use_dynamic_sel and sel_reduce_factor passed to RepFlowLayer initialization
  • use_exp_switch passed to both prod_env_mat calls for edge and angle environment matrices

Also applies to: 409-409, 431-431


437-448: Verify the logic change for padding and node embedding preparation.

The code has been reorganized to:

  1. Set padding positions to 0 instead of -1
  2. Move node embedding preparation after angle neighbor list processing
  3. Add assertion for extended_atype_embd

This reorganization appears to support the dynamic selection logic, but ensure that setting padding indices to 0 doesn't cause indexing issues elsewhere in the codebase.

#!/bin/bash
# Description: Check if there are other parts of the codebase that expect nlist padding to be -1
# Expected: Find usages that might be affected by changing padding from -1 to 0

echo "Searching for code that checks nlist == -1:"
rg -A 3 -B 3 "nlist.*== -1"

echo "Searching for code that uses nlist with -1 padding:"
rg -A 3 -B 3 "nlist.*-1"

578-590: ```shell
#!/bin/bash

Show the full signature and docstring of _cal_hg_dynamic in RepFlowLayer

rg -n -A10 "def _cal_hg_dynamic" deepmd/pt/model/descriptor/repflow_layer.py


</details>

</blockquote></details>

</details>

<!-- This is an auto-generated comment by CodeRabbit for review status -->

Comment on lines 464 to 491
if self.use_dynamic_sel:
# get graph index
edge_index, angle_index = get_graph_index(
nlist, nlist_mask, a_nlist_mask, nall
)
# flat all the tensors
# n_edge x 1
edge_input = edge_input[nlist_mask]
# n_edge x 3
h2 = h2[nlist_mask]
# n_edge x 1
sw = sw[nlist_mask]
# nb x nloc x a_nnei x a_nnei
a_nlist_mask = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :]
# n_angle x 1
angle_input = angle_input[a_nlist_mask]
# n_angle x 1
a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask]
else:
# avoid jit assertion
edge_index = angle_index = torch.zeros(
[1, 3], device=nlist.device, dtype=nlist.dtype
)
# get edge and angle embedding
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim
edge_ebd = self.act(self.edge_embd(edge_input))
# nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim
angle_ebd = self.angle_embd(angle_input)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Complex dynamic selection logic needs verification.

The dynamic selection implementation introduces significant conditional branching:

  1. Graph index computation: Uses get_graph_index to compute edge_index and angle_index
  2. Tensor flattening: Flattens tensors when use_dynamic_sel=True
  3. Masking logic: Complex angle mask computation on line 477
  4. Fallback tensors: Creates dummy zero tensors when dynamic selection is disabled

The logic appears correct but is complex. Verify that:

  • The angle masking logic a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :] produces the expected tensor shape
  • The flattened tensors maintain correct correspondence with graph indices

Consider adding comments to explain the complex tensor manipulations, especially:

         if self.use_dynamic_sel:
             # get graph index
             edge_index, angle_index = get_graph_index(
                 nlist, nlist_mask, a_nlist_mask, nall
             )
-            # flat all the tensors
+            # Flatten tensors for dynamic selection - only keep valid entries
             # n_edge x 1
             edge_input = edge_input[nlist_mask]
             # n_edge x 3
             h2 = h2[nlist_mask]
             # n_edge x 1
             sw = sw[nlist_mask]
-            # nb x nloc x a_nnei x a_nnei
+            # Create angle pair mask: valid if both neighbors are valid
             a_nlist_mask = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :]
             # n_angle x 1
             angle_input = angle_input[a_nlist_mask]
-            # n_angle x 1
+            # Angle switch weights: product of individual neighbor weights
             a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if self.use_dynamic_sel:
# get graph index
edge_index, angle_index = get_graph_index(
nlist, nlist_mask, a_nlist_mask, nall
)
# flat all the tensors
# n_edge x 1
edge_input = edge_input[nlist_mask]
# n_edge x 3
h2 = h2[nlist_mask]
# n_edge x 1
sw = sw[nlist_mask]
# nb x nloc x a_nnei x a_nnei
a_nlist_mask = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :]
# n_angle x 1
angle_input = angle_input[a_nlist_mask]
# n_angle x 1
a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask]
else:
# avoid jit assertion
edge_index = angle_index = torch.zeros(
[1, 3], device=nlist.device, dtype=nlist.dtype
)
# get edge and angle embedding
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim
edge_ebd = self.act(self.edge_embd(edge_input))
# nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim
angle_ebd = self.angle_embd(angle_input)
if self.use_dynamic_sel:
# get graph index
edge_index, angle_index = get_graph_index(
nlist, nlist_mask, a_nlist_mask, nall
)
# Flatten tensors for dynamic selection - only keep valid entries
# n_edge x 1
edge_input = edge_input[nlist_mask]
# n_edge x 3
h2 = h2[nlist_mask]
# n_edge x 1
sw = sw[nlist_mask]
# Create angle pair mask: valid if both neighbors are valid
a_nlist_mask = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :]
# n_angle x 1
angle_input = angle_input[a_nlist_mask]
# Angle switch weights: product of individual neighbor weights
a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask]
else:
# avoid jit assertion
edge_index = angle_index = torch.zeros(
[1, 3], device=nlist.device, dtype=nlist.dtype
)
# get edge and angle embedding
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim
edge_ebd = self.act(self.edge_embd(edge_input))
# nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim
angle_ebd = self.angle_embd(angle_input)
🤖 Prompt for AI Agents
In deepmd/pt/model/descriptor/repflows.py around lines 464 to 491, verify that
the angle masking logic using a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :,
None, :] correctly produces the intended tensor shape and that the flattened
tensors like edge_input[nlist_mask] and angle_input[a_nlist_mask] maintain
proper alignment with the computed graph indices. Add clear comments explaining
the purpose and expected shapes of these tensor manipulations to improve code
readability and maintainability.

iProzd added 2 commits July 18, 2025 23:52
Replaces comm_dict.insert with comm_dict.insert_or_assign for the 'send_list' key in both DeepPotPT.cc and DeepSpinPT.cc. This ensures that the value is updated if the key already exists, preventing potential issues with duplicate key insertion.
Replaces comm_dict.insert with comm_dict.insert_or_assign for all tensor assignments in DeepPotPT.cc and DeepSpinPT.cc. This ensures that existing keys are updated rather than causing errors or duplications, improving robustness when keys may already exist.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] pt: support virial for the spin model
1 participant