diff --git a/matdeeplearn/tasks/task.py b/matdeeplearn/tasks/task.py index 6a48172e..16ea18f5 100644 --- a/matdeeplearn/tasks/task.py +++ b/matdeeplearn/tasks/task.py @@ -75,6 +75,7 @@ def run(self): # if isinstance(self.trainer.data_loader, list): self.trainer.predict( loader=self.trainer.data_loader, split="predict", results_dir=results_dir, labels=self.config["task"]["labels"], + vmap_pred = self.config["task"].get("vmap_pred", False) ) # else: # self.trainer.predict( diff --git a/matdeeplearn/trainers/property_trainer.py b/matdeeplearn/trainers/property_trainer.py index 06916021..374dabf9 100644 --- a/matdeeplearn/trainers/property_trainer.py +++ b/matdeeplearn/trainers/property_trainer.py @@ -1,5 +1,6 @@ import logging import time +import copy import numpy as np import torch @@ -232,9 +233,16 @@ def validate(self, split="val"): return metrics @torch.no_grad() - def predict(self, loader, split, results_dir="train_results", write_output=True, labels=True): + def predict(self, loader, split, results_dir="train_results", write_output=True, labels=True, vmap_pred = False): for mod in self.model: mod.eval() + if vmap_pred: + params, buffers = stack_module_state(self.model) + base_model = copy.deepcopy(self.model[0]) + base_model = base_model.to('meta') + # TODO: Allow to work with pos_grad and cell_grad + def fmodel(params, buffers, x): + return functional_call(base_model, (params, buffers), (x,))['output'] # assert isinstance(loader, torch.utils.data.dataloader.DataLoader) @@ -256,25 +264,30 @@ def predict(self, loader, split, results_dir="train_results", write_output=True, loader_iter = iter(loader) for i in range(0, len(loader_iter)): batch = next(loader_iter).to(self.rank) - out_list = self._forward([batch]) - out = {} - out_stack={} - for key in out_list[0].keys(): - temp = [o[key] for o in out_list] - if temp[0] is not None: - out_stack[key] = torch.stack(temp) - out[key] = torch.mean(out_stack[key], dim=0) - out[key+"_std"] = torch.std(out_stack[key], dim=0) - else: - out[key] = None - out[key+"_std"] = None - + out_stack={} + if not vmap_pred: + out_list = self._forward([batch]) + for key in out_list[0].keys(): + temp = [o[key] for o in out_list] + if temp[0] is not None: + out_stack[key] = torch.stack(temp) + out[key] = torch.mean(out_stack[key], dim=0) + out[key+"_std"] = torch.std(out_stack[key], dim=0) + else: + out[key] = None + out[key+"_std"] = None + batch_p = [o["output"].data.cpu().numpy() for o in out_list] + + else: + out_list = vmap(fmodel, in_dims = (0, 0, None))(self.params, self.buffers, batch) + out["output"] = torch.mean(out_list, dim = 0) + out["output_std"] = torch.std(out_list, dim = 0) + batch_p = [out_list[o].cpu().numpy() for o in range(out_list.size()[0])] - batch_p = [o["output"].data.cpu().numpy() for o in out_list] - batch_p_mean = out["output"].cpu().numpy() + batch_p_mean = out["output"].cpu().numpy() + batch_stds = out["output_std"].cpu().numpy() batch_ids = batch.structure_id - batch_stds = out["output_std"].cpu().numpy() if labels == True: loss = self._compute_loss(out, batch)