Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adios dataset name 319 #320

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
083ae32
adds dataset_name attribute to the base class and reads it into a dat…
Jan 27, 2025
edafeb5
adds dataset_name to the Data class a kwarg #319
Jan 27, 2025
8fca19f
dataset_name example for csce #319
Jan 27, 2025
c859eaf
GraphGPS arguments added to JSON file for CSCE examplke
allaffa Jan 27, 2025
4d62ab9
Merge remote-tracking branch 'max_fork/update_csce_ising_examples' in…
Jan 27, 2025
ca57aa4
formatting edits #319
Jan 27, 2025
c62961d
formatting #319
Jan 28, 2025
4c3e53c
checking to see if dataset_name is in keys before attempting to remov…
Feb 5, 2025
90f9c46
Dev hierarchical multibranch output heads (#323)
pzhanggit Feb 10, 2025
14b0add
Predictive GFM 2025 (#318)
allaffa Feb 10, 2025
b425edc
Merge remote-tracking branch 'upstream/main' into adios_dataset_name_319
Feb 11, 2025
924532e
Add multidataset example with deepspeed support (#316)
licj15 Feb 5, 2025
f03174e
data attributes updated for consistency across datasets
allaffa Jan 11, 2025
7689405
non-normalized chemical composition added as data attribute
allaffa Jan 11, 2025
a09f17f
scripts updated
Jan 11, 2025
214c28e
development of tranistion1x scripts continues
allaffa Jan 11, 2025
95150b8
black formatting fixed
allaffa Jan 11, 2025
d65a975
detach().clone() used to defined normalized energy per atom and black…
allaffa Jan 11, 2025
7898c88
smiles_string added as data attribute
allaffa Jan 16, 2025
1edc5fd
Reverted smiles_utils.py to version from commit 3c3c434f544d1a042775a…
allaffa Jan 23, 2025
a0919dc
xyz2mol functionalities put in a separate file
allaffa Jan 23, 2025
99f883f
total_energy and total_energy_per_atom replaced with energy and energ…
Feb 11, 2025
2147739
commented out fields for transition1x
allaffa Feb 12, 2025
e8c313f
duplicated liens removed
allaffa Feb 12, 2025
0eecdf7
First draft full
Feb 13, 2025
d916ae8
Merge remote-tracking branch 'upstream/main' into adios_dataset_name_319
Feb 13, 2025
517bd1b
Utility functions to read pbc uniformly as a tensor and applying posi…
Feb 14, 2025
75014f1
Convert data.pbc bool --> int during writing and int --> bool during …
Feb 14, 2025
3038ee8
Alexandria fully updated to have the same data_object attributes, typ…
Feb 14, 2025
2991aa4
re-order linesf
Feb 14, 2025
386edff
make sure on cpu for ASE
Feb 14, 2025
7108e84
Update examples full
Feb 14, 2025
8e4592b
Revise comment
Feb 14, 2025
ab52d94
correct bool check for tensor
Feb 14, 2025
dae490f
Remove unnecessary imports
Feb 14, 2025
faad02e
adjust comments
Feb 14, 2025
3bec7c9
Use PBC as int instead full
Feb 14, 2025
0b78cf2
Make sure to view cell as (3,3)
Feb 14, 2025
385c62d
Keep samples with default if we cant read pbc and tweak pbc transforms
Feb 15, 2025
a6e01a3
Merge pull request #25 from RylieWeaver/fix-examples-pbc
allaffa Feb 15, 2025
8c9cf28
Merge branch 'Predictive_GFM_2025_fix_examples' into adios_dataset_na…
kshitij-v-mehta Feb 17, 2025
8232537
Removed adding a blank attributed 'dataset_name' to the Data class. I…
Feb 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/alexandria/alexandria_energy.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
},
"NeuralNetwork": {
"Architecture": {
"global_attn_engine": "",
"global_attn_type": "",
"max_neighbours": 20,
"mpnn_type": "EGNN",
"equivariance": true,
"radius": 5,
"max_neighbours": 100000,
"pe_dim": 1,
"global_attn_heads": 8,
"num_gaussians": 50,
"envelope_exponent": 5,
"int_emb_size": 64,
Expand Down
129 changes: 101 additions & 28 deletions examples/alexandria/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from mpi4py import MPI
import argparse

import random
import numpy as np

import random

import torch
from torch_geometric.data import Data
from torch_geometric.transforms import AddLaplacianEigenvectorPE

from torch_geometric.transforms import Distance, Spherical, LocalCartesian

Expand All @@ -27,6 +29,9 @@
from hydragnn.preprocess.graph_samples_checks_and_updates import (
RadiusGraph,
RadiusGraphPBC,
PBCDistance,
PBCLocalCartesian,
pbc_as_tensor,
)
from hydragnn.preprocess.load_data import split_dataset

Expand Down Expand Up @@ -67,12 +72,21 @@ def list_directories(path):
reversed_dict_periodic_table = {value: key for key, value in periodic_table.items()}

# transform_coordinates = Spherical(norm=False, cat=False)
# transform_coordinates = LocalCartesian(norm=False, cat=False)
transform_coordinates = Distance(norm=False, cat=False)
transform_coordinates = LocalCartesian(norm=False, cat=False)
# transform_coordinates = Distance(norm=False, cat=False)

transform_coordinates_pbc = PBCLocalCartesian(norm=False, cat=False)
# transform_coordinates_pbc = PBCDistance(norm=False, cat=False)

class Alexandria(AbstractBaseDataset):
def __init__(self, dirpath, var_config, energy_per_atom=True, dist=False):
def __init__(
self,
dirpath,
var_config,
graphgps_transform=None,
energy_per_atom=True,
dist=False,
):
super().__init__()

self.dist = dist
Expand All @@ -86,6 +100,11 @@ def __init__(self, dirpath, var_config, energy_per_atom=True, dist=False):
self.radius_graph = RadiusGraph(5.0, loop=False, max_num_neighbors=50)
self.radius_graph_pbc = RadiusGraphPBC(5.0, loop=False, max_num_neighbors=50)

self.graphgps_transform = graphgps_transform

# Threshold for atomic forces in eV/angstrom
self.forces_norm_threshold = 1000.0

list_dirs = list_directories(
os.path.join(dirpath, "compressed_data", "alexandria.icams.rub.de")
)
Expand Down Expand Up @@ -140,22 +159,24 @@ def get_magmoms_array_from_structure(structure):
assert pos.shape[0] > 0, "pos tensor does not have any atoms"
except:
print(f"Structure {entry_id} does not have positional sites", flush=True)
return data_object
natoms = torch.IntTensor([pos.shape[0]])

cell = None
try:
cell = torch.tensor(structure["lattice"]["matrix"]).to(torch.float32)
cell = torch.tensor(structure["lattice"]["matrix"], dtype=torch.float32).view(3,3)
except:
print(f"Structure {entry_id} does not have cell", flush=True)
return data_object

pbc = None
try:
pbc = structure["lattice"]["pbc"]
pbc = pbc_as_tensor(structure["lattice"]["pbc"])
except:
print(f"Structure {entry_id} does not have pbc", flush=True)
return data_object

# If either cell or pbc were not read, we set to defaults
if cell is None or pbc is None:
cell = torch.eye(3, dtype=torch.float32)
pbc = torch.tensor([False, False, False], dtype=torch.bool)

atomic_numbers = None
try:
Expand Down Expand Up @@ -241,16 +262,31 @@ def get_magmoms_array_from_structure(structure):
# print(f"Structure {entry_id} does not have e_above_hull")
# return data_object

x = torch.cat([atomic_numbers, pos, forces], dim=1)

# Calculate chemical composition
atomic_number_list = atomic_numbers.tolist()
assert len(atomic_number_list) == natoms
## 118: number of atoms in the periodic table
hist, _ = np.histogram(atomic_number_list, bins=range(1, 118 + 2))
chemical_composition = torch.tensor(hist).unsqueeze(1).to(torch.float32)

data_object = Data(
dataset_name="alexandria",
natoms=natoms,
pos=pos,
cell=cell,
pbc=pbc,
edge_index=None,
edge_attr=None,
atomic_numbers=atomic_numbers,
forces=forces,
chemical_composition=chemical_composition,
smiles_string=None,
# entry_id=entry_id,
natoms=natoms,
total_energy=total_energy_tensor,
total_energy_per_atom=total_energy_per_atom_tensor,
x=x,
energy=total_energy_tensor,
energy_per_atom=total_energy_per_atom_tensor,
forces=forces,
# formation_energy=torch.tensor(formation_energy).float(),
# formation_energy_per_atom=torch.tensor(formation_energy_per_atom).float(),
# energy_above_hull=energy_above_hull,
Expand All @@ -261,29 +297,45 @@ def get_magmoms_array_from_structure(structure):
)

if self.energy_per_atom:
data_object.y = data_object.total_energy_per_atom
data_object.y = data_object.energy_per_atom
else:
data_object.y = data_object.total_energy
data_object.y = data_object.energy

data_object.x = torch.cat(
[data_object.atomic_numbers, data_object.pos, data_object.forces], dim=1
)

if data_object.pbc is not None and data_object.cell is not None:
# Apply radius graph and build edge attributes accordingly
if data_object.pbc.any():
try:
data_object = self.radius_graph_pbc(data_object)
data_object = transform_coordinates_pbc(data_object)
except:
print(
f"Structure {entry_id} could not successfully apply pbc radius graph",
f"Structure {entry_id} could not successfully apply one or both of the pbc radius graph and positional transform",
flush=True,
)
data_object = self.radius_graph(data_object)
data_object = transform_coordinates(data_object)
else:
data_object = self.radius_graph(data_object)

data_object = transform_coordinates(data_object)

return data_object
data_object = transform_coordinates(data_object)

# Default edge_shifts for when radius_graph_pbc is not activated
if not hasattr(data_object, "edge_shifts"):
data_object.edge_shifts = torch.zeros((data_object.edge_index.size(1), 3), dtype=torch.float32)

# FIXME: PBC from bool --> int32 to be accepted by ADIOS
data_object.pbc = data_object.pbc.int()

# LPE
if self.graphgps_transform is not None:
data_object = self.graphgps_transform(data_object)

if self.check_forces_values(data_object.forces):
return data_object
else:
print(
f"L2-norm of force tensor exceeds threshold {self.forces_norm_threshold} - atomistic structure: {data}",
flush=True,
)
return None

def process_file_content(self, filepath):
"""
Expand Down Expand Up @@ -311,7 +363,8 @@ def process_file_content(self, filepath):
self.get_data_dict(entry)
for entry in iterate_tqdm(
data["entries"],
desc=f"Processing file {filepath}",
#desc=f"Processing file {filepath}",
desc=None,
verbosity_level=2,
)
]
Expand All @@ -332,6 +385,14 @@ def process_file_content(self, filepath):
except Exception as e:
print("An error occurred:", e, flush=True)

def check_forces_values(self, forces):

# Calculate the L2 norm for each row
norms = torch.norm(forces, p=2, dim=1)
# Check if all norms are less than the threshold

return torch.all(norms < self.forces_norm_threshold).item()

def len(self):
return len(self.dataset)

Expand All @@ -356,7 +417,7 @@ def get(self, idx):
"--energy_per_atom",
help="option to normalize energy by number of atoms",
type=bool,
default=True,
default=False,
)
parser.add_argument("--ddstore", action="store_true", help="ddstore dataset")
parser.add_argument("--ddstore_width", type=int, help="ddstore width", default=None)
Expand All @@ -365,6 +426,9 @@ def get(self, idx):
parser.add_argument("--batch_size", type=int, help="batch_size", default=None)
parser.add_argument("--everyone", action="store_true", help="gptimer")
parser.add_argument("--modelname", help="model name")
parser.add_argument(
"--compute_grad_energy", type=bool, help="compute_grad_energy", default=False
)
group = parser.add_mutually_exclusive_group()
group.add_argument(
"--adios",
Expand Down Expand Up @@ -402,6 +466,13 @@ def get(self, idx):
var_config["node_feature_names"] = node_feature_names
var_config["node_feature_dims"] = node_feature_dims

# Transformation to create positional and structural laplacian encoders
graphgps_transform = AddLaplacianEigenvectorPE(
k=config["NeuralNetwork"]["Architecture"]["pe_dim"],
attr_name="pe",
is_undirected=True,
)

if args.batch_size is not None:
config["NeuralNetwork"]["Training"]["batch_size"] = args.batch_size

Expand Down Expand Up @@ -431,6 +502,7 @@ def get(self, idx):
total = Alexandria(
os.path.join(datadir),
var_config,
graphgps_transform=graphgps_transform,
energy_per_atom=args.energy_per_atom,
dist=True,
)
Expand Down Expand Up @@ -597,6 +669,7 @@ def get(self, idx):
log_name,
verbosity,
create_plots=False,
compute_grad_energy=args.compute_grad_energy,
)

hydragnn.utils.model.save_model(model, optimizer, log_name)
Expand Down
Loading
Loading