diff --git a/force_field_models/inference/inference.py b/force_field_models/inference/inference.py index 5c6816f..87ce6b2 100644 --- a/force_field_models/inference/inference.py +++ b/force_field_models/inference/inference.py @@ -66,6 +66,7 @@ def predict_energies(model, data_loader): for batch in tqdm(data_loader): batch = batch.to(DEVICE) output = model(batch, True) + output = output / 627.5094740631 for atom_idx, mol_idx in enumerate(batch.batch): atom_id = batch.x[0][atom_idx].item()