Skip to content

Commit

Permalink
Align examples pbc (ORNL#313)
Browse files Browse the repository at this point in the history
* revise open catalyst, consistent naming, and remove unnecessary assert

* take away unnecessary check

* revert unneeded naming

* Unified PBC naming and using torch autograd in LJ for simpler and more robust force calculations

* Aligning the OC examples with the OC dataset in the literature, which always uses PBC

* more robust radius

* correct link

* Unify naming of cell in other parts of HydraGNN such as the tests and preprocess

* More robust default radius and black formatting

* mptrj pbc

* Remove assert for dupe edges and ensure max_num_neighbors for computational feasibility

* Revise examples for error handling in extracting cell, extracting pbc, and creating radius graph

* comment

* Simpler implementation in numpy and robust device handling

* no need to do these changes now that we have robust error handling

* no need for radius change

* polish radius graph pbc and reemove unnecessary oc neighbor changes

* black

* better import

* reverse unnecessary changes

* device and data type handling

* black

* clean up inference script and add r2

* better typing and device handling

* simplify checks

* Switch stack and epochs in json for well-performing default
  • Loading branch information
RylieWeaver authored Dec 28, 2024
1 parent 1419dc4 commit 4f8a21d
Show file tree
Hide file tree
Showing 12 changed files with 360 additions and 201 deletions.
4 changes: 2 additions & 2 deletions examples/LennardJones/LJ.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"periodic_boundary_conditions": true,
"global_attn_engine": "",
"global_attn_type": "",
"mpnn_type": "EGNN",
"mpnn_type": "DimeNet",
"radius": 5.0,
"max_neighbours": 5,
"int_emb_size": 32,
Expand Down Expand Up @@ -61,7 +61,7 @@
"output_names": ["graph_energy"]
},
"Training": {
"num_epoch": 15,
"num_epoch": 25,
"batch_size": 64,
"perc_train": 0.7,
"patience": 20,
Expand Down
169 changes: 57 additions & 112 deletions examples/LennardJones/LJ_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
from torch_geometric.data import Data
from torch_geometric.transforms import AddLaplacianEigenvectorPE
from torch_scatter import scatter

# torch.set_default_tensor_type(torch.DoubleTensor)
# torch.set_default_dtype(torch.float64)
Expand All @@ -36,6 +37,7 @@
from hydragnn.utils.datasets.abstractrawdataset import AbstractBaseDataset
from hydragnn.utils.distributed import nsplit
from hydragnn.preprocess.graph_samples_checks_and_updates import get_radius_graph_pbc
from hydragnn.utils.model.operations import get_edge_vectors_and_lengths

# Angstrom unit
primitive_bravais_lattice_constant_x = 3.8
Expand All @@ -51,6 +53,7 @@

def create_dataset(path, config):
radius_cutoff = config["NeuralNetwork"]["Architecture"]["radius"]
max_num_neighbors = config["NeuralNetwork"]["Architecture"]["max_neighbours"]
number_configurations = (
config["Dataset"]["number_configurations"]
if "number_configurations" in config["Dataset"]
Expand All @@ -73,6 +76,7 @@ def create_dataset(path, config):
atom_types,
atomic_structure_handler=atomic_structure_handler,
radius_cutoff=radius_cutoff,
max_num_neighbors=max_num_neighbors,
relative_maximum_atomic_displacement=1e-1,
number_configurations=number_configurations,
)
Expand Down Expand Up @@ -167,7 +171,7 @@ def transform_input_to_data_object_base(self, filepath):
forces_pre_scaled = forces * forces_pre_scaling_factor

data = Data(
supercell_size=torch_supercell.to(torch.float32),
cell=torch_supercell.to(torch.float32),
num_nodes=num_nodes,
grad_energy_post_scaling_factor=grad_energy_post_scaling_factor,
forces_pre_scaling_factor=torch.tensor(forces_pre_scaling_factor).to(
Expand All @@ -182,11 +186,14 @@ def transform_input_to_data_object_base(self, filepath):
.unsqueeze(0)
.to(torch.float32),
energy=torch.tensor(total_energy).unsqueeze(0).to(torch.float32),
pbc=[
True,
True,
True,
], # LJ example always has periodic boundary conditions
pbc=torch.tensor(
[
True,
True,
True,
],
dtype=torch.bool,
), # LJ example always has periodic boundary conditions
)

# Create pbc edges and lengths
Expand Down Expand Up @@ -345,37 +352,28 @@ def create_configuration(
supercell_size_x = primitive_bravais_lattice_constant_x * uc_x
supercell_size_y = primitive_bravais_lattice_constant_y * uc_y
supercell_size_z = primitive_bravais_lattice_constant_z * uc_z
data.supercell_size = torch.diag(
data.cell = torch.diag(
torch.tensor([supercell_size_x, supercell_size_y, supercell_size_z])
)
data.pbc = [True, True, True]
data.pbc = torch.tensor([True, True, True], dtype=torch.bool)
data.x = torch.cat([atom_types, positions], dim=1)

create_graph_connectivity_pbc = get_radius_graph_pbc(
radius_cutoff, max_num_neighbors
)
data = create_graph_connectivity_pbc(data)

atomic_descriptors = torch.cat(
(
atom_types,
positions,
),
1,
)

data.x = atomic_descriptors

data = atomic_structure_handler.compute(data)

total_energy = torch.sum(data.x[:, 4])
energy_per_atom = total_energy / number_nodes

total_energy_str = numpy.array2string(total_energy.detach().numpy())
energy_per_atom_str = numpy.array2string(energy_per_atom.detach().numpy())
total_energy_str = numpy.array2string(total_energy.detach().cpu().numpy())
energy_per_atom_str = numpy.array2string(energy_per_atom.detach().cpu().numpy())
filetxt = total_energy_str + "\n" + energy_per_atom_str

for index in range(0, 3):
numpy_row = data.supercell_size[index, :].detach().numpy()
numpy_row = data.cell[index, :].detach().numpy()
numpy_string_row = numpy.array2string(numpy_row, precision=64, separator="\t")
filetxt += "\n" + numpy_string_row.lstrip("[").rstrip("]")

Expand All @@ -402,73 +400,47 @@ def __init__(
self.radius_cutoff = radius_cutoff
self.formula = formula

# Calculate the potential energy with torch gradient tracking, then simply use autograd to calculate the forces
def compute(self, data):
# Instantiate
assert data.pos.shape[0] == data.x.shape[0]

interatomic_potential = torch.zeros([data.pos.shape[0], 1])
interatomic_forces = torch.zeros([data.pos.shape[0], 3])

for node_id in range(data.pos.shape[0]):
neighbor_list_indices = torch.where(data.edge_index[0, :] == node_id)[
0
].tolist()
neighbor_list = data.edge_index[1, neighbor_list_indices]

for neighbor_id, edge_id in zip(neighbor_list, neighbor_list_indices):
neighbor_pos = data.pos[neighbor_id, :]
distance_vector = data.pos[neighbor_id, :] - data.pos[node_id, :]

# Adjust the neighbor position based on periodic boundary conditions (PBC)
## If the distance between the atoms is larger than the cutoff radius, the edge is because of PBC conditions
if torch.norm(distance_vector) > self.radius_cutoff:
## At this point, we know that the edge is due to PBC conditions, so we need to adjust the neighbor position. We also know that
## that this connection MUST be the closest connection possible as a result of the asserted radius_cutoff < supercell_size earlier
## in the code. Because of this, we can simply adjust the neighbor position coordinate-wise to be closer than
## as done in the following lines of code. The logic goes that if the distance vector[index] is larger than half the supercell size,
## then there is a closer distance at +- supercell_size[index], and we adjust to that for each coordinate
if abs(distance_vector[0]) > data.supercell_size[0, 0] / 2:
if distance_vector[0] > 0:
neighbor_pos[0] -= data.supercell_size[0, 0]
else:
neighbor_pos[0] += data.supercell_size[0, 0]

if abs(distance_vector[1]) > data.supercell_size[1, 1] / 2:
if distance_vector[1] > 0:
neighbor_pos[1] -= data.supercell_size[1, 1]
else:
neighbor_pos[1] += data.supercell_size[1, 1]

if abs(distance_vector[2]) > data.supercell_size[2, 2] / 2:
if distance_vector[2] > 0:
neighbor_pos[2] -= data.supercell_size[2, 2]
else:
neighbor_pos[2] += data.supercell_size[2, 2]

# The distance vecor may need to be updated after applying PBCs
distance_vector = data.pos[node_id, :] - neighbor_pos

# pair_distance = data.edge_attr[edge_id].item()
interatomic_potential[node_id] += self.formula.potential_energy(
distance_vector
)

derivative_x = self.formula.derivative_x(distance_vector)
derivative_y = self.formula.derivative_y(distance_vector)
derivative_z = self.formula.derivative_z(distance_vector)

interatomic_forces_contribution_x = -derivative_x
interatomic_forces_contribution_y = -derivative_y
interatomic_forces_contribution_z = -derivative_z

interatomic_forces[node_id, 0] += interatomic_forces_contribution_x
interatomic_forces[node_id, 1] += interatomic_forces_contribution_y
interatomic_forces[node_id, 2] += interatomic_forces_contribution_z

data.x = torch.cat(
(data.x, interatomic_potential, interatomic_forces),
1,
node_potential = torch.zeros([data.pos.shape[0], 1])
node_forces = torch.zeros([data.pos.shape[0], 3])

# Calculate
data.pos.requires_grad = True
edge_vec, edge_dist = get_edge_vectors_and_lengths(
positions=data.pos,
edge_index=data.edge_index,
shifts=data.edge_shifts,
normalize=False,
)

# Sum potential by edge, node, and total
edge_potential = self.formula.potential_energy(
edge_dist
) # Shape [num_edges, 1]
node_potential = scatter(
edge_potential,
data.edge_index[0],
dim=0,
dim_size=data.pos.shape[0],
reduce="add",
) # Shape [num_nodes, 1]
total_potential = torch.sum(node_potential, dim=0, keepdim=True) # Shape [1]

# Autograd to calculate forces
node_forces = -torch.autograd.grad(
total_potential,
data.pos,
grad_outputs=torch.ones_like(total_potential),
)[
0
] # Shape [num_nodes, 3]

# Append to data
data.x = torch.cat((data.x, node_potential, node_forces), dim=1)

return data


Expand All @@ -477,40 +449,13 @@ def __init__(self, epsilon, sigma):
self.epsilon = epsilon
self.sigma = sigma

def potential_energy(self, distance_vector):
pair_distance = torch.norm(distance_vector)
def potential_energy(self, pair_distance):
return (
4
* self.epsilon
* ((self.sigma / pair_distance) ** 12 - (self.sigma / pair_distance) ** 6)
)

def radial_derivative(self, distance_vector):
pair_distance = torch.norm(distance_vector)
return (
4
* self.epsilon
* (
-12 * (self.sigma / pair_distance) ** 12 * 1 / pair_distance
+ 6 * (self.sigma / pair_distance) ** 6 * 1 / pair_distance
)
)

def derivative_x(self, distance_vector):
pair_distance = torch.norm(distance_vector)
radial_derivative = self.radial_derivative(pair_distance)
return radial_derivative * (distance_vector[0].item()) / pair_distance

def derivative_y(self, distance_vector):
pair_distance = torch.norm(distance_vector)
radial_derivative = self.radial_derivative(pair_distance)
return radial_derivative * (distance_vector[1].item()) / pair_distance

def derivative_z(self, distance_vector):
pair_distance = torch.norm(distance_vector)
radial_derivative = self.radial_derivative(pair_distance)
return radial_derivative * (distance_vector[2].item()) / pair_distance


"""Etc"""

Expand Down
28 changes: 14 additions & 14 deletions examples/LennardJones/LJ_inference_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@

import hydragnn
from hydragnn.utils.profiling_and_tracing.time_utils import Timer
from hydragnn.utils.distributed import get_device
from hydragnn.utils.distributed import get_device, setup_ddp
from hydragnn.utils.model import load_existing_model
from hydragnn.utils.datasets.pickledataset import SimplePickleDataset
from hydragnn.utils.input_config_parsing.config_utils import (
update_config,
)
from hydragnn.utils.print import setup_log
from hydragnn.models.create import create_model_config
from hydragnn.preprocess import create_dataloaders

Expand All @@ -42,13 +43,14 @@
from LJ_data import info

import matplotlib.pyplot as plt
from sklearn.metrics import r2_score

plt.rcParams.update({"font.size": 16})


def get_log_name_config(config):
return (
config["NeuralNetwork"]["Architecture"]["model_type"]
config["NeuralNetwork"]["Architecture"]["mpnn_type"]
+ "-r-"
+ str(config["NeuralNetwork"]["Architecture"]["radius"])
+ "-ncl-"
Expand Down Expand Up @@ -132,10 +134,10 @@ def getcolordensity(xdata, ydata):
input_filename = os.path.join(dirpwd, args.inputfile)
with open(input_filename, "r") as f:
config = json.load(f)
hydragnn.utils.setup_log(get_log_name_config(config))
setup_log(get_log_name_config(config))
##################################################################################################################
# Always initialize for multi-rank training.
comm_size, rank = hydragnn.utils.setup_ddp()
comm_size, rank = setup_ddp()
##################################################################################################################
comm = MPI.COMM_WORLD

Expand Down Expand Up @@ -179,11 +181,6 @@ def getcolordensity(xdata, ydata):
load_existing_model(model, modelname, path="./logs/")
model.eval()

variable_index = 0
# for output_name, output_type, output_dim in zip(config["NeuralNetwork"]["Variables_of_interest"]["output_names"], config["NeuralNetwork"]["Variables_of_interest"]["type"], config["NeuralNetwork"]["Variables_of_interest"]["output_dim"]):

test_MAE = 0.0

num_samples = len(testset)
energy_true_list = []
energy_pred_list = []
Expand All @@ -196,9 +193,6 @@ def getcolordensity(xdata, ydata):
0
] # Note that this is sensitive to energy and forces prediction being single-task (current requirement)
energy_pred = torch.sum(node_energy_pred, dim=0).float()
test_MAE += torch.norm(energy_pred - data.energy, p=1).item() / len(testset)
# predicted.backward(retain_graph=True)
# gradients = data.pos.grad
grads_energy = torch.autograd.grad(
outputs=energy_pred,
inputs=data.pos,
Expand All @@ -211,6 +205,14 @@ def getcolordensity(xdata, ydata):
forces_pred_list.extend((-grads_energy).flatten().tolist())
forces_true_list.extend(data.forces.flatten().tolist())

# Show R2 Metrics
print(
f"R2 energy: ", r2_score(np.array(energy_true_list), np.array(energy_pred_list))
)
print(
f"R2 forces: ", r2_score(np.array(forces_true_list), np.array(forces_pred_list))
)

hist2d_norm = getcolordensity(energy_true_list, energy_pred_list)

fig, ax = plt.subplots()
Expand All @@ -225,8 +227,6 @@ def getcolordensity(xdata, ydata):
plt.tight_layout()
plt.savefig(f"./energy_Scatterplot" + ".png", dpi=400)

print(f"Test MAE energy: ", test_MAE)

hist2d_norm = getcolordensity(forces_pred_list, forces_true_list)
fig, ax = plt.subplots()
plt.scatter(forces_pred_list, forces_true_list, s=8, c=hist2d_norm, vmin=0, vmax=1)
Expand Down
Loading

0 comments on commit 4f8a21d

Please sign in to comment.