Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fix)Make the weighted avarange fit for all kinds of systems #4593

Open
wants to merge 41 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
b3a925d
fix and ut
SumGuo-88 Feb 10, 2025
08e4a55
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2025
bdc260a
fix ut bug (make the def name not start with 'test'
SumGuo-88 Feb 11, 2025
875ee01
Merge branch 'debug-weightedavg' of https://github.com/SumGuo-88/deep…
SumGuo-88 Feb 11, 2025
7d137d6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 11, 2025
8ab5ab9
Make code simple
SumGuo-88 Feb 11, 2025
33c4161
Merge branch 'debug-weightedavg' of https://github.com/SumGuo-88/deep…
SumGuo-88 Feb 11, 2025
db637f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 11, 2025
a6e5ee1
Merge branch 'devel' into debug-weightedavg
SumGuo-88 Feb 11, 2025
4ac6b45
Merge branch 'devel' into debug-weightedavg
SumGuo-88 Feb 11, 2025
fa0aa4f
check change
SumGuo-88 Feb 12, 2025
ab9ec6e
Merge branch 'debug-weightedavg' of https://github.com/SumGuo-88/deep…
SumGuo-88 Feb 12, 2025
d2c9c4b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 12, 2025
f87ef49
Merge branch 'devel' into debug-weightedavg
SumGuo-88 Feb 12, 2025
99d6942
reverse
SumGuo-88 Feb 12, 2025
984a78e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 12, 2025
603b6f2
reverse2
SumGuo-88 Feb 12, 2025
0f669b8
Merge branch 'debug-weightedavg' of https://github.com/SumGuo-88/deep…
SumGuo-88 Feb 12, 2025
5273168
remake the ut
SumGuo-88 Feb 18, 2025
57ba28f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 18, 2025
322b4c0
remake ut- still with error 2
SumGuo-88 Feb 18, 2025
8407b0c
Merge branch 'debug-weightedavg' of https://github.com/SumGuo-88/deep…
SumGuo-88 Feb 18, 2025
11d0c68
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 18, 2025
1942dfa
remake ut --still with error 3
SumGuo-88 Feb 18, 2025
c7ca682
Merge branch 'debug-weightedavg' of https://github.com/SumGuo-88/deep…
SumGuo-88 Feb 18, 2025
80a8589
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 18, 2025
4e882f3
remake test with two def
SumGuo-88 Feb 19, 2025
66bb904
Merge branch 'debug-weightedavg' of https://github.com/SumGuo-88/deep…
SumGuo-88 Feb 19, 2025
687b08d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 19, 2025
2b1f4bb
Merge branch 'devel' into debug-weightedavg
SumGuo-88 Feb 19, 2025
8d48ba4
remake all
SumGuo-88 Feb 19, 2025
b95ad66
Merge branch 'debug-weightedavg' of https://github.com/SumGuo-88/deep…
SumGuo-88 Feb 19, 2025
824472a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 19, 2025
accc321
make ops input right
SumGuo-88 Feb 19, 2025
7ea19b2
Merge branch 'debug-weightedavg' of https://github.com/SumGuo-88/deep…
SumGuo-88 Feb 19, 2025
daf9235
change def name
SumGuo-88 Feb 19, 2025
a949ebc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 19, 2025
7653e34
change to dict_to_return
SumGuo-88 Feb 19, 2025
edc1744
Merge branch 'debug-weightedavg' of https://github.com/SumGuo-88/deep…
SumGuo-88 Feb 19, 2025
5bfa822
coverage
SumGuo-88 Feb 20, 2025
96c2108
no coverage
SumGuo-88 Feb 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 65 additions & 28 deletions deepmd/entrypoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test(
)

if isinstance(dp, DeepPot):
err = test_ener(
err, find_energy, find_force, find_virial = test_ener(
dp,
data,
system,
Expand All @@ -143,6 +143,29 @@ def test(
atomic,
append_detail=(cc != 0),
)
err_part = {}

if find_energy == 1:
err_part["mae_e"] = err["mae_e"]
err_part["mae_ea"] = err["mae_ea"]
err_part["rmse_e"] = err["rmse_e"]
err_part["rmse_ea"] = err["rmse_ea"]

if find_force == 1:
if "rmse_f" in err:
err_part["mae_f"] = err["mae_f"]
err_part["rmse_f"] = err["rmse_f"]
else:
err_part["mae_fr"] = err["mae_fr"]
err_part["rmse_fr"] = err["rmse_fr"]
err_part["mae_fm"] = err["mae_fm"]
err_part["rmse_fm"] = err["rmse_fm"]
if find_virial == 1:
err_part["mae_v"] = err["mae_v"]
err_part["rmse_v"] = err["rmse_v"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note your modification is not covered by UT.


err = err_part

elif isinstance(dp, DeepDOS):
err = test_dos(
dp,
Expand Down Expand Up @@ -307,6 +330,9 @@ def test_ener(
data.add("hessian", 1, atomic=True, must=True, high_prec=False)

test_data = data.get_test()
find_energy = test_data.get("find_energy")
find_force = test_data.get("find_force")
find_virial = test_data.get("find_virial")
mixed_type = data.mixed_type
natoms = len(test_data["type"][0])
nframes = test_data["box"].shape[0]
Expand Down Expand Up @@ -448,7 +474,7 @@ def test_ener(
log.info(f"Force spin MAE : {mae_fm:e} eV/uB")
log.info(f"Force spin RMSE : {rmse_fm:e} eV/uB")

if data.pbc and not out_put_spin:
if data.pbc and not out_put_spin and find_virial == 1:
log.info(f"Virial MAE : {mae_v:e} eV")
log.info(f"Virial RMSE : {rmse_v:e} eV")
log.info(f"Virial MAE/Natoms : {mae_va:e} eV")
Expand Down Expand Up @@ -559,39 +585,50 @@ def test_ener(
append=append_detail,
)
if not out_put_spin:
dict_to_return = {
"mae_e": (mae_e, energy.size),
"mae_ea": (mae_ea, energy.size),
"mae_f": (mae_f, force.size),
"mae_v": (mae_v, virial.size),
"mae_va": (mae_va, virial.size),
"rmse_e": (rmse_e, energy.size),
"rmse_ea": (rmse_ea, energy.size),
"rmse_f": (rmse_f, force.size),
"rmse_v": (rmse_v, virial.size),
"rmse_va": (rmse_va, virial.size),
}
return (
{
"mae_e": (mae_e, energy.size),
"mae_ea": (mae_ea, energy.size),
"mae_f": (mae_f, force.size),
"mae_v": (mae_v, virial.size),
"mae_va": (mae_va, virial.size),
"rmse_e": (rmse_e, energy.size),
"rmse_ea": (rmse_ea, energy.size),
"rmse_f": (rmse_f, force.size),
"rmse_v": (rmse_v, virial.size),
"rmse_va": (rmse_va, virial.size),
},
find_energy,
find_force,
find_virial,
)
else:
dict_to_return = {
"mae_e": (mae_e, energy.size),
"mae_ea": (mae_ea, energy.size),
"mae_fr": (mae_fr, force_r.size),
"mae_fm": (mae_fm, force_m.size),
"mae_v": (mae_v, virial.size),
"mae_va": (mae_va, virial.size),
"rmse_e": (rmse_e, energy.size),
"rmse_ea": (rmse_ea, energy.size),
"rmse_fr": (rmse_fr, force_r.size),
"rmse_fm": (rmse_fm, force_m.size),
"rmse_v": (rmse_v, virial.size),
"rmse_va": (rmse_va, virial.size),
}
return (
{
"mae_e": (mae_e, energy.size),
"mae_ea": (mae_ea, energy.size),
"mae_fr": (mae_fr, force_r.size),
"mae_fm": (mae_fm, force_m.size),
"mae_v": (mae_v, virial.size),
"mae_va": (mae_va, virial.size),
"rmse_e": (rmse_e, energy.size),
"rmse_ea": (rmse_ea, energy.size),
"rmse_fr": (rmse_fr, force_r.size),
"rmse_fm": (rmse_fm, force_m.size),
"rmse_v": (rmse_v, virial.size),
"rmse_va": (rmse_va, virial.size),
},
find_energy,
find_force,
find_virial,
)
if dp.has_hessian:
dict_to_return["mae_h"] = (mae_h, hessian.size)
dict_to_return["rmse_h"] = (rmse_h, hessian.size)
return dict_to_return



def print_ener_sys_avg(avg: dict[str, float]) -> None:
"""Print errors summary for energy type potential.

Expand Down
100 changes: 100 additions & 0 deletions source/tests/pt/test_weighted_avg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import unittest
from unittest.mock import patch, MagicMock
from deepmd.entrypoints.test import test # Import the test function
from deepmd.infer.deep_pot import DeepPot # Import DeepPot

class TestDeepPotModel(unittest.TestCase):

@patch('deepmd.entrypoints.test.DeepEval') # Mock DeepEval class
@patch('deepmd.entrypoints.test.DeepmdData') # Mock DeepmdData class
@patch('deepmd.entrypoints.test.test_ener') # Mock test_ener function
@patch('deepmd.entrypoints.test.weighted_average') # Mock weighted_average function
@patch('builtins.open') # Mock the open function to avoid FileNotFoundError
def test_deep_pot(self, mock_open, mock_weighted_avg, mock_test_ener, mock_deepmd_data, mock_deep_eval):
# Mock the file reading behavior to return mock data instead
mock_open.return_value.__enter__.return_value.read.return_value = "mock_system_1\nmock_system_2"

# Setup mock return values
mock_deep_eval_instance = MagicMock()
mock_deep_eval.return_value = mock_deep_eval_instance
mock_deep_eval_instance.get_type_map.return_value = "mock_type_map"

mock_deepmd_data_instance = MagicMock()
mock_deepmd_data.return_value = mock_deepmd_data_instance

# Define the base_data to simulate the test_ener output
base_data = [
{ # System 1
"mae_e": (2.0, 5),
"mae_ea": (1.5, 5),
"rmse_e": (2.5, 5),
"rmse_ea": (2.0, 5),
"mae_f": (0.3, 15),
"rmse_f": (0.4, 15),
"mae_v": (1.2, 5),
"rmse_v": (1.5, 5),
"mae_va": (0.8, 5),
"rmse_va": (1.0, 5)
},
{ # System 2
"mae_e": (3.0, 10),
"mae_ea": (2.5, 10),
"rmse_e": (3.5, 10),
"rmse_ea": (3.0, 10),
"mae_f": (0.5, 30),
"rmse_f": (0.6, 30),
"mae_v": (2.0, 10),
"rmse_v": (2.5, 10),
"mae_va": (1.5, 10),
"rmse_va": (2.0, 10)
},
{ # System 3
"mae_e": (4.0, 15),
"mae_ea": (3.5, 15),
"rmse_e": (4.5, 15),
"rmse_ea": (4.0, 15),
"mae_f": (0.7, 45),
"rmse_f": (0.8, 45),
"mae_v": (3.0, 15),
"rmse_v": (3.5, 15),
"mae_va": (2.5, 15),
"rmse_va": (3.0, 15)
}
]

# Simulate err values for each system, adding the (1, 1, 1) triplet
mock_test_ener.return_value = (
base_data[0], # Using the first system's base data
1, # find_energy
1, # find_force
1 # find_virial
)

# Call the function with mock data
test(model="mock_model_path",
system="mock_system_path",
datafile="mock_datafile.txt", # Still passing mock file name
numb_test=10,
rand_seed=None,
shuffle_test=True,
detail_file="mock_detail.txt",
atomic=True)

# Check if mocks are called as expected
mock_deep_eval.assert_called_once_with("mock_model_path", head=None)
mock_deepmd_data.assert_called_once_with(
"mock_system_path", set_prefix="set", shuffle_test=True,
type_map="mock_type_map", sort_atoms=False
)
mock_test_ener.assert_called_once() # Check if test_ener was called for DeepPot
mock_weighted_avg.assert_called_once()

# Check if the file was opened (mocked)
mock_open.assert_called_once_with("mock_datafile.txt", 'r')

# Check results
self.assertEqual(mock_weighted_avg.return_value['mae_e'], 0.7)
self.assertEqual(mock_weighted_avg.return_value['rmse_e'], 0.4)

if __name__ == '__main__':
unittest.main()
Loading