-
Notifications
You must be signed in to change notification settings - Fork 558
feat(pt): support spin virial #4545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: devel
Are you sure you want to change the base?
Conversation
📝 Walkthrough## Walkthrough
This pull request introduces enhanced support for virial calculations in the DeepMD-kit framework, specifically focusing on spin-related models. The changes span multiple files across the project, adding new functionality to handle virial outputs, coordinate corrections, and spin-related computations. The modifications enable more comprehensive energy and virial loss calculations, with updates to model processing, loss computation, and testing frameworks. Additionally, a dynamic neighbor selection mechanism is introduced in the repflow descriptor layers, controlled by new parameters and supported by new utility functions and tests.
## Changes
| File | Change Summary |
|------|----------------|
| `deepmd/pt/loss/ener_spin.py` | Added virial loss calculation support in `EnergySpinLoss` class, updated `label_requirement` for "virial". |
| `deepmd/pt/model/model/make_model.py` | Introduced `coord_corr_for_virial` parameter in `forward_common` and `forward_common_lower` methods. |
| `deepmd/pt/model/model/spin_model.py` | Enhanced spin input processing and model forward methods to include spin-related corrections and virial output handling. |
| `deepmd/pt/model/model/transform_output.py` | Added `extended_coord_corr` parameter to `fit_output_to_model_output` function for derivative correction. |
| `deepmd/pt/model/descriptor/repflows.py` | Added dynamic neighbor selection support with `use_dynamic_sel` and `sel_reduce_factor` parameters; refactored forward method to handle dynamic graph indices. |
| `deepmd/dpmodel/descriptor/dpa3.py` | Added `edge_init_use_dist`, `use_exp_switch`, `use_dynamic_sel`, and `sel_reduce_factor` parameters to `RepFlowArgs` constructor. |
| `deepmd/pt/model/descriptor/dpa3.py` | Passed dynamic selection parameters from `repflow_args` to `DescrptBlockRepflows` in `DescrptDPA3` constructor. |
| `deepmd/pt/model/descriptor/repflow_layer.py` | Added dynamic selection logic and new methods for dynamic aggregation and updates; extended constructor and forward method to support dynamic neighbor and angle indexing. |
| `deepmd/pt/model/network/utils.py` | Added `aggregate` and `get_graph_index` utility functions for dynamic graph indexing and aggregation. |
| `deepmd/utils/argcheck.py` | Added optional arguments `edge_init_use_dist`, `use_exp_switch`, `use_dynamic_sel`, and `sel_reduce_factor` to `dpa3_repflow_args` function. |
| `deepmd/pt/model/descriptor/env_mat.py` | Added `use_exp_switch` parameter to environment matrix functions to select between smoothing weight functions. |
| `deepmd/pt/utils/preprocess.py` | Added `compute_exp_sw` function implementing an exponential switch function alternative. |
| `source/api_c/include/deepmd.hpp` | Fixed virial array population in `DeepSpinModelDevi` class's `compute` methods. |
| `source/api_c/src/c_api.cc` | Restored virial output handling in `DP_DeepSpinModelDeviCompute_variant` function. |
| `source/api_cc/src/DeepSpinPT.cc` | Implemented virial tensor retrieval and assignment in `DeepSpinPT` class's `compute` methods; changed comm_dict insertion to `insert_or_assign`. |
| `source/tests/pt/model/test_autodiff.py` | Added spin-related test cases and logic in `VirialTest` and `ForceTest` classes; introduced `TestEnergyModelSpinSeAVirial` class. |
| `source/tests/pt/model/test_ener_spin_model.py` | Updated spin input processing tests in `SpinTest` class to handle additional return values. |
| `source/tests/universal/common/cases/model/utils.py` | Modified virial calculation logic to include new conditions for spin virial testing. |
| `source/tests/universal/pt/model/test_model.py` | Added `test_spin_virial` property to `TestSpinEnergyModelDP` class to enable spin virial testing. |
| `source/tests/pt/model/test_nosel.py` | Added tests for `DescrptDPA3` descriptor with dynamic neighbor selection, validating output consistency. |
## Suggested Reviewers
- njzjz
- wanghan-iapcm 📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (2)
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (12)
deepmd/pt/model/model/spin_model.py (3)
58-59
: Enhance code readability by reshaping tensor in a single stepIn line 58, consider reshaping
self.virtual_scale_mask.to(atype.device)[atype]
directly without wrapping it in parentheses for better readability.
379-381
: Avoid unnecessary computation by not callingprocess_spin_input
during stat computationIn the
compute_or_load_stat
method, callingprocess_spin_input
may introduce unnecessary computational overhead ifcoord_corr
is not used. Consider modifying the code to excludecoord_corr
when it's not needed.
591-594
: Consider consistency in handlingdo_grad_c
checksIn the
forward
method, ensure that the handling ofdo_grad_c("energy")
and subsequent assignments align with the changes made intranslated_output_def
. This maintains consistency across the methods.source/tests/pt/model/test_autodiff.py (4)
144-144
: Initialize thespin
variable only when necessaryThe
spin
variable is initialized even whentest_spin
isFalse
. Consider moving the initialization inside the conditional block to optimize performance.Apply this diff to adjust the initialization:
- spin = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator)
Move the initialization to after line 150, within the
if test_spin
block.
148-148
: Ensurespin
is only converted to NumPy when necessarySimilar to the previous comment, the conversion of
spin
to a NumPy array should be conditional based ontest_spin
to avoid unnecessary computations.
151-154
: Simplify the conditional assignment oftest_keys
The assignment of
test_keys
can be streamlined for clarity.Apply this diff to simplify the code:
- if not test_spin: - test_keys = ["energy", "force", "virial"] - else: - test_keys = ["energy", "force", "force_mag", "virial"] + test_keys = ["energy", "force", "virial"] + if test_spin: + test_keys.append("force_mag")
263-268
: Add a newline for code style consistencyInclude a blank line after the class definition to follow PEP 8 style guidelines for better readability.
Apply this diff:
class TestEnergyModelSpinSeAVirial(unittest.TestCase, VirialTest): + def setUp(self) -> None: model_params = copy.deepcopy(model_spin)
deepmd/pt/model/model/transform_output.py (3)
159-159
: Update function documentation to include new parameterThe
fit_output_to_model_output
function has a new parameterextended_coord_corr
. Update the docstring to describe this parameter and its role in the computation.
195-195
: Avoid using# noqa
comments for line lengthInstead of using
# noqa: RUF005
to suppress line length warnings, refactor the code to comply with style guidelines for better maintainability.Apply this diff to split the line:
- ).view(list(dc.shape[:-2]) + [1, 9]) # noqa: RUF005 + ) + dc = dc.view(list(dc.shape[:-2]) + [1, 9])
Line range hint
226-226
: Consider adding type annotations for function returnsAdding type annotations to functions enhances code clarity and aids in static analysis. Consider specifying the return types for the functions in this module.
deepmd/pt/loss/ener_spin.py (1)
271-286
: LGTM! The virial loss calculation is well implemented.The implementation follows the established pattern for loss calculations, with proper scaling and optional MAE computation. The code is clean and well-structured.
Consider extracting the common pattern of loss calculation (L2, MAE, scaling) into a helper method to reduce code duplication across energy, force, and virial loss calculations.
source/tests/pt/model/test_ener_spin_model.py (1)
118-118
: Document the purpose of the ignored return values.The additional return values (marked with
_
) fromprocess_spin_input
andprocess_spin_input_lower
are silently ignored. Consider adding a comment explaining what these values represent and why they can be safely ignored in these tests.Also applies to: 177-177
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
deepmd/pt/loss/ener_spin.py
(1 hunks)deepmd/pt/model/model/make_model.py
(7 hunks)deepmd/pt/model/model/spin_model.py
(11 hunks)deepmd/pt/model/model/transform_output.py
(2 hunks)source/api_c/include/deepmd.hpp
(2 hunks)source/api_c/src/c_api.cc
(1 hunks)source/api_cc/src/DeepSpinPT.cc
(4 hunks)source/tests/pt/model/test_autodiff.py
(3 hunks)source/tests/pt/model/test_ener_spin_model.py
(2 hunks)source/tests/universal/common/cases/model/utils.py
(3 hunks)source/tests/universal/pt/model/test_model.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (21)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Analyze (python)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test C++ (false)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test C++ (true)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C++ (cpu, cpu)
🔇 Additional comments (13)
deepmd/pt/model/model/spin_model.py (4)
63-64
: Ensure proper alignment of coordinate correctionsThe concatenation of tensors in
coord_corr
must maintain correct alignment with the corresponding atoms. Verify thattorch.zeros_like(coord)
and-spin_dist
are correctly ordered, ensuring that the coordinate corrections apply to the appropriate atoms.
92-95
: Validate the consistency of virtual atom handlingWhen creating
extended_coord_corr
, confirm that the virtual atoms are correctly accounted for, and that the concatenation preserves the intended structure. This is crucial for accurate virial calculations involving spin corrections.
410-412
: Handle the new outputcoord_corr_for_virial
appropriatelyEnsure that all downstream methods that receive
coord_corr_for_virial
can handle this new parameter without errors. Verify thatself.backbone_model.forward_common
acceptscoord_corr_for_virial
as an argument.
631-636
: Verify accurate squeezing of tensors and assignmentIn the
forward_lower
method, confirm that thesqueeze
operations correctly reduce tensor dimensions and that the results are assigned to the appropriate keys inmodel_predict
.source/tests/pt/model/test_autodiff.py (1)
150-150
: Conditionally handlespin
andtest_spin
variablesVerify that all usages of
spin
andtest_spin
within theVirialTest
class are properly guarded by conditionals to prevent errors whentest_spin
isFalse
.deepmd/pt/model/model/transform_output.py (1)
191-196
: Ensure tensor shapes are compatible during matrix multiplicationIn the computation of
dc_corr
, validate that the shapes of the tensors involved in the matrix multiplication are compatible to prevent runtime errors.deepmd/pt/model/model/make_model.py (2)
138-138
: LGTM! The parameter addition is well documented.The new
coord_corr_for_virial
parameter is properly documented with clear shape information.Also applies to: 157-159
187-194
: LGTM! The coordinate correction handling is robust.The implementation properly handles both cases:
- When correction is provided: converts to correct dtype and gathers extended coordinates
- When correction is not provided: sets extended_coord_corr to None
source/api_cc/src/DeepSpinPT.cc (1)
254-254
: LGTM! The virial output handling is properly implemented.The implementation correctly:
- Retrieves virial tensor from outputs
- Converts tensor to appropriate float type
- Transfers data to CPU
- Assigns to output vector
Also applies to: 269-273, 417-417, 432-435
source/tests/universal/pt/model/test_model.py (1)
716-717
: Note the temporary nature of this flag.The added
test_spin_virial
flag enables spin virial testing for this backend. As indicated by the comment, this can be removed once other backends support spin virial calculations.Let's verify which backends currently lack spin virial support:
✅ Verification successful
Temporary flag is correctly annotated
The comment accurately reflects the current state of the codebase. The spin virial implementation is currently only available in the PyTorch (PT) backend, and the flag is appropriately marked as temporary until other backends (TensorFlow, PaddlePaddle) implement this functionality.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Search for spin virial implementations across backends rg -A 5 "spin.*virial" source/api_cc/src/Length of output: 43
Script:
#!/bin/bash # Broader search for spin virial across the codebase echo "=== Searching for spin virial across the codebase ===" rg -i "spin.*virial" --type py echo -e "\n=== Searching for related test files ===" fd "test.*model.*py$" echo -e "\n=== Looking for backend-specific implementations ===" rg -i "spin.*virial" --type py -A 5 source/tests/Length of output: 5891
source/tests/universal/common/cases/model/utils.py (1)
895-898
: Improved control flow for spin virial testing.The condition ensures virial computation only happens when appropriate - either for non-spin models or when spin virial testing is explicitly enabled.
source/api_c/src/c_api.cc (1)
865-869
: Re-enabled virial data propagation.The previously commented-out code for copying virial data has been restored, allowing proper propagation of virial values in spin model deviation computations.
source/api_c/include/deepmd.hpp (1)
2605-2607
: Re-enabled virial data propagation in C++ interface.The previously commented-out code for copying virial data has been restored, maintaining consistency with the C API changes.
@@ -408,6 +419,7 @@ def forward_common( | |||
fparam=fparam, | |||
aparam=aparam, | |||
do_atomic_virial=do_atomic_virial, | |||
coord_corr_for_virial=coord_corr_for_virial, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Confirm the compatibility of the new argument coord_corr_for_virial
Check if the backbone model's forward_common
method is designed to accept coord_corr_for_virial
. If not, update the backbone model accordingly or modify the call to prevent runtime errors.
@@ -469,6 +482,7 @@ | |||
do_atomic_virial=do_atomic_virial, | |||
comm_dict=comm_dict, | |||
extra_nlist_sort=extra_nlist_sort, | |||
extended_coord_corr=extended_coord_corr_for_virial, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure extended_coord_corr
is accepted by forward_common_lower
Similar to the previous comment, verify that self.backbone_model.forward_common_lower
accepts extended_coord_corr
as an argument. This prevents potential issues during model execution.
if self.do_grad_c("energy"): | ||
output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"]) | ||
output_def["virial"].squeeze(-2) | ||
output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"]) | ||
output_def["atom_virial"].squeeze(-3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assign the result of squeeze
operations to reduce tensor dimensions
The squeeze
methods in lines 560 and 562 do not modify tensors in place. Assign the results to ensure the dimensions are correctly reduced.
Apply this diff to fix the issue:
- output_def["virial"].squeeze(-2)
+ output_def["virial"] = output_def["virial"].squeeze(-2)
- output_def["atom_virial"].squeeze(-3)
+ output_def["atom_virial"] = output_def["atom_virial"].squeeze(-3)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if self.do_grad_c("energy"): | |
output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"]) | |
output_def["virial"].squeeze(-2) | |
output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"]) | |
output_def["atom_virial"].squeeze(-3) | |
if self.do_grad_c("energy"): | |
output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"]) | |
output_def["virial"] = output_def["virial"].squeeze(-2) | |
output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"]) | |
output_def["atom_virial"] = output_def["atom_virial"].squeeze(-3) |
spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0), | ||
) | ||
# detach | ||
ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure compatibility of tensor devices
When creating tensors within the np_infer
function, ensure that all tensors are on the same device to prevent device mismatch errors, especially when env.DEVICE
differs from "cpu"
.
Apply this diff to correct the device assignment:
- spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0),
+ spins=torch.tensor(spin, device=new_cell.device).unsqueeze(0),
Committable suggestion skipped: line range outside the PR's diff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
source/api_cc/src/DeepSpinPT.cc (1)
269-274
: Consider future support for per-atom virial.While global virial support is now implemented, per-atom virial remains commented out (lines 295-300). Consider whether this is a planned future enhancement for spin models.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
source/api_cc/src/DeepSpinPT.cc
(4 hunks)source/tests/universal/common/cases/model/utils.py
(3 hunks)source/tests/universal/pt/model/test_model.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- source/tests/universal/pt/model/test_model.py
- source/tests/universal/common/cases/model/utils.py
⏰ Context from checks skipped due to timeout of 90000ms (29)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test C++ (false)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Analyze (python)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test C++ (true)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (4)
source/api_cc/src/DeepSpinPT.cc (4)
254-254
: Enabled virial tensor retrieval.The code now retrieves the virial tensor from the model's outputs, enabling support for virial calculations in the DeepSpinPT class.
270-273
: Successfully implemented virial tensor processing.Added proper tensor conversion, CPU transfer, and data assignment for the virial tensor, following the same pattern used for energy and forces. This enables complete virial support for spin models.
417-417
: Enabled virial tensor retrieval in overloaded method.Consistent implementation of virial retrieval in the overloaded compute method, maintaining parity with the primary implementation.
432-435
: Successfully implemented virial tensor processing in overloaded method.The implementation correctly follows the same pattern used in the primary compute method, ensuring consistent virial handling across both implementations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🔭 Outside diff range comments (2)
deepmd/dpmodel/descriptor/dpa3.py (1)
191-215
: Add new parameters to the serialize method.The newly added parameters
use_dynamic_sel
andsel_reduce_factor
are not included in theserialize
method, which will cause issues when saving and loading models.Add these parameters to the serialize method:
def serialize(self) -> dict: return { "n_dim": self.n_dim, "e_dim": self.e_dim, "a_dim": self.a_dim, "nlayers": self.nlayers, "e_rcut": self.e_rcut, "e_rcut_smth": self.e_rcut_smth, "e_sel": self.e_sel, "a_rcut": self.a_rcut, "a_rcut_smth": self.a_rcut_smth, "a_sel": self.a_sel, "a_compress_rate": self.a_compress_rate, "a_compress_e_rate": self.a_compress_e_rate, "a_compress_use_split": self.a_compress_use_split, "n_multi_edge_message": self.n_multi_edge_message, "axis_neuron": self.axis_neuron, "update_angle": self.update_angle, "update_style": self.update_style, "update_residual": self.update_residual, "update_residual_init": self.update_residual_init, "fix_stat_std": self.fix_stat_std, "optim_update": self.optim_update, "smooth_edge_update": self.smooth_edge_update, + "use_dynamic_sel": self.use_dynamic_sel, + "sel_reduce_factor": self.sel_reduce_factor, }deepmd/pt/model/descriptor/repflow_layer.py (1)
1179-1243
: Add new dynamic selection parameters to serialization.The
use_dynamic_sel
andsel_reduce_factor
parameters are not included in the serialization, which will cause issues when saving and loading models that use dynamic selection.Update the serialize method to include these parameters:
def serialize(self) -> dict: """Serialize the networks to a dict. Returns ------- dict The serialized networks. """ data = { "@class": "RepFlowLayer", "@version": 1, "e_rcut": self.e_rcut, "e_rcut_smth": self.e_rcut_smth, "e_sel": self.e_sel, "a_rcut": self.a_rcut, "a_rcut_smth": self.a_rcut_smth, "a_sel": self.a_sel, "ntypes": self.ntypes, "n_dim": self.n_dim, "e_dim": self.e_dim, "a_dim": self.a_dim, "a_compress_rate": self.a_compress_rate, "a_compress_e_rate": self.a_compress_e_rate, "a_compress_use_split": self.a_compress_use_split, "n_multi_edge_message": self.n_multi_edge_message, "axis_neuron": self.axis_neuron, "activation_function": self.activation_function, "update_angle": self.update_angle, "update_style": self.update_style, "update_residual": self.update_residual, "update_residual_init": self.update_residual_init, "precision": self.precision, "optim_update": self.optim_update, "smooth_edge_update": self.smooth_edge_update, + "use_dynamic_sel": self.use_dynamic_sel, + "sel_reduce_factor": self.sel_reduce_factor, "node_self_mlp": self.node_self_mlp.serialize(),
🧹 Nitpick comments (7)
deepmd/pt/model/network/utils.py (2)
43-45
: Consider using broadcasting for better readability.The transpose operations for division could be simplified using broadcasting.
Apply this diff to simplify the division:
- if average: - output = (output.T / bin_count).T - return output + if average: + output = output / bin_count.unsqueeze(1) + return output
86-91
: Clean up commented code for clarity.There are commented-out lines that should be removed if not needed, or uncommented if they serve a purpose.
- # nf x nloc x nnei x nnei - # nlist_mask_3d = nlist_mask[:, :, :, None] & nlist_mask[:, :, None, :] a_nlist_mask_3d = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :] n_edge = nlist_mask.sum().item() - # n_angle = a_nlist_mask_3d.sum().item()deepmd/pt/model/descriptor/repflows.py (1)
479-482
: Consider adding a comment explaining the JIT workaround.The dummy tensor creation for JIT compatibility could benefit from more explanation.
else: - # avoid jit assertion + # Create dummy tensors to avoid JIT assertion errors when dynamic selection is disabled. + # These tensors are not used in the computation but are required for type consistency. edge_index = angle_index = torch.zeros( [1, 3], device=nlist.device, dtype=nlist.dtype )source/tests/pt/model/test_nosel.py (2)
90-91
: Consider adding tests with different sel_reduce_factor values.The test currently only uses
sel_reduce_factor=1.0
. Consider adding test cases with other values to ensure the dynamic selection works correctly with different reduction factors.
143-206
: Remove or enable the commented test_jit method.The JIT test is completely commented out. Either remove it if it's not needed, or fix and enable it if JIT compatibility is important.
Would you like me to help implement a working JIT test or should this commented code be removed?
deepmd/pt/model/descriptor/repflow_layer.py (2)
806-806
: Consider numerical stability when using float division results as scaling factors.The
dynamic_e_sel
anddynamic_a_sel
are computed as float divisions and used in scaling operations. For better numerical stability and clarity, consider storing the integer values separately.self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor -self.dynamic_e_sel = self.nnei / self.sel_reduce_factor -self.dynamic_a_sel = self.a_sel / self.sel_reduce_factor +self.dynamic_e_sel = int(self.nnei / self.sel_reduce_factor) +self.dynamic_a_sel = int(self.a_sel / self.sel_reduce_factor) +self.dynamic_e_scale = self.dynamic_e_sel ** (-0.5) +self.dynamic_a_scale = self.dynamic_a_sel ** (-0.5)Then use the pre-computed scale factors instead of computing them repeatedly:
-scale_factor=self.dynamic_e_sel ** (-0.5), +scale_factor=self.dynamic_e_scale,Also applies to: 825-825, 894-896, 1064-1064
731-743
: Clarify when edge_index and angle_index parameters are required.The documentation should explicitly state that these parameters are required when
use_dynamic_sel
is True.Update the parameter documentation:
- edge_index : Optional for dynamic sel, n_edge x 2 + edge_index : n_edge x 2 (required when use_dynamic_sel is True) n2e_index : n_edge Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). n_ext2e_index : n_edge Broadcast indices from extended node(j) to edge(ij). - angle_index : Optional for dynamic sel, n_angle x 3 + angle_index : n_angle x 3 (required when use_dynamic_sel is True)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
deepmd/dpmodel/descriptor/dpa3.py
(2 hunks)deepmd/pt/model/descriptor/dpa3.py
(1 hunks)deepmd/pt/model/descriptor/repflow_layer.py
(17 hunks)deepmd/pt/model/descriptor/repflows.py
(7 hunks)deepmd/pt/model/network/utils.py
(1 hunks)deepmd/utils/argcheck.py
(1 hunks)source/tests/pt/model/test_nosel.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.11.9)
deepmd/pt/model/descriptor/repflow_layer.py
754-754: Local variable nall
is assigned to but never used
Remove assignment to unused variable nall
(F841)
🔇 Additional comments (5)
deepmd/pt/model/descriptor/dpa3.py (1)
153-154
: LGTM!The new parameters are correctly passed from
repflow_args
to theDescrptBlockRepflows
constructor.deepmd/pt/model/descriptor/repflows.py (1)
574-586
: Verify the scale factor calculation.The scale factor
(self.nnei / self.sel_reduce_factor) ** (-0.5)
is used in dynamic mode. Please ensure this scaling is appropriate for maintaining numerical stability across different neighbor counts.Could you provide documentation or references for why this specific scale factor formula was chosen?
deepmd/pt/model/descriptor/repflow_layer.py (3)
330-381
: Well-implemented dynamic calculation method.The
_cal_hg_dynamic
method is properly documented and implements the dynamic version of the transposed rotation matrix calculation correctly. The use of the aggregate function with proper reshaping maintains mathematical equivalence with the original method.
454-506
: Clean implementation of dynamic symmetrization.The method properly implements the dynamic version of the symmetrization operation by reusing existing components (_cal_hg_dynamic and _cal_grrg). Good code reuse and clear documentation.
549-607
: Well-implemented dynamic angle update with proper validation.The
optim_angle_update_dynamic
method correctly implements the dynamic version of angle updates with:
- Clear index range calculations for sub-matrices
- Proper dimension validation via assertion
- Efficient use of
torch.index_select
for gathering
Argument( | ||
"use_dynamic_sel", | ||
bool, | ||
optional=True, | ||
default=False, | ||
), | ||
Argument( | ||
"sel_reduce_factor", | ||
float, | ||
optional=True, | ||
default=10.0, | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add documentation for the new dynamic selection arguments.
Both new arguments use_dynamic_sel
and sel_reduce_factor
are missing documentation strings. All other arguments in this function include comprehensive doc
parameters to explain their purpose and usage. This is particularly important since this module generates user-facing documentation.
Please add documentation for both arguments:
Argument(
"use_dynamic_sel",
bool,
optional=True,
default=False,
+ doc="Whether to enable dynamic neighbor selection mechanism. When enabled, the neighbor selection will be dynamically adjusted during training.",
),
Argument(
"sel_reduce_factor",
float,
optional=True,
default=10.0,
+ doc="The reduction factor for dynamic neighbor selection. Controls how aggressively the neighbor count is reduced when dynamic selection is enabled.",
),
Note: The exact documentation content may need adjustment based on the specific implementation details of these features.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
Argument( | |
"use_dynamic_sel", | |
bool, | |
optional=True, | |
default=False, | |
), | |
Argument( | |
"sel_reduce_factor", | |
float, | |
optional=True, | |
default=10.0, | |
), | |
Argument( | |
"use_dynamic_sel", | |
bool, | |
optional=True, | |
default=False, | |
doc="Whether to enable dynamic neighbor selection mechanism. When enabled, the neighbor selection will be dynamically adjusted during training.", | |
), | |
Argument( | |
"sel_reduce_factor", | |
float, | |
optional=True, | |
default=10.0, | |
doc="The reduction factor for dynamic neighbor selection. Controls how aggressively the neighbor count is reduced when dynamic selection is enabled.", | |
), |
🤖 Prompt for AI Agents
In deepmd/utils/argcheck.py around lines 1600 to 1611, the new arguments
use_dynamic_sel and sel_reduce_factor lack docstrings explaining their purpose
and usage. Add appropriate doc parameters to both Argument instances, providing
clear and concise descriptions that align with their functionality, similar to
other documented arguments in this function, to ensure they are properly
documented for user-facing documentation generation.
use_dynamic_sel: bool = False, | ||
sel_reduce_factor: float = 10.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add validation for sel_reduce_factor
to prevent division by zero.
The sel_reduce_factor
parameter is used as a divisor without validation. If a user passes 0 or a negative value, it will cause runtime errors.
Apply this diff to add parameter validation:
def __init__(
self,
# ... other parameters ...
use_dynamic_sel: bool = False,
sel_reduce_factor: float = 10.0,
# ... other parameters ...
) -> None:
super().__init__()
+ if sel_reduce_factor <= 0:
+ raise ValueError(f"sel_reduce_factor must be positive, got {sel_reduce_factor}")
self.epsilon = 1e-4 # protection of 1./nnei
Also applies to: 106-109
🤖 Prompt for AI Agents
In deepmd/pt/model/descriptor/repflow_layer.py around lines 58-59 and 106-109,
the sel_reduce_factor parameter is used as a divisor but lacks validation,
risking division by zero or errors with negative values. Add validation checks
to ensure sel_reduce_factor is greater than zero before it is used, raising a
ValueError or similar if the condition is not met to prevent runtime errors.
edge_index: torch.Tensor, # n_edge x 2 | ||
angle_index: torch.Tensor, # n_angle x 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❌ Incorrect review comment
Add validation for required dynamic selection parameters.
When use_dynamic_sel
is True, the edge_index
and angle_index
parameters are required but this isn't validated.
Run the following script to check if there are any existing validations in the codebase:
Consider adding validation at the beginning of the forward method:
def forward(
self,
# ... parameters ...
edge_index: torch.Tensor, # n_edge x 2
angle_index: torch.Tensor, # n_angle x 3
):
+ if self.use_dynamic_sel:
+ if edge_index is None or angle_index is None:
+ raise ValueError("edge_index and angle_index must be provided when use_dynamic_sel is True")
Also applies to: 758-761
🏁 Script executed:
#!/bin/bash
# Description: Check for edge_index and angle_index validation in the codebase
# Search for validation of edge_index and angle_index when use_dynamic_sel is True
rg -B 5 -A 5 "edge_index.*angle_index.*use_dynamic_sel|use_dynamic_sel.*edge_index.*angle_index"
# Check how these parameters are passed in the descriptor files
fd -e py -x rg -l "RepFlowLayer" {} \; | xargs rg -B 3 -A 3 "edge_index.*angle_index"
Length of output: 1279
Add validation for required dynamic selection parameters.
When use_dynamic_sel
is True, the edge_index
and angle_index
parameters are required but this isn't validated.
Run the following script to check if there are any existing validations in the codebase:
#!/bin/bash
# Description: Check for edge_index and angle_index validation in the codebase
# Search for validation of edge_index and angle_index when use_dynamic_sel is True
rg -B 5 -A 5 "edge_index.*angle_index.*use_dynamic_sel|use_dynamic_sel.*edge_index.*angle_index"
# Check how these parameters are passed in the descriptor files
fd -e py -x rg -l "RepFlowLayer" {} \; | xargs rg -B 3 -A 3 "edge_index.*angle_index"
Consider adding validation at the beginning of the forward method:
def forward(
self,
# ... parameters ...
edge_index: torch.Tensor, # n_edge x 2
angle_index: torch.Tensor, # n_angle x 3
):
+ if self.use_dynamic_sel:
+ if edge_index is None or angle_index is None:
+ raise ValueError("edge_index and angle_index must be provided when use_dynamic_sel is True")
Also applies to: 758-761
🤖 Prompt for AI Agents
In deepmd/pt/model/descriptor/repflow_layer.py around lines 705-706 and also
758-761, there is no validation to ensure that edge_index and angle_index are
provided when use_dynamic_sel is True. Add explicit checks at the start of the
forward method to raise an error if use_dynamic_sel is True but either
edge_index or angle_index is None or missing. This will enforce the requirement
and prevent runtime errors due to missing parameters.
if not self.smooth_edge_update: | ||
# will be deprecated in the future | ||
# not support dynamic index, will pass anyway | ||
if self.use_dynamic_sel: | ||
raise NotImplementedError( | ||
"smooth_edge_update must be True when use_dynamic_sel is True!" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Move the compatibility check to initialization for earlier error detection.
The check for smooth_edge_update
compatibility with use_dynamic_sel
should be performed during initialization rather than at runtime.
Move this validation to the __init__
method:
def __init__(
self,
# ... parameters ...
) -> None:
super().__init__()
+ if use_dynamic_sel and not smooth_edge_update:
+ raise ValueError("smooth_edge_update must be True when use_dynamic_sel is True")
if sel_reduce_factor <= 0:
raise ValueError(f"sel_reduce_factor must be positive, got {sel_reduce_factor}")
And remove the runtime check:
- if not self.smooth_edge_update:
- # will be deprecated in the future
- # not support dynamic index, will pass anyway
- if self.use_dynamic_sel:
- raise NotImplementedError(
- "smooth_edge_update must be True when use_dynamic_sel is True!"
- )
+ if not self.smooth_edge_update:
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if not self.smooth_edge_update: | |
# will be deprecated in the future | |
# not support dynamic index, will pass anyway | |
if self.use_dynamic_sel: | |
raise NotImplementedError( | |
"smooth_edge_update must be True when use_dynamic_sel is True!" | |
) | |
--- a/deepmd/pt/model/descriptor/repflow_layer.py | |
+++ b/deepmd/pt/model/descriptor/repflow_layer.py | |
@@ -1,6 +1,8 @@ | |
class RepFlowLayer(...): | |
def __init__( | |
self, | |
# ... parameters ... | |
) -> None: | |
super().__init__() | |
+ # fail fast on incompatible flags | |
+ if use_dynamic_sel and not smooth_edge_update: | |
+ raise ValueError("smooth_edge_update must be True when use_dynamic_sel is True") | |
if sel_reduce_factor <= 0: | |
raise ValueError(f"sel_reduce_factor must be positive, got {sel_reduce_factor}") | |
# ... rest of init ... | |
@@ -1066,9 +1068,4 @@ class RepFlowLayer(...): | |
- if not self.smooth_edge_update: | |
- # will be deprecated in the future | |
- # not support dynamic index, will pass anyway | |
- if self.use_dynamic_sel: | |
- raise NotImplementedError( | |
- "smooth_edge_update must be True when use_dynamic_sel is True!" | |
- ) | |
+ if not self.smooth_edge_update: | |
# (remaining logic under this branch) |
🤖 Prompt for AI Agents
In deepmd/pt/model/descriptor/repflow_layer.py around lines 1066 to 1072, the
compatibility check between smooth_edge_update and use_dynamic_sel is currently
done at runtime, which delays error detection. Move this validation into the
__init__ method of the class to catch incompatible configurations earlier during
object creation, and remove the runtime check from this section of the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
deepmd/pt/utils/preprocess.py (1)
20-29
: Enhance documentation and explain design choices.The function needs better documentation for the following:
Hardcoded constant: The value
C = 20
is hardcoded without explanation. Consider adding a comment explaining its significance or making it configurable.Clamping behavior: Unlike
compute_smooth_weight
which clamps to[rmin, rmax]
, this function clamps to[0.0, rmax]
. This difference should be documented.Mathematical formula: The docstring should explain the exponential double-exponential formula and its purpose.
def compute_exp_sw(distance, rmin: float, rmax: float): - """Compute the exponential switch function for neighbor update.""" + """Compute the exponential switch function for neighbor update. + + Implements exp(-exp(a * (distance - b))) where a = C/rmin, b = rmin. + Note: Unlike compute_smooth_weight, distance is clamped to [0.0, rmax]. + + Parameters + ---------- + distance : torch.Tensor + Input distances + rmin : float + Minimum radius for switch function + rmax : float + Maximum cutoff radius + """ if rmin >= rmax: raise ValueError("rmin should be less than rmax.") distance = torch.clamp(distance, min=0.0, max=rmax) - C = 20 + C = 20 # Exponential switch parameter - controls decay rate a = C / rmin b = rmin exp_sw = torch.exp(-torch.exp(a * (distance - b))) return exp_sw
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
deepmd/dpmodel/descriptor/dpa3.py
(2 hunks)deepmd/pt/model/descriptor/dpa3.py
(1 hunks)deepmd/pt/model/descriptor/env_mat.py
(6 hunks)deepmd/pt/model/descriptor/repflows.py
(8 hunks)deepmd/pt/utils/preprocess.py
(1 hunks)deepmd/utils/argcheck.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- deepmd/pt/model/descriptor/dpa3.py
- deepmd/utils/argcheck.py
- deepmd/dpmodel/descriptor/dpa3.py
🔇 Additional comments (6)
deepmd/pt/model/descriptor/env_mat.py (1)
6-6
: Well-implemented integration of exponential switch function.The implementation correctly:
- Imports the new function
- Adds the parameter with appropriate default (False) for backward compatibility
- Uses clear conditional logic to select the switch function
- Propagates the parameter through the call stack
- Updates documentation appropriately
Also applies to: 18-18, 38-42, 61-61, 74-74, 87-87
deepmd/pt/model/descriptor/repflows.py (5)
22-24
: LGTM: Import addition.The import of
get_graph_index
fromdeepmd.pt.model.network.utils
is correctly added to support the new dynamic selection functionality.
189-191
: Good parameter design with sensible defaults.The three new parameters are well-designed:
use_exp_switch
: Boolean flag with False default maintains backward compatibilityuse_dynamic_sel
: Boolean flag for new dynamic selection featuresel_reduce_factor
: Float with 10.0 default provides reasonable scalingThe parameters are properly stored as instance variables for use throughout the class.
Also applies to: 224-226
279-280
: Correct parameter propagation.The parameters are properly propagated:
use_dynamic_sel
andsel_reduce_factor
passed toRepFlowLayer
initializationuse_exp_switch
passed to bothprod_env_mat
calls for edge and angle environment matricesAlso applies to: 409-409, 431-431
437-448
: Verify the logic change for padding and node embedding preparation.The code has been reorganized to:
- Set padding positions to 0 instead of -1
- Move node embedding preparation after angle neighbor list processing
- Add assertion for
extended_atype_embd
This reorganization appears to support the dynamic selection logic, but ensure that setting padding indices to 0 doesn't cause indexing issues elsewhere in the codebase.
#!/bin/bash # Description: Check if there are other parts of the codebase that expect nlist padding to be -1 # Expected: Find usages that might be affected by changing padding from -1 to 0 echo "Searching for code that checks nlist == -1:" rg -A 3 -B 3 "nlist.*== -1" echo "Searching for code that uses nlist with -1 padding:" rg -A 3 -B 3 "nlist.*-1"
578-590
: ```shell
#!/bin/bashShow the full signature and docstring of _cal_hg_dynamic in RepFlowLayer
rg -n -A10 "def _cal_hg_dynamic" deepmd/pt/model/descriptor/repflow_layer.py
</details> </blockquote></details> </details> <!-- This is an auto-generated comment by CodeRabbit for review status -->
if self.use_dynamic_sel: | ||
# get graph index | ||
edge_index, angle_index = get_graph_index( | ||
nlist, nlist_mask, a_nlist_mask, nall | ||
) | ||
# flat all the tensors | ||
# n_edge x 1 | ||
edge_input = edge_input[nlist_mask] | ||
# n_edge x 3 | ||
h2 = h2[nlist_mask] | ||
# n_edge x 1 | ||
sw = sw[nlist_mask] | ||
# nb x nloc x a_nnei x a_nnei | ||
a_nlist_mask = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :] | ||
# n_angle x 1 | ||
angle_input = angle_input[a_nlist_mask] | ||
# n_angle x 1 | ||
a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask] | ||
else: | ||
# avoid jit assertion | ||
edge_index = angle_index = torch.zeros( | ||
[1, 3], device=nlist.device, dtype=nlist.dtype | ||
) | ||
# get edge and angle embedding | ||
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim | ||
edge_ebd = self.act(self.edge_embd(edge_input)) | ||
# nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim | ||
angle_ebd = self.angle_embd(angle_input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Complex dynamic selection logic needs verification.
The dynamic selection implementation introduces significant conditional branching:
- Graph index computation: Uses
get_graph_index
to computeedge_index
andangle_index
- Tensor flattening: Flattens tensors when
use_dynamic_sel=True
- Masking logic: Complex angle mask computation on line 477
- Fallback tensors: Creates dummy zero tensors when dynamic selection is disabled
The logic appears correct but is complex. Verify that:
- The angle masking logic
a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :]
produces the expected tensor shape - The flattened tensors maintain correct correspondence with graph indices
Consider adding comments to explain the complex tensor manipulations, especially:
if self.use_dynamic_sel:
# get graph index
edge_index, angle_index = get_graph_index(
nlist, nlist_mask, a_nlist_mask, nall
)
- # flat all the tensors
+ # Flatten tensors for dynamic selection - only keep valid entries
# n_edge x 1
edge_input = edge_input[nlist_mask]
# n_edge x 3
h2 = h2[nlist_mask]
# n_edge x 1
sw = sw[nlist_mask]
- # nb x nloc x a_nnei x a_nnei
+ # Create angle pair mask: valid if both neighbors are valid
a_nlist_mask = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :]
# n_angle x 1
angle_input = angle_input[a_nlist_mask]
- # n_angle x 1
+ # Angle switch weights: product of individual neighbor weights
a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask]
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if self.use_dynamic_sel: | |
# get graph index | |
edge_index, angle_index = get_graph_index( | |
nlist, nlist_mask, a_nlist_mask, nall | |
) | |
# flat all the tensors | |
# n_edge x 1 | |
edge_input = edge_input[nlist_mask] | |
# n_edge x 3 | |
h2 = h2[nlist_mask] | |
# n_edge x 1 | |
sw = sw[nlist_mask] | |
# nb x nloc x a_nnei x a_nnei | |
a_nlist_mask = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :] | |
# n_angle x 1 | |
angle_input = angle_input[a_nlist_mask] | |
# n_angle x 1 | |
a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask] | |
else: | |
# avoid jit assertion | |
edge_index = angle_index = torch.zeros( | |
[1, 3], device=nlist.device, dtype=nlist.dtype | |
) | |
# get edge and angle embedding | |
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim | |
edge_ebd = self.act(self.edge_embd(edge_input)) | |
# nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim | |
angle_ebd = self.angle_embd(angle_input) | |
if self.use_dynamic_sel: | |
# get graph index | |
edge_index, angle_index = get_graph_index( | |
nlist, nlist_mask, a_nlist_mask, nall | |
) | |
# Flatten tensors for dynamic selection - only keep valid entries | |
# n_edge x 1 | |
edge_input = edge_input[nlist_mask] | |
# n_edge x 3 | |
h2 = h2[nlist_mask] | |
# n_edge x 1 | |
sw = sw[nlist_mask] | |
# Create angle pair mask: valid if both neighbors are valid | |
a_nlist_mask = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :] | |
# n_angle x 1 | |
angle_input = angle_input[a_nlist_mask] | |
# Angle switch weights: product of individual neighbor weights | |
a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask] | |
else: | |
# avoid jit assertion | |
edge_index = angle_index = torch.zeros( | |
[1, 3], device=nlist.device, dtype=nlist.dtype | |
) | |
# get edge and angle embedding | |
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim | |
edge_ebd = self.act(self.edge_embd(edge_input)) | |
# nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim | |
angle_ebd = self.angle_embd(angle_input) |
🤖 Prompt for AI Agents
In deepmd/pt/model/descriptor/repflows.py around lines 464 to 491, verify that
the angle masking logic using a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :,
None, :] correctly produces the intended tensor shape and that the flattened
tensors like edge_input[nlist_mask] and angle_input[a_nlist_mask] maintain
proper alignment with the computed graph indices. Add clear comments explaining
the purpose and expected shapes of these tensor manipulations to improve code
readability and maintainability.
Replaces comm_dict.insert with comm_dict.insert_or_assign for the 'send_list' key in both DeepPotPT.cc and DeepSpinPT.cc. This ensures that the value is updated if the key already exists, preventing potential issues with duplicate key insertion.
Replaces comm_dict.insert with comm_dict.insert_or_assign for all tensor assignments in DeepPotPT.cc and DeepSpinPT.cc. This ensures that existing keys are updated rather than causing errors or duplications, improving robustness when keys may already exist.
Summary by CodeRabbit
New Features
Bug Fixes
Tests