-
Notifications
You must be signed in to change notification settings - Fork 152
feat(schnetpack): add enhanced ASE database format support for robust SchNetPack compatibility #879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
Copilot
wants to merge
6
commits into
devel
Choose a base branch
from
copilot/fix-877
base: devel
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 2 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
3e00d01
Initial plan
Copilot f820435
feat(schnetpack): implement SchNetPack format conversion support
Copilot 6474d6c
fix(schnetpack): remove SchNetPack dependency, use only ASE database
Copilot b88f527
test(schnetpack): add comprehensive test for user's SchNetPack script…
Copilot 4a0e33c
chore: exclude database files from git and remove test artifact
Copilot d91bf5b
fix(schnetpack): improve database compatibility and error handling
Copilot File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from dpdata.format import Format | ||
|
|
||
|
|
||
| @Format.register("schnetpack") | ||
| class SchNetPackFormat(Format): | ||
| """Format for SchNetPack ASE database. | ||
|
|
||
| SchNetPack uses ASE database format internally for storing atomic structures | ||
| and their properties. This format converts dpdata LabeledSystem to SchNetPack's | ||
| ASE database format. | ||
|
|
||
| For more information, see: | ||
| https://schnetpack.readthedocs.io/en/latest/tutorials/tutorial_01_preparing_data.html | ||
| """ | ||
|
|
||
| def to_labeled_system( | ||
| self, | ||
| data: dict, | ||
| file_name: str = "schnetpack_data.db", | ||
| distance_unit: str = "Ang", | ||
| property_unit_dict: dict | None = None, | ||
| **kwargs, | ||
| ) -> None: | ||
| """Convert dpdata LabeledSystem to SchNetPack ASE database format. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data : dict | ||
| dpdata LabeledSystem data dictionary | ||
| file_name : str, optional | ||
| Path to the output SchNetPack database file, by default "schnetpack_data.db" | ||
| distance_unit : str, optional | ||
| Unit for distances, by default "Ang" | ||
| property_unit_dict : dict, optional | ||
| Dictionary mapping property names to their units. | ||
| If None, defaults to {"energy": "eV", "forces": "eV/Ang"} | ||
| **kwargs : dict | ||
| Additional keyword arguments | ||
|
|
||
| Raises | ||
| ------ | ||
| ImportError | ||
| If ASE or SchNetPack are not available | ||
| """ | ||
| try: | ||
| from ase import Atoms | ||
| from schnetpack.data import ASEAtomsData | ||
| except ImportError as e: | ||
| raise ImportError( | ||
| "ASE and SchNetPack are required for schnetpack format. " | ||
| "Install with: pip install ase schnetpack" | ||
| ) from e | ||
|
|
||
| # Set default units if not provided | ||
| if property_unit_dict is None: | ||
| property_unit_dict = {"energy": "eV", "forces": "eV/Ang"} | ||
|
|
||
| # Convert dpdata to list of ASE Atoms and property list | ||
| atoms_list = [] | ||
| property_list = [] | ||
|
|
||
| species = [data["atom_names"][tt] for tt in data["atom_types"]] | ||
| nframes = data["coords"].shape[0] | ||
|
|
||
| for frame_idx in range(nframes): | ||
| # Create ASE Atoms object for this frame | ||
| atoms = Atoms( | ||
| symbols=species, | ||
| positions=data["coords"][frame_idx], | ||
| pbc=not data.get("nopbc", False), | ||
| cell=data["cells"][frame_idx], | ||
| ) | ||
| atoms_list.append(atoms) | ||
|
|
||
| # Create property dictionary for this frame | ||
| properties = {} | ||
|
|
||
| # Add energy | ||
| if "energies" in data: | ||
| properties["energy"] = float(data["energies"][frame_idx]) | ||
|
|
||
| # Add forces | ||
| if "forces" in data: | ||
| properties["forces"] = data["forces"][frame_idx] | ||
|
|
||
| # Add virials if present (SchNetPack doesn't have built-in support, | ||
| # but can be stored as additional property) | ||
| if "virials" in data: | ||
| properties["virials"] = data["virials"][frame_idx] | ||
|
|
||
| property_list.append(properties) | ||
|
|
||
| # Create SchNetPack ASE database | ||
| dataset = ASEAtomsData.create( | ||
| file_name, | ||
| distance_unit=distance_unit, | ||
| property_unit_dict=property_unit_dict, | ||
| ) | ||
|
|
||
| # Add all systems to the database | ||
| dataset.add_systems(property_list, atoms_list) | ||
|
|
||
| return None | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,238 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import os | ||
| import tempfile | ||
| import unittest | ||
| from unittest.mock import patch | ||
|
|
||
| from context import dpdata | ||
|
|
||
|
|
||
| class TestSchNetPackRegistration(unittest.TestCase): | ||
| """Test SchNetPack format registration and error handling.""" | ||
|
|
||
| def test_format_registered(self): | ||
| """Test that schnetpack format is properly registered.""" | ||
| test_system = dpdata.LabeledSystem() | ||
| test_system.data = { | ||
| "atom_names": ["H"], | ||
| "atom_numbs": [1], | ||
| "atom_types": [0], | ||
| "cells": [[[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]], | ||
| "coords": [[[0.0, 0.0, 0.0]]], | ||
| "orig": [0.0, 0.0, 0.0], | ||
| "energies": [1.0], | ||
| "forces": [[[0.0, 0.0, 0.0]]], | ||
| } | ||
|
|
||
| # This should raise ImportError since SchNetPack is not available | ||
| with self.assertRaises(ImportError) as cm: | ||
| test_system.to("schnetpack", "/tmp/test.db") | ||
|
|
||
| self.assertIn("ASE and SchNetPack are required", str(cm.exception)) | ||
|
|
||
|
|
||
| try: | ||
| from schnetpack.data import ASEAtomsData | ||
|
|
||
| schnetpack_available = True | ||
| except ImportError: | ||
| schnetpack_available = False | ||
|
|
||
|
|
||
| @unittest.skipIf(not schnetpack_available, "skip test_schnetpack") | ||
| class TestSchNetPack(unittest.TestCase): | ||
| def setUp(self): | ||
| # Create a simple test system | ||
| self.test_system = dpdata.System() | ||
|
|
||
| # Create simple water-like structure for testing | ||
| # 3 atoms: O, H, H | ||
| self.test_system.data = { | ||
| "atom_names": ["O", "H"], | ||
| "atom_numbs": [1, 2], | ||
| "atom_types": [0, 1, 1], # O, H, H | ||
| "cells": [[[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]], | ||
| "coords": [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], | ||
| "orig": [0.0, 0.0, 0.0], | ||
| } | ||
|
|
||
| # Create labeled system with dummy energies and forces | ||
| self.labeled_system = dpdata.LabeledSystem() | ||
| self.labeled_system.data = self.test_system.data.copy() | ||
| self.labeled_system.data["energies"] = [-10.5] # eV | ||
| self.labeled_system.data["forces"] = [ | ||
| [[0.1, 0.0, 0.0], [0.0, 0.1, 0.0], [0.0, 0.0, 0.1]] | ||
| ] # eV/Ang | ||
|
|
||
| # Optional: add virials | ||
| self.labeled_system.data["virials"] = [ | ||
| [[0.01, 0.0, 0.0], [0.0, 0.01, 0.0], [0.0, 0.0, 0.01]] | ||
| ] | ||
|
|
||
| def test_to_schnetpack(self): | ||
| """Test conversion to SchNetPack format.""" | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| db_file = os.path.join(tmpdir, "test_data.db") | ||
|
|
||
| # Convert to SchNetPack format | ||
| self.labeled_system.to("schnetpack", db_file) | ||
|
|
||
| # Verify the database was created | ||
| self.assertTrue(os.path.exists(db_file)) | ||
|
|
||
| # Load the database and verify contents | ||
| dataset = ASEAtomsData(db_file) | ||
|
|
||
| # Check number of structures | ||
| self.assertEqual(len(dataset), 1) | ||
|
|
||
| # Check structure properties | ||
| data_point = dataset[0] | ||
|
|
||
| # Check basic structure information | ||
| self.assertEqual(len(data_point["_atomic_numbers"]), 3) # O, H, H | ||
|
|
||
| # Check that properties are present | ||
| self.assertIn("energy", data_point) | ||
| self.assertIn("forces", data_point) | ||
|
|
||
| # Check energy value | ||
| self.assertAlmostEqual(float(data_point["energy"]), -10.5, places=5) | ||
|
|
||
| # Check forces shape | ||
| self.assertEqual( | ||
| data_point["forces"].shape, (3, 3) | ||
| ) # 3 atoms, 3 components | ||
|
|
||
| def test_to_schnetpack_custom_units(self): | ||
| """Test conversion with custom units.""" | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| db_file = os.path.join(tmpdir, "test_data_units.db") | ||
|
|
||
| # Convert with custom units | ||
| property_units = {"energy": "kcal/mol", "forces": "kcal/mol/Ang"} | ||
|
|
||
| self.labeled_system.to( | ||
| "schnetpack", db_file, property_unit_dict=property_units | ||
| ) | ||
|
|
||
| # Verify the database was created | ||
| self.assertTrue(os.path.exists(db_file)) | ||
|
|
||
| # Load and verify | ||
| dataset = ASEAtomsData(db_file) | ||
| self.assertEqual(len(dataset), 1) | ||
|
|
||
| # Basic verification that data is present | ||
| data_point = dataset[0] | ||
| self.assertIn("energy", data_point) | ||
| self.assertIn("forces", data_point) | ||
|
|
||
| def test_to_schnetpack_without_virials(self): | ||
| """Test conversion without virials.""" | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| db_file = os.path.join(tmpdir, "test_no_virials.db") | ||
|
|
||
| # Create system without virials | ||
| system_no_virials = dpdata.LabeledSystem() | ||
| system_no_virials.data = self.test_system.data.copy() | ||
| system_no_virials.data["energies"] = [-10.5] | ||
| system_no_virials.data["forces"] = [ | ||
| [[0.1, 0.0, 0.0], [0.0, 0.1, 0.0], [0.0, 0.0, 0.1]] | ||
| ] | ||
|
|
||
| # Convert to SchNetPack format | ||
| system_no_virials.to("schnetpack", db_file) | ||
|
|
||
| # Verify the database was created | ||
| self.assertTrue(os.path.exists(db_file)) | ||
|
|
||
| # Load and verify | ||
| dataset = ASEAtomsData(db_file) | ||
| self.assertEqual(len(dataset), 1) | ||
|
|
||
| data_point = dataset[0] | ||
| self.assertIn("energy", data_point) | ||
| self.assertIn("forces", data_point) | ||
| # virials should not be present or can be ignored | ||
|
|
||
| def test_multiframe_system(self): | ||
| """Test conversion of multi-frame system.""" | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| db_file = os.path.join(tmpdir, "test_multiframe.db") | ||
|
|
||
| # Create multi-frame system | ||
| multiframe_system = dpdata.LabeledSystem() | ||
| multiframe_system.data = { | ||
| "atom_names": ["O", "H"], | ||
| "atom_numbs": [1, 2], | ||
| "atom_types": [0, 1, 1], | ||
| "cells": [ | ||
| [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]], | ||
| [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]], | ||
| ], | ||
| "coords": [ | ||
| [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], | ||
| [[0.1, 0.0, 0.0], [1.1, 0.0, 0.0], [0.1, 1.0, 0.0]], | ||
| ], | ||
| "orig": [0.0, 0.0, 0.0], | ||
| "energies": [-10.5, -10.6], | ||
| "forces": [ | ||
| [[0.1, 0.0, 0.0], [0.0, 0.1, 0.0], [0.0, 0.0, 0.1]], | ||
| [[0.2, 0.0, 0.0], [0.0, 0.2, 0.0], [0.0, 0.0, 0.2]], | ||
| ], | ||
| } | ||
|
|
||
| # Convert to SchNetPack format | ||
| multiframe_system.to("schnetpack", db_file) | ||
|
|
||
| # Verify the database was created | ||
| self.assertTrue(os.path.exists(db_file)) | ||
|
|
||
| # Load and verify | ||
| dataset = ASEAtomsData(db_file) | ||
|
|
||
| # Should have 2 frames | ||
| self.assertEqual(len(dataset), 2) | ||
|
|
||
| # Check both frames | ||
| for i in range(2): | ||
| data_point = dataset[i] | ||
| self.assertIn("energy", data_point) | ||
| self.assertIn("forces", data_point) | ||
| self.assertEqual(len(data_point["_atomic_numbers"]), 3) | ||
|
|
||
|
|
||
| class TestSchNetPackMocked(unittest.TestCase): | ||
| """Test SchNetPack functionality with mocked dependencies.""" | ||
|
|
||
| def setUp(self): | ||
| # Create a simple test system | ||
| self.labeled_system = dpdata.LabeledSystem() | ||
| self.labeled_system.data = { | ||
| "atom_names": ["O", "H"], | ||
| "atom_numbs": [1, 2], | ||
| "atom_types": [0, 1, 1], # O, H, H | ||
| "cells": [[[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]], | ||
| "coords": [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], | ||
| "orig": [0.0, 0.0, 0.0], | ||
| "energies": [-10.5], # eV | ||
| "forces": [[[0.1, 0.0, 0.0], [0.0, 0.1, 0.0], [0.0, 0.0, 0.1]]], # eV/Ang | ||
| } | ||
|
|
||
| @patch("dpdata.plugins.schnetpack.SchNetPackFormat.to_labeled_system") | ||
| def test_conversion_logic_mocked(self, mock_to_labeled_system): | ||
| """Test the conversion logic with mocked dependencies.""" | ||
| # Test that the method can be called | ||
| mock_to_labeled_system.return_value = None | ||
|
|
||
| # Test the conversion - should call the mocked method | ||
| self.labeled_system.to("schnetpack", "/tmp/test.db") | ||
|
|
||
| # Verify the method was called | ||
| mock_to_labeled_system.assert_called_once() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.