diff --git a/graph_structure_learning/graph_structure_learner.py b/graph_structure_learning/graph_structure_learner.py new file mode 100644 index 00000000..6074374f --- /dev/null +++ b/graph_structure_learning/graph_structure_learner.py @@ -0,0 +1,13 @@ +""" + Performs graph structure learning on the MSG as outlined in the + paper. +""" + + +class GraphStructureLearner: + """ + Implements GSL. + """ + + def __init__(self): + pass diff --git a/graph_structure_learning/molecular_similarity_graph.py b/graph_structure_learning/molecular_similarity_graph.py new file mode 100644 index 00000000..565c5548 --- /dev/null +++ b/graph_structure_learning/molecular_similarity_graph.py @@ -0,0 +1,24 @@ +""" + A class implementing the molecular similarity graph (MSG) by + taking data from the MoleculeProcessor class as well as the model + embeddings. Again, it should follow the steps in the notebook + and have methods that perform the following + - constructs an adjacency matrix from the tanimoto coefficients + subject to a cutoff, + - takes model embeddings and adjacency matrix and converts them + into PyTorch Geometric graph data. + - (optional) networkx visualization of molecular similarity graph +""" + +class MolecularSimilarityGraph: + def __init__(self, moleculeData): + pass + + def constructAdjacencyMatrix(self): + pass + + def toGraphData(self): + pass + + def visualize(self): + pass \ No newline at end of file diff --git a/graph_structure_learning/run_with_gsl.py b/graph_structure_learning/run_with_gsl.py new file mode 100644 index 00000000..a1ba5600 --- /dev/null +++ b/graph_structure_learning/run_with_gsl.py @@ -0,0 +1,8 @@ +""" + Run our GNN-GSL model. + - Much of this code can be taken from main.py, but + we will have to modify the running process so that it + takes the 'readout' output from the trainer and feeds + it into the GSL pipeline. +""" + diff --git a/graph_structure_learning/xyz_processor.py b/graph_structure_learning/xyz_processor.py new file mode 100644 index 00000000..808ff339 --- /dev/null +++ b/graph_structure_learning/xyz_processor.py @@ -0,0 +1,26 @@ +""" + Use Atomic Simulation Environment (ASE) to process the molecules + into individual .xyz files and then delete the files. + - Implement a 'MoleculeProcessor' object that takes a list of + JSON molecules on construction and has methods that perform the + exact same process as in our notebook. I.e., the following methods + must be implemented: + - converts the molecules into .xyz files, + - computes and stores all relevent metrics + (fingerprints, Tanimoto coeffs, etc.), + - deletes all .xyz files. +""" + + +class MoleculeProcessor: + def __init__(self, molList): + pass + + def toXYZ(self): + pass + + def computeMetrics(self): + pass + + def teardownXYZ(self): + pass diff --git a/matdeeplearn/models/torchmd_et.py b/matdeeplearn/models/torchmd_et.py index a72a67d7..a33121d6 100644 --- a/matdeeplearn/models/torchmd_et.py +++ b/matdeeplearn/models/torchmd_et.py @@ -16,9 +16,9 @@ from matdeeplearn.models.torchmd_output_modules import Scalar, EquivariantScalar from matdeeplearn.common.registry import registry from matdeeplearn.preprocessor.helpers import node_rep_one_hot -@registry.register_model("torchmd_et") +@registry.register_model("torchmd_et") class TorchMD_ET(BaseModel): r"""The TorchMD equivariant Transformer architecture. @@ -60,7 +60,7 @@ class TorchMD_ET(BaseModel): def __init__( self, - node_dim, + node_dim, edge_dim, output_dim, hidden_channels=128, @@ -110,7 +110,8 @@ def __init__( self.distance_influence = distance_influence self.max_z = max_z self.pool = pool - assert pool_order in ['early', 'late'], f"{pool_order} is currently not supported" + assert pool_order in [ + 'early', 'late'], f"{pool_order} is currently not supported" self.pool_order = pool_order self.output_dim = output_dim cutoff_lower = 0 @@ -159,10 +160,13 @@ def __init__( self.post_lin_list = nn.ModuleList() for i in range(self.num_post_layers): if i == 0: - self.post_lin_list.append(nn.Linear(hidden_channels, post_hidden_channels)) + self.post_lin_list.append( + nn.Linear(hidden_channels, post_hidden_channels)) else: - self.post_lin_list.append(nn.Linear(post_hidden_channels, post_hidden_channels)) - self.post_lin_list.append(nn.Linear(post_hidden_channels, self.output_dim)) + self.post_lin_list.append( + nn.Linear(post_hidden_channels, post_hidden_channels)) + self.post_lin_list.append( + nn.Linear(post_hidden_channels, self.output_dim)) self.reset_parameters() @@ -174,82 +178,93 @@ def reset_parameters(self): for attn in self.attention_layers: attn.reset_parameters() self.out_norm.reset_parameters() - + @conditional_grad(torch.enable_grad()) def _forward(self, data): x = self.embedding(data.z) - #edge_index, edge_weight, edge_vec = self.distance(data.pos, data.batch) - #assert ( + # edge_index, edge_weight, edge_vec = self.distance(data.pos, data.batch) + # assert ( # edge_vec is not None - #), "Distance module did not return directional information" + # ), "Distance module did not return directional information" if self.otf_edge_index == True: - #data.edge_index, edge_weight, data.edge_vec, cell_offsets, offset_distance, neighbors = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) - data.edge_index, data.edge_weight, data.edge_vec, _, _, _ = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) - data.edge_attr = self.distance_expansion(data.edge_weight) - - #mask = data.edge_index[0] != data.edge_index[1] - #data.edge_vec[mask] = data.edge_vec[mask] / torch.norm(data.edge_vec[mask], dim=1).unsqueeze(1) - data.edge_vec = data.edge_vec / torch.norm(data.edge_vec, dim=1).unsqueeze(1) - + # data.edge_index, edge_weight, data.edge_vec, cell_offsets, offset_distance, neighbors = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) + data.edge_index, data.edge_weight, data.edge_vec, _, _, _ = self.generate_graph( + data, self.cutoff_radius, self.n_neighbors) + data.edge_attr = self.distance_expansion(data.edge_weight) + + # mask = data.edge_index[0] != data.edge_index[1] + # data.edge_vec[mask] = data.edge_vec[mask] / torch.norm(data.edge_vec[mask], dim=1).unsqueeze(1) + data.edge_vec = data.edge_vec / \ + torch.norm(data.edge_vec, dim=1).unsqueeze(1) + if self.otf_node_attr == True: - data.x = node_rep_one_hot(data.z).float() - + data.x = node_rep_one_hot(data.z).float() + if self.neighbor_embedding is not None: - x = self.neighbor_embedding(data.z, x, data.edge_index, data.edge_weight, data.edge_attr) + x = self.neighbor_embedding( + data.z, x, data.edge_index, data.edge_weight, data.edge_attr) vec = torch.zeros(x.size(0), 3, x.size(1), device=x.device) for attn in self.attention_layers: - dx, dvec = attn(x, vec, data.edge_index, data.edge_weight, data.edge_attr, data.edge_vec) + dx, dvec = attn(x, vec, data.edge_index, + data.edge_weight, data.edge_attr, data.edge_vec) x = x + dx vec = vec + dvec + # just output the embeddings => stop before the prediction layer x = self.out_norm(x) - - if self.prediction_level == "graph": - if self.pool_order == 'early': - x = getattr(torch_geometric.nn, self.pool)(x, data.batch) - for i in range(0, len(self.post_lin_list) - 1): - x = self.post_lin_list[i](x) - x = getattr(F, self.activation)(x) - x = self.post_lin_list[-1](x) - if self.pool_order == 'late': - x = getattr(torch_geometric.nn, self.pool)(x, data.batch) - #x = self.pool.pre_reduce(x, vec, data.z, data.pos, data.batch) - #x = self.pool.reduce(x, data.batch) - elif self.prediction_level == "node": - for i in range(0, len(self.post_lin_list) - 1): - x = self.post_lin_list[i](x) - x = getattr(F, self.activation)(x) - x = self.post_lin_list[-1](x) - + + # if self.prediction_level == "graph": + # if self.pool_order == 'early': + # x = getattr(torch_geometric.nn, self.pool)(x, data.batch) + # for i in range(0, len(self.post_lin_list) - 1): + # x = self.post_lin_list[i](x) + # x = getattr(F, self.activation)(x) + # x = self.post_lin_list[-1](x) + # if self.pool_order == 'late': + # x = getattr(torch_geometric.nn, self.pool)(x, data.batch) + # # x = self.pool.pre_reduce(x, vec, data.z, data.pos, data.batch) + # # x = self.pool.reduce(x, data.batch) + # elif self.prediction_level == "node": + # for i in range(0, len(self.post_lin_list) - 1): + # x = self.post_lin_list[i](x) + # x = getattr(F, self.activation)(x) + # x = self.post_lin_list[-1](x) + + # TODO: FIGURE OUT HOW TO ACCESS EMBEDDINGS; WE NEED THEM TO COMPUTE + # MOLECULAR FINGERPRINTS. + return x - + def forward(self, data): - + output = {} out = self._forward(data) - output["output"] = out + output["output"] = out - if self.gradient == True and out.requires_grad == True: - volume = torch.einsum("zi,zi->z", data.cell[:, 0, :], torch.cross(data.cell[:, 1, :], data.cell[:, 2, :], dim=1)).unsqueeze(-1) + # this is skipped reached since we're not getting the prediction (I think?) + # even if it is reached, we're probably fine lol. + if self.gradient == True and out.requires_grad == True: + volume = torch.einsum("zi,zi->z", data.cell[:, 0, :], torch.cross( + data.cell[:, 1, :], data.cell[:, 2, :], dim=1)).unsqueeze(-1) grad = torch.autograd.grad( - out, - [data.pos, data.displacement], - grad_outputs=torch.ones_like(out), - create_graph=self.training) + out, + [data.pos, data.displacement], + grad_outputs=torch.ones_like(out), + create_graph=self.training) forces = -1 * grad[0] stress = grad[1] - stress = stress / volume.view(-1, 1, 1) + stress = stress / volume.view(-1, 1, 1) - output["pos_grad"] = forces - output["cell_grad"] = stress + output["pos_grad"] = forces + output["cell_grad"] = stress else: - output["pos_grad"] = None - output["cell_grad"] = None - - return output + output["pos_grad"] = None + output["cell_grad"] = None + + return output def __repr__(self): return ( @@ -267,6 +282,7 @@ def __repr__(self): f"cutoff_lower={self.cutoff_lower}, " f"self.cutoff_radius={self.self.cutoff_radius})" ) + @property def target_attr(self): return "y" @@ -285,7 +301,8 @@ def __init__( cutoff_upper, aggregation, ): - super(EquivariantMultiHeadAttention, self).__init__(aggr=aggregation, node_dim=0) + super(EquivariantMultiHeadAttention, self).__init__( + aggr=aggregation, node_dim=0) assert hidden_channels % num_heads == 0, ( f"The number of hidden channels ({hidden_channels}) " f"must be evenly divisible by the number of " @@ -307,7 +324,8 @@ def __init__( self.v_proj = nn.Linear(hidden_channels, hidden_channels * 3) self.o_proj = nn.Linear(hidden_channels, hidden_channels * 3) - self.vec_proj = nn.Linear(hidden_channels, hidden_channels * 3, bias=False) + self.vec_proj = nn.Linear( + hidden_channels, hidden_channels * 3, bias=False) self.dk_proj = None if distance_influence in ["keys", "both"]: @@ -343,21 +361,23 @@ def forward(self, x, vec, edge_index, r_ij, f_ij, d_ij): k = self.k_proj(x).reshape(-1, self.num_heads, self.head_dim) v = self.v_proj(x).reshape(-1, self.num_heads, self.head_dim * 3) - vec1, vec2, vec3 = torch.split(self.vec_proj(vec), self.hidden_channels, dim=-1) + vec1, vec2, vec3 = torch.split( + self.vec_proj(vec), self.hidden_channels, dim=-1) vec = vec.reshape(-1, 3, self.num_heads, self.head_dim) vec_dot = (vec1 * vec2).sum(dim=1) dk = ( - self.act(self.dk_proj(f_ij)).reshape(-1, self.num_heads, self.head_dim) + self.act(self.dk_proj(f_ij)).reshape(-1, + self.num_heads, self.head_dim) if self.dk_proj is not None else None ) dv = ( - self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads, self.head_dim * 3) + self.act(self.dv_proj(f_ij)).reshape(-1, + self.num_heads, self.head_dim * 3) if self.dv_proj is not None else None ) - # propagate_type: (q: Tensor, k: Tensor, v: Tensor, vec: Tensor, dk: Tensor, dv: Tensor, r_ij: Tensor, d_ij: Tensor) x, vec = self.propagate( diff --git a/matdeeplearn/trainers/property_trainer.py b/matdeeplearn/trainers/property_trainer.py index 06916021..19733e50 100644 --- a/matdeeplearn/trainers/property_trainer.py +++ b/matdeeplearn/trainers/property_trainer.py @@ -36,6 +36,7 @@ def __init__( save_dir, checkpoint_path, use_amp, + use_gsl=False # if true, run with graph structure learning ): super().__init__( model, @@ -55,7 +56,7 @@ def __init__( output_frequency, model_save_frequency, save_dir, - checkpoint_path, + checkpoint_path, use_amp, ) @@ -68,7 +69,7 @@ def train(self): if str(self.rank) not in ("cpu", "cuda"): dist.barrier() - + end_epoch = ( self.max_checkpoint_epochs + start_epoch if self.max_checkpoint_epochs @@ -85,63 +86,76 @@ def train(self): logging.info( f"Running for {end_epoch - start_epoch} epochs on {type(self.model[0]).__name__} model" ) - - for epoch in range(start_epoch, end_epoch): + + for epoch in range(start_epoch, end_epoch): epoch_start_time = time.time() if self.train_sampler: self.train_sampler.set_epoch(epoch) # skip_steps = self.step % len(self.train_loader) train_loader_iter = [] for i in range(len(self.model)): - train_loader_iter.append(iter(self.data_loader[i]["train_loader"])) + train_loader_iter.append( + iter(self.data_loader[i]["train_loader"])) # metrics for every epoch _metrics = [{} for _ in range(len(self.model))] - - #for i in range(skip_steps, len(self.train_loader)): - pbar = tqdm(range(0, len(self.data_loader[0]["train_loader"])), disable=not self.batch_tqdm) - for i in pbar: - #self.epoch = epoch + (i + 1) / len(self.train_loader) - #self.step = epoch * len(self.train_loader) + i + 1 - #print(i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024)) + + # for i in range(skip_steps, len(self.train_loader)): + pbar = tqdm( + range(0, len(self.data_loader[0]["train_loader"])), disable=not self.batch_tqdm) + for i in pbar: + # self.epoch = epoch + (i + 1) / len(self.train_loader) + # self.step = epoch * len(self.train_loader) + i + 1 + # print(i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024)) batch = [] for n, mod in enumerate(self.model): mod.train() batch.append(next(train_loader_iter[n]).to(self.rank)) # Get a batch of train data - # batch = next(train_loader_iter).to(self.rank) - # print(epoch, i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024), torch.sum(batch.n_atoms)) - # Compute forward, loss, backward + # batch = next(train_loader_iter).to(self.rank) + # print(epoch, i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024), torch.sum(batch.n_atoms)) + # Compute forward, loss, backward with autocast(enabled=self.use_amp): - out_list = self._forward(batch) - loss = self._compute_loss(out_list, batch) - #print(i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024)) + out_list = self._forward(batch) + out = out_list[0] + # Perform a readout operation on the atomic node embeddings + # to obtain a representation of the entire molecule. + # TODO: We need to extract this vector somehow + readout = torch.exp(torch.mean(torch.log(out), dim=1)) + + loss = self._compute_loss(out_list, batch) + + # print(i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024)) grad_norm = [] for i in range(len(self.model)): grad_norm.append(self._backward(loss[i], i)) - pbar.set_description("Batch Loss {:.4f}, grad norm {:.4f}".format(torch.mean(torch.stack(loss)).item(), torch.mean(torch.stack(grad_norm)).item())) + pbar.set_description("Batch Loss {:.4f}, grad norm {:.4f}".format(torch.mean( + torch.stack(loss)).item(), torch.mean(torch.stack(grad_norm)).item())) # Compute metrics # TODO: revert _metrics to be empty per batch, so metrics are logged per batch, not per epoch - # keep option to log metrics per epoch + # keep option to log metrics per epoch for n in range(len(self.model)): - _metrics[n] = self._compute_metrics(out_list[n], batch[n], _metrics[n]) - self.metrics[n] = self.evaluator.update("loss", loss[n].item(), out_list[n]["output"].shape[0], _metrics[n]) + _metrics[n] = self._compute_metrics( + out_list[n], batch[n], _metrics[n]) + self.metrics[n] = self.evaluator.update( + "loss", loss[n].item(), out_list[n]["output"].shape[0], _metrics[n]) self.epoch = epoch + 1 if str(self.rank) not in ("cpu", "cuda"): dist.barrier() - # TODO: could add param to eval and save on increments instead of every time - - # Save current model - torch.cuda.empty_cache() + # TODO: could add param to eval and save on increments instead of every time + + # Save current model + torch.cuda.empty_cache() if str(self.rank) in ("0", "cpu", "cuda"): if self.model_save_frequency == 1: - self.save_model(checkpoint_file="checkpoint.pt", training_state=True) + self.save_model( + checkpoint_file="checkpoint.pt", training_state=True) # Evaluate on validation set if it exists if self.data_loader[0].get("val_loader"): - metric = self.validate("val") + metric = self.validate("val") else: metric = self.metrics @@ -159,48 +173,53 @@ def train(self): if metric[i][type(self.loss_fn).__name__]["metric"] < self.best_metric[i]: if self.output_frequency == 0: if self.model_save_frequency == 1: - self.update_best_model(metric[i], i, write_model=True, write_csv=False) + self.update_best_model( + metric[i], i, write_model=True, write_csv=False) else: - self.update_best_model(metric[i], i, write_model=False, write_csv=False) + self.update_best_model( + metric[i], i, write_model=False, write_csv=False) elif self.output_frequency == 1: if self.model_save_frequency == 1: - self.update_best_model(metric[i], i, write_model=True, write_csv=True) + self.update_best_model( + metric[i], i, write_model=True, write_csv=True) else: - self.update_best_model(metric[i], i, write_model=False, write_csv=True) - + self.update_best_model( + metric[i], i, write_model=False, write_csv=True) + self._scheduler_step() - - torch.cuda.empty_cache() - + torch.cuda.empty_cache() + if self.best_model_state: for i in range(len(self.model)): if str(self.rank) in "0": - self.model[i].module.load_state_dict(self.best_model_state[i]) + self.model[i].module.load_state_dict( + self.best_model_state[i]) elif str(self.rank) in ("cpu", "cuda"): self.model[i].load_state_dict(self.best_model_state[i]) - #if self.data_loader.get("test_loader"): + # if self.data_loader.get("test_loader"): # metric = self.validate("test") # test_loss = metric[type(self.loss_fn).__name__]["metric"] - #else: - # test_loss = "N/A" + # else: + # test_loss = "N/A" if self.model_save_frequency != -1: - self.save_model("best_checkpoint.pt", index=None, metric=metric, training_state=True) - logging.info("Final Losses: ") + self.save_model("best_checkpoint.pt", index=None, + metric=metric, training_state=True) + logging.info("Final Losses: ") if "train" in self.write_output: self.predict(self.data_loader[0]["train_loader"], "train") if "val" in self.write_output and self.data_loader[0].get("val_loader"): self.predict(self.data_loader[0]["val_loader"], "val") if "test" in self.write_output and self.data_loader[0].get("test_loader"): - self.predict(self.data_loader[0]["test_loader"], "test") - - return self.best_model_state - + self.predict(self.data_loader[0]["test_loader"], "test") + + return self.best_model_state, readout + @torch.no_grad() def validate(self, split="val"): for i in range(len(self.model)): self.model[i].eval() - + evaluator, metrics = Evaluator(), [{} for _ in range(len(self.model))] loader_iter = [] @@ -211,31 +230,33 @@ def validate(self, split="val"): loader_iter.append(iter(self.data_loader[i]["test_loader"])) elif split == "train": loader_iter.append(iter(self.data_loader[i]["train_loader"])) - + for i in range(0, len(loader_iter[0])): - #print(i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024)) + # print(i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024)) batch = [] for i in range(len(self.model)): batch.append(next(loader_iter[i]).to(self.rank)) - + out_list = self._forward(batch) loss = self._compute_loss(out_list, batch) # Compute metrics - #print(i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024)) + # print(i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024)) for n in range(len(self.model)): - metrics[n] = self._compute_metrics(out_list[n], batch[n], metrics[n]) - metrics[n] = evaluator.update("loss", loss[n].item(), out_list[n]["output"].shape[0], metrics[n]) + metrics[n] = self._compute_metrics( + out_list[n], batch[n], metrics[n]) + metrics[n] = evaluator.update( + "loss", loss[n].item(), out_list[n]["output"].shape[0], metrics[n]) del loss, batch, out_list - + torch.cuda.empty_cache() - + return metrics @torch.no_grad() - def predict(self, loader, split, results_dir="train_results", write_output=True, labels=True): + def predict(self, loader, split, results_dir="train_results", write_output=True, labels=True): for mod in self.model: mod.eval() - + # assert isinstance(loader, torch.utils.data.dataloader.DataLoader) # TODO: make this compatible with model ensemble @@ -243,7 +264,7 @@ def predict(self, loader, split, results_dir="train_results", write_output=True, loader = get_dataloader( loader.dataset, batch_size=loader.batch_size, sampler=None ) - + evaluator, metrics = Evaluator(), {} predict, target = None, None ids = [] @@ -251,15 +272,15 @@ def predict(self, loader, split, results_dir="train_results", write_output=True, target_pos_grad = None ids_cell_grad = [] target_cell_grad = None - node_level = False - - loader_iter = iter(loader) + node_level = False + + loader_iter = iter(loader) for i in range(0, len(loader_iter)): batch = next(loader_iter).to(self.rank) out_list = self._forward([batch]) - + out = {} - out_stack={} + out_stack = {} for key in out_list[0].keys(): temp = [o[key] for o in out_list] if temp[0] is not None: @@ -269,12 +290,11 @@ def predict(self, loader, split, results_dir="train_results", write_output=True, else: out[key] = None out[key+"_std"] = None - - + batch_p = [o["output"].data.cpu().numpy() for o in out_list] - batch_p_mean = out["output"].cpu().numpy() + batch_p_mean = out["output"].cpu().numpy() batch_ids = batch.structure_id - batch_stds = out["output_std"].cpu().numpy() + batch_stds = out["output_std"].cpu().numpy() if labels == True: loss = self._compute_loss(out, batch) @@ -283,12 +303,13 @@ def predict(self, loader, split, results_dir="train_results", write_output=True, "loss", loss.item(), out["output"].shape[0], metrics ) if str(self.rank) not in ("cpu", "cuda"): - batch_t = batch[self.model[0].module.target_attr].cpu().numpy() + batch_t = batch[self.model[0].module.target_attr].cpu( + ).numpy() else: batch_t = batch[self.model[0].target_attr].cpu().numpy() - - # Node level prediction - if batch_p[0].shape[0] > loader.batch_size: + + # Node level prediction + if batch_p[0].shape[0] > loader.batch_size: node_level = True node_ids = batch.z.cpu().numpy() structure_ids = np.repeat( @@ -303,137 +324,157 @@ def predict(self, loader, split, results_dir="train_results", write_output=True, structure_ids_pos_grad = np.repeat( batch.structure_id, batch.n_atoms.cpu().numpy(), axis=0 ) - batch_ids_pos_grad = np.column_stack((structure_ids_pos_grad, node_ids_pos_grad)) - ids_pos_grad = batch_ids_pos_grad if i == 0 else np.row_stack((ids_pos_grad, batch_ids_pos_grad)) - predict_pos_grad = batch_p_pos_grad if i == 0 else np.concatenate((predict_pos_grad, batch_p_pos_grad), axis=0) - predict_pos_grad_std = batch_p_pos_grad_std if i == 0 else np.concatenate((predict_pos_grad_std, batch_p_pos_grad_std), axis=0) + batch_ids_pos_grad = np.column_stack( + (structure_ids_pos_grad, node_ids_pos_grad)) + ids_pos_grad = batch_ids_pos_grad if i == 0 else np.row_stack( + (ids_pos_grad, batch_ids_pos_grad)) + predict_pos_grad = batch_p_pos_grad if i == 0 else np.concatenate( + (predict_pos_grad, batch_p_pos_grad), axis=0) + predict_pos_grad_std = batch_p_pos_grad_std if i == 0 else np.concatenate( + (predict_pos_grad_std, batch_p_pos_grad_std), axis=0) if "forces" in batch: - batch_t_pos_grad = batch["forces"].cpu().numpy() - target_pos_grad = batch_t_pos_grad if i == 0 else np.concatenate((target_pos_grad, batch_t_pos_grad), axis=0) - - if out.get("cell_grad") != None: - batch_p_cell_grad = out["cell_grad"].data.view(out["cell_grad"].data.size(0), -1).cpu().numpy() - batch_p_cell_grad_std = out["cell_grad_std"].data.view(out["cell_grad"].data.size(0), -1).cpu().numpy() - batch_ids_cell_grad = batch.structure_id - ids_cell_grad = batch_ids_cell_grad if i == 0 else np.row_stack((ids_cell_grad, batch_ids_cell_grad)) - predict_cell_grad = batch_p_cell_grad if i == 0 else np.concatenate((predict_cell_grad, batch_p_cell_grad), axis=0) - predict_cell_grad_std = batch_p_cell_grad_std if i == 0 else np.concatenate((predict_cell_grad_std, batch_p_cell_grad_std), axis=0) + batch_t_pos_grad = batch["forces"].cpu().numpy() + target_pos_grad = batch_t_pos_grad if i == 0 else np.concatenate( + (target_pos_grad, batch_t_pos_grad), axis=0) + + if out.get("cell_grad") != None: + batch_p_cell_grad = out["cell_grad"].data.view( + out["cell_grad"].data.size(0), -1).cpu().numpy() + batch_p_cell_grad_std = out["cell_grad_std"].data.view( + out["cell_grad"].data.size(0), -1).cpu().numpy() + batch_ids_cell_grad = batch.structure_id + ids_cell_grad = batch_ids_cell_grad if i == 0 else np.row_stack( + (ids_cell_grad, batch_ids_cell_grad)) + predict_cell_grad = batch_p_cell_grad if i == 0 else np.concatenate( + (predict_cell_grad, batch_p_cell_grad), axis=0) + predict_cell_grad_std = batch_p_cell_grad_std if i == 0 else np.concatenate( + (predict_cell_grad_std, batch_p_cell_grad_std), axis=0) if "stress" in batch: - batch_t_cell_grad = batch["stress"].view(out["cell_grad"].data.size(0), -1).cpu().numpy() - target_cell_grad = batch_t_cell_grad if i == 0 else np.concatenate((target_cell_grad, batch_t_cell_grad), axis=0) - - ids = batch_ids if i == 0 else np.row_stack((ids, batch_ids)) - predict_mean = batch_p_mean if i == 0 else np.concatenate((predict_mean, batch_p_mean), axis=0) - stds = batch_stds if i == 0 else np.row_stack((stds, batch_stds)) - if i == 0: - predict = [0 for _ in range(len(self.model))] + batch_t_cell_grad = batch["stress"].view( + out["cell_grad"].data.size(0), -1).cpu().numpy() + target_cell_grad = batch_t_cell_grad if i == 0 else np.concatenate( + (target_cell_grad, batch_t_cell_grad), axis=0) + + ids = batch_ids if i == 0 else np.row_stack((ids, batch_ids)) + predict_mean = batch_p_mean if i == 0 else np.concatenate( + (predict_mean, batch_p_mean), axis=0) + stds = batch_stds if i == 0 else np.row_stack((stds, batch_stds)) + if i == 0: + predict = [0 for _ in range(len(self.model))] for x in range(len(self.model)): - predict[x] = batch_p[x] if i == 0 else np.concatenate((predict[x], batch_p[x]), axis=0) + predict[x] = batch_p[x] if i == 0 else np.concatenate( + (predict[x], batch_p[x]), axis=0) if labels == True: - target = batch_t if i == 0 else np.concatenate((target, batch_t), axis=0) - + target = batch_t if i == 0 else np.concatenate( + (target, batch_t), axis=0) + if labels == True: - del loss, batch, out - else: - del batch, out - + del loss, batch, out + else: + del batch, out + if write_output == True: if labels == True: - if len(self.model) > 1: + if len(self.model) > 1: self.save_results( np.column_stack((ids, target, predict_mean, stds)), results_dir, f"{split}_predictions.csv", node_level, True, std=True, - ) + ) for x in range(len(self.model)): - mod = str(x) + mod = str(x) self.save_results( np.column_stack((ids, target, predict[x])), results_dir, f"{split}_predictions_{mod}.csv", node_level, True, std=False, ) - else: + else: self.save_results( np.column_stack((ids, target, predict_mean)), results_dir, f"{split}_predictions.csv", node_level, True, std=False, - ) + ) else: - if len(self.model) > 1: + if len(self.model) > 1: self.save_results( np.column_stack((ids, predict_mean, stds)), results_dir, f"{split}_predictions.csv", node_level, False, std=True, - ) + ) for x in range(len(self.model)): mod = str(x) self.save_results( - np.column_stack((ids, predict[x])), results_dir, f"{split}_predictions_{mod}.csv", node_level, False, std=False, + np.column_stack((ids, predict[x])), results_dir, f"{split}_predictions_{mod}.csv", node_level, False, std=False, ) - else: + else: self.save_results( np.column_stack((ids, predict_mean)), results_dir, f"{split}_predictions.csv", node_level, False, std=False, - ) - #if out.get("pos_grad") != None: + ) + # if out.get("pos_grad") != None: if len(ids_pos_grad) > 0: if isinstance(target_pos_grad, np.ndarray): - if len(self.model) > 1: + if len(self.model) > 1: self.save_results( np.column_stack((ids_pos_grad, target_pos_grad, predict_pos_grad, predict_pos_grad_std)), results_dir, f"{split}_predictions_pos_grad.csv", True, True, std=True ) else: self.save_results( np.column_stack((ids_pos_grad, target_pos_grad, predict_pos_grad)), results_dir, f"{split}_predictions_pos_grad.csv", True, True, std=False - ) + ) else: self.save_results( np.column_stack((ids_pos_grad, predict_pos_grad)), results_dir, f"{split}_predictions_pos_grad.csv", True, False, std=False ) - #if out.get("cell_grad") != None: + # if out.get("cell_grad") != None: if len(ids_cell_grad) > 0: if isinstance(target_cell_grad, np.ndarray): - if len(self.model) > 1: + if len(self.model) > 1: self.save_results( np.column_stack((ids_cell_grad, target_cell_grad, predict_cell_grad, predict_cell_grad_std)), results_dir, f"{split}_predictions_cell_grad.csv", False, True, std=True ) else: self.save_results( np.column_stack((ids_cell_grad, target_cell_grad, predict_cell_grad)), results_dir, f"{split}_predictions_cell_grad.csv", False, True, std=False - ) + ) else: self.save_results( np.column_stack((ids_cell_grad, predict_cell_grad)), results_dir, f"{split}_predictions_cell_grad.csv", False, False, std=False ) - + if labels == True: predict_loss = metrics[type(self.loss_fn).__name__]["metric"] - logging.info("Saved {:s} error: {:.5f}".format(split, predict_loss)) - if len(self.model) > 1: - predictions = {"ids":ids, "predict":predict_mean, "target":target, "std": stds} + logging.info("Saved {:s} error: {:.5f}".format( + split, predict_loss)) + if len(self.model) > 1: + predictions = {"ids": ids, "predict": predict_mean, + "target": target, "std": stds} else: - predictions = {"ids":ids, "predict":predict_mean, "target":target} + predictions = {"ids": ids, + "predict": predict_mean, "target": target} else: - if len(self.model) > 1: - predictions = {"ids":ids, "predict":predict_mean, "std": stds} + if len(self.model) > 1: + predictions = {"ids": ids, + "predict": predict_mean, "std": stds} else: - predictions = {"ids":ids, "predict":predict_mean} + predictions = {"ids": ids, "predict": predict_mean} torch.cuda.empty_cache() - + return predictions - - def predict_by_calculator(self, loader): + + def predict_by_calculator(self, loader): for x, mod in self.model: mod.eval() - + assert isinstance(loader, torch.utils.data.dataloader.DataLoader) - assert len(loader) == 1, f"Predicting by calculator only allows one structure at a time, but got {len(loader)} structures." + assert len( + loader) == 1, f"Predicting by calculator only allows one structure at a time, but got {len(loader)} structures." if str(self.rank) not in ("cpu", "cuda"): loader = get_dataloader( loader.dataset, batch_size=loader.batch_size, sampler=None ) - + results = [] loader_iter = iter(loader) for i in range(0, len(loader_iter)): - batch = next(loader_iter).to(self.rank) + batch = next(loader_iter).to(self.rank) out_list = self._forward(batch.to(self.rank)) out = {} - out_stack={} + out_stack = {} for key in out_list[0].keys(): temp = [o[key] for o in out_list] if temp[0] is not None: @@ -442,12 +483,15 @@ def predict_by_calculator(self, loader): else: out[key] = None - energy = None if out.get('output') is None else out.get('output').data.cpu().numpy() - stress = None if out.get('cell_grad') is None else out.get('cell_grad').view(-1, 3).data.cpu().numpy() - forces = None if out.get('pos_grad') is None else out.get('pos_grad').data.cpu().numpy() - + energy = None if out.get('output') is None else out.get( + 'output').data.cpu().numpy() + stress = None if out.get('cell_grad') is None else out.get( + 'cell_grad').view(-1, 3).data.cpu().numpy() + forces = None if out.get('pos_grad') is None else out.get( + 'pos_grad').data.cpu().numpy() + results = {'energy': energy, 'stress': stress, 'forces': forces} - + return results def _forward(self, batch_data): @@ -480,9 +524,8 @@ def _backward(self, loss, index=None): ) self.scaler.step(self.optimizer[index]) self.scaler.update() - - return grad_norm + return grad_norm def _compute_metrics(self, out, batch_data, metrics): # TODO: finish this method @@ -493,14 +536,15 @@ def _compute_metrics(self, out, batch_data, metrics): metrics = self.evaluator.eval( out, property_target, self.loss_fn, prev_metrics=metrics - ) + ) return metrics def _log_metrics(self, val_metrics=None): - train_loss = [torch.tensor(i[type(self.loss_fn).__name__]["metric"]) for i in self.metrics] + train_loss = [torch.tensor( + i[type(self.loss_fn).__name__]["metric"]) for i in self.metrics] train_loss = torch.mean(torch.stack(train_loss)).item() - lr = self.scheduler[0].lr + lr = self.scheduler[0].lr if not val_metrics: val_loss = "N/A" logging.info( @@ -513,7 +557,8 @@ def _log_metrics(self, val_metrics=None): ) ) else: - val_loss = [torch.tensor(i[type(self.loss_fn).__name__]["metric"]) for i in val_metrics] + val_loss = [torch.tensor( + i[type(self.loss_fn).__name__]["metric"]) for i in val_metrics] val_loss = torch.mean(torch.stack(val_loss)).item() lr = self.scheduler[0].lr logging.info( @@ -526,7 +571,6 @@ def _log_metrics(self, val_metrics=None): ) ) - def _load_task(self): """Initializes task-specific info. Implemented by derived classes.""" pass @@ -535,7 +579,8 @@ def _scheduler_step(self): for i in range(len(self.model)): if self.scheduler[i].scheduler_type == "ReduceLROnPlateau": self.scheduler[i].step( - metrics=self.metrics[i][type(self.loss_fn).__name__]["metric"] + metrics=self.metrics[i][type( + self.loss_fn).__name__]["metric"] ) else: self.scheduler[i].step()