Skip to content

Commit 0e19e83

Browse files
author
Yury Lysogorskiy
committed
Update to 0.4.5:
- add df2extxyz, grace_predict and grace_preprocess - update gracemaker (including bugfix for finetuning) - implement distributed fit and grace_preprocess - add layer normalization to GRACE models - optimized padding strategy for TPCalculator - add pyproject.toml - add two OAM foundation models
1 parent 7aae775 commit 0e19e83

33 files changed

+3859
-543
lines changed

bin/df2extxyz

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#!/usr/bin/env python
2+
import argparse
3+
import logging
4+
import os
5+
import sys
6+
7+
import pandas as pd
8+
from ase.calculators.singlepoint import SinglePointCalculator
9+
10+
LOG_FMT = "%(asctime)s %(levelname).1s - %(message)s"
11+
logging.basicConfig(level=logging.INFO, format=LOG_FMT, datefmt="%Y/%m/%d %H:%M:%S")
12+
logger = logging.getLogger()
13+
14+
15+
from ase.io import write
16+
17+
18+
def build_parser():
19+
parser = argparse.ArgumentParser(
20+
prog="df2extxyz", description="Conversion from df.pkl.gz to extxyz"
21+
)
22+
23+
parser.add_argument("input", help="input pkl.gz file", type=str)
24+
25+
parser.add_argument(
26+
"-e",
27+
"--energy-column",
28+
help="name of energy column",
29+
type=str,
30+
default="energy_corrected",
31+
)
32+
33+
parser.add_argument(
34+
"-f", "--force-column", help="name of forces column", type=str, default="forces"
35+
)
36+
37+
parser.add_argument(
38+
"-s",
39+
"--stress-column",
40+
help="name of stress column",
41+
type=str,
42+
default="stress",
43+
)
44+
45+
parser.add_argument(
46+
"-o", "--output", help="output file name", type=str, default=None
47+
)
48+
return parser
49+
50+
51+
def sizeof_fmt(file_name_or_size, suffix="B"):
52+
if isinstance(file_name_or_size, str):
53+
file_name_or_size = os.path.getsize(file_name_or_size)
54+
for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
55+
if abs(file_name_or_size) < 1024.0:
56+
return "%3.1f%s%s" % (file_name_or_size, unit, suffix)
57+
file_name_or_size /= 1024.0
58+
return "%.1f%s%s" % (file_name_or_size, "Yi", suffix)
59+
60+
61+
def main(args):
62+
parser = build_parser()
63+
args_parse = parser.parse_args(args)
64+
input_fname = args_parse.input
65+
energy_col = args_parse.energy_column
66+
force_col = args_parse.force_column
67+
stress_col = args_parse.stress_column
68+
69+
logging.info(f"Reading input file : {input_fname} ({sizeof_fmt(input_fname)})")
70+
df = pd.read_pickle(input_fname, compression="gzip")
71+
logging.info(f"Dataframe shape: {df.shape}")
72+
df.reset_index(drop=True, inplace=True)
73+
74+
atoms_list = []
75+
for _, row in df.iterrows():
76+
at = row["ase_atoms"]
77+
stress = row[stress_col] if stress_col in row else None
78+
at.info.update(
79+
{
80+
"REF_energy": row[energy_col],
81+
}
82+
)
83+
at.arrays.update({"REF_forces": row[force_col]})
84+
if stress_col in row:
85+
at.info["REF_stress"] = row[stress_col]
86+
87+
at.calc = SinglePointCalculator(
88+
at,
89+
energy=row[energy_col],
90+
forces=row[force_col],
91+
stress=stress,
92+
)
93+
94+
atoms_list.append(at)
95+
96+
output_fname = args_parse.output
97+
if output_fname is None:
98+
ext_to_replace = [".pckl.gzip", "pckl.gz", ".pkl.gzip", ".pkl.gz"]
99+
output_fname = input_fname
100+
for ext in ext_to_replace:
101+
output_fname = output_fname.replace(ext, ".xyz")
102+
logging.info(f"Writing to output filename: {output_fname}")
103+
write(output_fname, atoms_list, format="extxyz")
104+
logging.info(f"Saved to {output_fname} ({sizeof_fmt(output_fname)})")
105+
106+
107+
if __name__ == "__main__":
108+
main(sys.argv[1:])

bin/grace_predict

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#!/usr/bin/env python
2+
3+
import warnings
4+
5+
warnings.filterwarnings("ignore", category=FutureWarning)
6+
warnings.filterwarnings("ignore", category=DeprecationWarning)
7+
8+
import sys
9+
import os
10+
import pandas as pd
11+
import argparse
12+
13+
from tensorpotential.calculator import TPCalculator
14+
15+
import logging
16+
17+
LOG_FMT = "%(asctime)s %(levelname).1s - %(message)s"
18+
logging.basicConfig(level=logging.INFO, format=LOG_FMT, datefmt="%Y/%m/%d %H:%M:%S")
19+
logger = logging.getLogger()
20+
21+
from tqdm import tqdm
22+
23+
tqdm.pandas()
24+
25+
26+
def predict(row, calc):
27+
at = row["ase_atoms"].copy()
28+
if "mag_mom" in row:
29+
at.set_initial_magnetic_moments(row["mag_mom"])
30+
at.calc = calc
31+
e = at.get_potential_energy()
32+
f = at.get_forces()
33+
return {"energy": e, "forces": f}
34+
35+
36+
def main(args):
37+
parser = argparse.ArgumentParser()
38+
39+
parser.add_argument(
40+
"-m",
41+
"--model_path",
42+
help="provide path to the saved_model directory",
43+
type=str,
44+
default="saved_model",
45+
dest="model_path",
46+
)
47+
48+
parser.add_argument(
49+
"-d",
50+
"--dataset",
51+
help="path to the dataset.pkl.gzip containing ase_atoms structures",
52+
type=str,
53+
default="dataset.pkl.gz",
54+
dest="dataset_file",
55+
)
56+
57+
parser.add_argument(
58+
"-o",
59+
"--output",
60+
help="path to the OUTPUT dataset (pkl.gzip) containing energy_predicted and forces_predicted",
61+
type=str,
62+
default="predicted_dataset.pkl.gz",
63+
dest="output",
64+
)
65+
66+
args_parse = parser.parse_args(args)
67+
68+
model_path = os.path.abspath(args_parse.model_path)
69+
dataset_file = args_parse.dataset_file
70+
output_file = args_parse.output
71+
72+
logger.info(f"Loading model from: {model_path}")
73+
calc = TPCalculator(
74+
model=model_path,
75+
pad_atoms_number=20,
76+
pad_neighbors_fraction=0.30,
77+
# max_number_reduction_recompilation=3,
78+
)
79+
80+
logger.info(f"Loading dataset from: {dataset_file}")
81+
df = pd.read_pickle(dataset_file, compression="gzip")
82+
83+
logger.info(f"Starting prediction")
84+
85+
df["prediction"] = df.progress_apply(predict, axis=1, args=(calc,))
86+
df["energy_predicted"] = df["prediction"].map(lambda x: x["energy"])
87+
df["forces_predicted"] = df["prediction"].map(lambda x: x["forces"])
88+
df = df.drop(columns=["ase_atoms", "prediction"])
89+
90+
logger.info(f"Saving dataset to {output_file}")
91+
df.drop(columns=["ase_atoms"]).to_pickle(output_file, compression="gzip")
92+
93+
94+
if __name__ == "__main__":
95+
main(sys.argv[1:])

0 commit comments

Comments
 (0)