Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve H5MD trajectory support #655

Merged
merged 9 commits into from
Feb 11, 2025
90 changes: 59 additions & 31 deletions MDANSE/Src/MDANSE/Trajectory/H5MDTrajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,34 @@ def __init__(self, h5_filename: Union[Path, str]):
self._h5_filename = Path(h5_filename)

self._h5_file = h5py.File(self._h5_filename, "r")

# Load the chemical system
try:
particle_types = self._h5_file["/particles/all/species"]
particle_lookup = h5py.check_enum_dtype(
self._h5_file["/particles/all/species"].dtype
)
if particle_lookup is None:
# Load the chemical system
try:
symbols = self._h5_file["/parameters/atom_symbols"]
except KeyError:
LOG.error(
f"No information about chemical elements in {self._h5_filename}"
)
return
else:
chemical_elements = [byte.decode() for byte in symbols]
else:
reverse_lookup = {item: key for key, item in particle_lookup.items()}
chemical_elements = [
byte.decode() for byte in self._h5_file["/parameters/atom_symbols"]
reverse_lookup[type_number] for type_number in particle_types
]
except KeyError:
chemical_elements = self._h5_file["/particles/all/species"]
self._chemical_system = ChemicalSystem(self._h5_filename.stem)
self._chemical_system.initialise_atoms(chemical_elements)
try:
self._chemical_system.initialise_atoms(chemical_elements)
except Exception:
LOG.error(
"It was not possible to read chemical element information from an H5MD file."
)
return

# Load all the unit cells
self._load_unit_cells()
Expand All @@ -70,10 +88,10 @@ def __init__(self, h5_filename: Union[Path, str]):
coords = self._h5_file["/particles/all/position/value"][0, :, :]
try:
pos_unit = self._h5_file["/particles/all/position/value"].attrs["unit"]
except:
except Exception:
conv_factor = 1.0
else:
if pos_unit == "Ang":
if pos_unit in ("Ang", "Angstrom"):
pos_unit = "ang"
conv_factor = measure(1.0, pos_unit).toval("nm")
coords *= conv_factor
Expand All @@ -92,6 +110,7 @@ def file_is_right(self, filename):
temp["h5md"]
except KeyError:
result = False
temp.close()
return result

def close(self):
Expand All @@ -109,13 +128,12 @@ def __getitem__(self, frame):
:rtype: dict of ndarray
"""

grp = self._h5_file["/particles/all/position/value"]
try:
pos_unit = self._h5_file["/particles/all/position/value"].attrs["unit"]
except:
except Exception:
conv_factor = 1.0
else:
if pos_unit == "Ang":
if pos_unit in ("Ang", "Angstrom"):
pos_unit = "ang"
conv_factor = measure(1.0, pos_unit).toval("nm")
configuration = {}
Expand All @@ -125,12 +143,12 @@ def __getitem__(self, frame):
try:
try:
vel_unit = self._h5_file["/particles/all/velocity/value"].attrs["unit"]
except:
except Exception:
vel_unit = "ang/fs"
configuration["velocities"] = self._h5_file[
"/particles/all/velocity/value"
][frame, :, :] * measure(1.0, vel_unit).toval("nm/ps")
except:
except Exception:
pass

configuration["time"] = self.time()[frame]
Expand Down Expand Up @@ -167,7 +185,7 @@ def charges(self, frame):
except KeyError:
LOG.debug(f"No charge information in trajectory {self._h5_filename}")
charge = np.zeros(self._chemical_system.number_of_atoms)
except:
except Exception:
try:
charge = self._h5_file["/particles/all/charge"][:]
except KeyError:
Expand All @@ -190,10 +208,10 @@ def coordinates(self, frame):
raise IndexError(f"Invalid frame number: {frame}")
try:
pos_unit = self._h5_file["/particles/all/position/value"].attrs["unit"]
except:
except Exception:
conv_factor = 1.0
else:
if pos_unit == "Ang":
if pos_unit in ("Ang", "Angstrom"):
pos_unit = "ang"
conv_factor = measure(1.0, pos_unit).toval("nm")

Expand Down Expand Up @@ -224,7 +242,7 @@ def configuration(self, frame):
if k not in self._variables_to_skip:
try:
variables[k] = self.variable(k)[frame, :, :].astype(np.float64)
except:
except Exception:
self._variables_to_skip.append(k)

coordinates = self.coordinates(frame)
Expand All @@ -243,10 +261,10 @@ def _load_unit_cells(self):
self._unit_cells = []
try:
box_unit = self._h5_file["/particles/all/box/edges/value"].attrs["unit"]
except:
conv_factor = 1.0
except (AttributeError, KeyError):
conv_factor = 0.1
else:
if box_unit == "Ang":
if box_unit == "Ang" or box_unit == "Angstrom":
MBartkowiakSTFC marked this conversation as resolved.
Show resolved Hide resolved
box_unit = "ang"
conv_factor = measure(1.0, box_unit).toval("nm")
try:
Expand All @@ -256,9 +274,16 @@ def _load_unit_cells(self):
else:
if len(cells.shape) > 1:
for cell in cells:
temp_array = np.array(
[[cell[0], 0.0, 0.0], [0.0, cell[1], 0.0], [0.0, 0.0, cell[2]]]
)
if cell.shape == (3, 3):
temp_array = np.array(cell)
else:
temp_array = np.array(
[
[cell[0], 0.0, 0.0],
[0.0, cell[1], 0.0],
[0.0, 0.0, cell[2]],
]
)
uc = UnitCell(temp_array)
self._unit_cells.append(uc)
else:
Expand All @@ -270,14 +295,17 @@ def _load_unit_cells(self):
def time(self):
try:
time_unit = self._h5_file["/particles/all/position/time"].attrs["unit"]
except:
except KeyError:
conv_factor = 1.0
else:
conv_factor = measure(1.0, time_unit).toval("ps")
try:
time = self._h5_file["/particles/all/position/time"] * conv_factor
except:
time = []
except TypeError:
try:
time = self._h5_file["/particles/all/position/time"][:] * conv_factor
except Exception:
time = []
return time

def unit_cell(self, frame):
Expand Down Expand Up @@ -366,10 +394,10 @@ def read_com_trajectory(
grp = self._h5_file["/particles/all/position/value"]
try:
pos_unit = self._h5_file["/particles/all/position/value"].attrs["unit"]
except:
except Exception:
conv_factor = 1.0
else:
if pos_unit == "Ang":
if pos_unit in ("Ang", "Angstrom"):
pos_unit = "ang"
conv_factor = measure(1.0, pos_unit).toval("nm")

Expand Down Expand Up @@ -464,10 +492,10 @@ def read_atomic_trajectory(
grp = self._h5_file["/particles/all/position/value"]
try:
pos_unit = self._h5_file["/particles/all/position/value"].attrs["unit"]
except:
except Exception:
conv_factor = 1.0
else:
if pos_unit == "Ang":
if pos_unit in ("Ang", "Angstrom"):
pos_unit = "ang"
conv_factor = measure(1.0, pos_unit).toval("nm")
coords = grp[first:last:step, index, :].astype(np.float64) * conv_factor
Expand Down