diff --git a/docs/releases/ogstools-0.5.0.md b/docs/releases/ogstools-0.5.0.md index a0bff9d..8f21dab 100644 --- a/docs/releases/ogstools-0.5.0.md +++ b/docs/releases/ogstools-0.5.0.md @@ -37,6 +37,7 @@ - MeshSeries get extract() method to select points or cells via ids - MeshSeries can be sliced to get new MeshSeries with the selected subset of timesteps - MeshSeries gets a modify function that applies arbitrary function to all timestep - meshes. +- MeshSeries gets a save function (only for pvd implemented yet) - difference() between two meshes is now possible even with different topologies - Project write_input, path can be specified - MeshSeries gets scale() method to scale spatially or temporally diff --git a/ogstools/meshlib/ip_mesh.py b/ogstools/meshlib/ip_mesh.py index 12a9ca6..cf24ab6 100644 --- a/ogstools/meshlib/ip_mesh.py +++ b/ogstools/meshlib/ip_mesh.py @@ -1,5 +1,4 @@ from pathlib import Path -from tempfile import mkdtemp from typing import TypeVar import numpy as np @@ -197,10 +196,10 @@ def to_ip_point_cloud(mesh: Mesh) -> pv.UnstructuredGrid: for key in bad_keys: if key in _mesh.field_data: _mesh.field_data.remove(key) - tempdir = mkdtemp(prefix="to_ip_point_cloud") - input_file = Path(tempdir) / "ipDataToPointCloud_input.vtu" + parentpath = Path() if mesh.filepath is None else mesh.filepath.parent + input_file = parentpath / "ipDataToPointCloud_input.vtu" _mesh.save(input_file) - output_file = Path(tempdir) / "ip_mesh.vtu" + output_file = parentpath / "ip_mesh.vtu" ogs.cli.ipDataToPointCloud(i=str(input_file), o=str(output_file)) return pv.XMLUnstructuredGridReader(output_file).read() diff --git a/ogstools/meshlib/mesh.py b/ogstools/meshlib/mesh.py index 8a617c6..3cc720a 100644 --- a/ogstools/meshlib/mesh.py +++ b/ogstools/meshlib/mesh.py @@ -27,6 +27,8 @@ class Mesh(pv.UnstructuredGrid): Contains additional data and functions mainly for postprocessing. """ + filepath: Path | None = None + # pylint: disable=C0116 @copy_method_signature(data_processing.difference) def difference(self, *args: Any, **kwargs: Any) -> Any: @@ -94,6 +96,7 @@ def read(cls, filepath: str | Path) -> Mesh: else: mesh = cls(pv.read(filepath)) + mesh.filepath = Path(filepath).with_suffix(".vtu") return mesh @classmethod diff --git a/ogstools/meshlib/mesh_series.py b/ogstools/meshlib/mesh_series.py index 9c2c787..19d8fd4 100644 --- a/ogstools/meshlib/mesh_series.py +++ b/ogstools/meshlib/mesh_series.py @@ -12,11 +12,12 @@ from copy import copy, deepcopy from functools import partial from pathlib import Path -from typing import Any, Literal, overload +from typing import Any, Literal, cast, overload import meshio import numpy as np import pyvista as pv +from lxml import etree as ET from matplotlib import pyplot as plt from matplotlib.animation import FuncAnimation from scipy.interpolate import ( @@ -148,8 +149,8 @@ def __len__(self) -> int: return len(self.timesteps) def __iter__(self) -> Iterator[Mesh]: - for t in self.timesteps: - yield self.mesh(t) + for i in np.arange(len(self.timevalues), dtype=int): + yield self.mesh(i) def __str__(self) -> str: if self._data_type == "vtu": @@ -208,14 +209,14 @@ def ip_tesselated(self) -> MeshSeries: ip_mesh = self.mesh(0).to_ip_mesh() ip_pt_cloud = self.mesh(0).to_ip_point_cloud() ordering = ip_mesh.find_containing_cell(ip_pt_cloud.points) - for ts in self.timesteps: + for i in np.arange(len(self.timevalues), dtype=int): ip_data = { - key: self.mesh(ts).field_data[key][np.argsort(ordering)] + key: self.mesh(i).field_data[key][np.argsort(ordering)] for key in ip_mesh.cell_data } ip_mesh.cell_data.update(ip_data) ip_ms._mesh_cache[ - self.timevalues[ts] + self.timevalues[i] ] = ip_mesh.copy() # pylint: disable=protected-access ip_ms._timevalues = self._timevalues # pylint: disable=protected-access return ip_ms @@ -245,6 +246,10 @@ def mesh(self, timestep: int, lazy_eval: bool = True) -> Mesh: mesh = Mesh(self.mesh_func(pv_mesh)) if lazy_eval: self._mesh_cache[timevalue] = mesh + if self._data_type == "pvd": + mesh.filepath = Path(self.timestep_files[data_timestep]) + else: + mesh.filepath = Path(self.filepath) return mesh def rawdata_file(self) -> Path | None: @@ -290,7 +295,13 @@ def timevalues(self) -> np.ndarray: @property def timesteps(self) -> list: - """Return the timesteps of the timeseries data.""" + """ + Return the OGS simulation timesteps of the timeseries data. + Not to be confused with timevalues which returns a list of + times usually given in time units. + """ + + # TODO: read time steps from fn string if available return np.arange(len(self.timevalues), dtype=int) def _xdmf_values(self, variable_name: str) -> np.ndarray: @@ -579,16 +590,15 @@ def plot_probe( def animate( self, variable: Variable, - timesteps: Sequence | None = None, + timevalues: Sequence | None = None, plot_func: Callable[[plt.Axes, float], None] = lambda *_: None, **kwargs: Any, ) -> FuncAnimation: """ - Create an animation for a variable with given timesteps. + Create an animation for a variable with given timevalues. - :param variable: the field to be visualized on all timesteps - :param timesteps: if sequence of int: the timesteps to animate - if sequence of float: the timevalues to animate + :param variable: the field to be visualized on all timevalues + :param timevalues: the timevalues to animate :param plot_func: A function which expects to read a matplotlib Axes and the time value of the current frame. Useful to customize the plot in the animation. @@ -598,7 +608,7 @@ def animate( plot.setup.layout = "tight" plot.setup.combined_colorbar = True - ts = self.timesteps if timesteps is None else timesteps + ts = self.timevalues if timevalues is None else timevalues fig = plot.contourf(self.mesh(0, lazy_eval=False), variable) assert isinstance(fig, plt.Figure) @@ -609,14 +619,14 @@ def animate( def init() -> None: pass - def animate_func(i: int | float, fig: plt.Figure) -> None: + def animate_func(tv: float, fig: plt.Figure) -> None: fig.axes[-1].remove() # remove colorbar for ax in np.ravel(np.asarray(fig.axes)): ax.clear() - mesh = self[i] if isinstance(i, int) else self.read_interp(i, True) + mesh = self.read_interp(tv, True) with warnings.catch_warnings(): warnings.simplefilter("ignore") - plot_func(fig.axes[0], i) + plot_func(fig.axes[0], tv) plot.contourplots.draw_plot( mesh, variable, fig=fig, axes=fig.axes[0], **kwargs ) # type: ignore[assignment] @@ -834,3 +844,90 @@ def extract( ), } return self.transform(func[preference]) + + def _rename_vtufiles(self, new_pvd_fn: Path, fns: list[Path]) -> list: + fns_new: list[Path] = [] + for filename in fns: + filepathparts_at_timestep = list(filename.parts) + filepathparts_at_timestep[-1] = filepathparts_at_timestep[ + -1 + ].replace( + Path(self.filepath).name.split(".")[0], + new_pvd_fn.name.split(".")[0], + ) + fns_new.append(Path(*filepathparts_at_timestep)) + return fns_new + + def _save_vtu(self, new_pvd_fn: Path, fns: list[Path]) -> None: + for i, timestep in enumerate(self.timesteps): + if ".vtu" in fns[i].name: + pv.save_meshio( + Path(new_pvd_fn.parent, fns[i].name), self.mesh(i) + ) + elif ".xdmf" in fns[i].name: + newname = fns[i].name.replace( + ".xdmf", f"_ts_{timestep}_t_{self.timevalues[i]}.vtu" + ) + pv.save_meshio(Path(new_pvd_fn.parent, newname), self.mesh(i)) + else: + s = "File type not supported." + raise RuntimeError(s) + + def _save_pvd(self, new_pvd_fn: Path, fns: list[Path]) -> None: + root = ET.Element("VTKFile") + root.attrib["type"] = "Collection" + root.attrib["version"] = "0.1" + root.attrib["byte_order"] = "LittleEndian" + root.attrib["compressor"] = "vtkZLibDataCompressor" + collection = ET.SubElement(root, "Collection") + for i, timestep in enumerate(self.timevalues): + timestepselement = ET.SubElement(collection, "DataSet") + timestepselement.attrib["timestep"] = str(timestep) + timestepselement.attrib["group"] = "" + timestepselement.attrib["part"] = "0" + if ".xdmf" in fns[i].name: + newname = fns[i].name.replace( + ".xdmf", f"_ts_{self.timesteps[i]}_t_{timestep}.vtu" + ) + timestepselement.attrib["file"] = newname + elif ".vtu" in fns[i].name: + timestepselement.attrib["file"] = fns[i].name + else: + s = "File type not supported." + raise RuntimeError(s) + tree = ET.ElementTree(root) + tree.write( + new_pvd_fn, + encoding="ISO-8859-1", + xml_declaration=True, + pretty_print=True, + ) + + def _check_path(self, filename: Path | None) -> Path: + if not isinstance(filename, Path): + s = "filename is empty" + raise RuntimeError(s) + assert isinstance(filename, Path) + return cast(Path, filename) + + def save(self, filename: str, deep: bool = True) -> None: + """ + Save mesh series to disk. + + :param filename: Filename to save the series to. Extension specifies + the file type. Currently only PVD is supported. + :param deep: Specifies whether VTU/H5 files should be written. + """ + fn = Path(filename) + fns = [ + self._check_path(self.mesh(t).filepath) + for t in np.arange(len(self.timevalues), dtype=int) + ] + if ".pvd" in fn.name: + if deep is True: + fns = self._rename_vtufiles(fn, fns) + self._save_vtu(fn, fns) + self._save_pvd(fn, fns) + else: + s = "Currently the save method is implemented for PVD/VTU only." + raise RuntimeError(s) diff --git a/tests/test_meshlib.py b/tests/test_meshlib.py index 8208ed9..f064a72 100644 --- a/tests/test_meshlib.py +++ b/tests/test_meshlib.py @@ -7,6 +7,7 @@ import pkg_resources import pytest import pyvista as pv +from lxml import etree as ET import ogstools as ot from ogstools import examples @@ -472,3 +473,60 @@ def test_copy_shallow(self): ms_shallowcopy = ms.copy(deep=False) ms.test_var = False assert not ms_shallowcopy.test_var + + def test_save_pvd_mesh_series(self): + temp = Path(mkdtemp()) + file_name = "test.pvd" + + ms = examples.load_meshseries_HT_2D_PVD() + ms.save(Path(temp, file_name), deep=True) + ms_test = ot.MeshSeries(Path(temp, file_name)) + assert len(ms.timevalues) == len(ms_test.timevalues) + assert np.abs(ms.timevalues[1] - ms_test.timevalues[1]) < 1e-14 + for var in ["temperature", "darcy_velocity", "pressure"]: + val_ref = np.sum(ms.aggregate_over_domain(var, np.max)) + val_test = np.sum(ms_test.aggregate_over_domain(var, np.max)) + assert np.abs(val_ref - val_test) < 1e-14 + + for m in ms_test: + assert "test" in m.filepath.name + + ms.save(Path(temp, file_name), deep=False) + tree = ET.parse(Path(temp, file_name)) + num_slices = len(ms.timevalues) + num_slices_test = len(tree.findall("./Collection/DataSet")) + assert num_slices == num_slices_test + pvd_entries = tree.findall("./Collection/DataSet") + for i in range(num_slices): + assert ms[i].filepath.name == pvd_entries[i].attrib["file"] + ts = float(pvd_entries[i].attrib["timestep"]) + assert np.abs(ms.timevalues[i] - ts) < 1e-14 + + def test_save_xdmf_mesh_series(self): + temp = Path(mkdtemp()) + file_name = "test.pvd" + + ms = examples.load_meshseries_CT_2D_XDMF() + ms.save(Path(temp, file_name), deep=True) + ms_test = ot.MeshSeries(Path(temp, file_name)) + assert len(ms.timevalues) == len(ms_test.timevalues) + assert np.abs(ms.timevalues[1] - ms_test.timevalues[1]) < 1e-14 + assert ( + np.abs( + np.sum(ms.aggregate_over_domain("Si", np.max)) + - np.sum(ms_test.aggregate_over_domain("Si", np.max)) + ) + < 1e-14 + ) + for m in ms_test: + assert "test" in m.filepath.name + + ms.save(Path(temp, file_name), deep=False) + tree = ET.parse(Path(temp, file_name)) + num_slices = len(ms.timevalues) + pvd_entries = tree.findall("./Collection/DataSet") + num_slices_test = len(pvd_entries) + assert num_slices == num_slices_test + for i in range(num_slices): + ts = float(pvd_entries[i].attrib["timestep"]) + assert np.abs(ms.timevalues[i] - ts) < 1e-14