Skip to content

Commit

Permalink
Merge branch 'save_mesh_series' into 'main'
Browse files Browse the repository at this point in the history
add save_pvd method to mesh series

See merge request ogs/tools/ogstools!234
  • Loading branch information
TobiasMeisel committed Jan 23, 2025
2 parents 573a9e0 + 8ba2495 commit ba922e8
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 20 deletions.
1 change: 1 addition & 0 deletions docs/releases/ogstools-0.5.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
- MeshSeries get extract() method to select points or cells via ids
- MeshSeries can be sliced to get new MeshSeries with the selected subset of timesteps
- MeshSeries gets a modify function that applies arbitrary function to all timestep - meshes.
- MeshSeries gets a save function (only for pvd implemented yet)
- difference() between two meshes is now possible even with different topologies
- Project write_input, path can be specified
- MeshSeries gets scale() method to scale spatially or temporally
Expand Down
7 changes: 3 additions & 4 deletions ogstools/meshlib/ip_mesh.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from pathlib import Path
from tempfile import mkdtemp
from typing import TypeVar

import numpy as np
Expand Down Expand Up @@ -197,10 +196,10 @@ def to_ip_point_cloud(mesh: Mesh) -> pv.UnstructuredGrid:
for key in bad_keys:
if key in _mesh.field_data:
_mesh.field_data.remove(key)
tempdir = mkdtemp(prefix="to_ip_point_cloud")
input_file = Path(tempdir) / "ipDataToPointCloud_input.vtu"
parentpath = Path() if mesh.filepath is None else mesh.filepath.parent
input_file = parentpath / "ipDataToPointCloud_input.vtu"
_mesh.save(input_file)
output_file = Path(tempdir) / "ip_mesh.vtu"
output_file = parentpath / "ip_mesh.vtu"
ogs.cli.ipDataToPointCloud(i=str(input_file), o=str(output_file))
return pv.XMLUnstructuredGridReader(output_file).read()

Expand Down
3 changes: 3 additions & 0 deletions ogstools/meshlib/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class Mesh(pv.UnstructuredGrid):
Contains additional data and functions mainly for postprocessing.
"""

filepath: Path | None = None

# pylint: disable=C0116
@copy_method_signature(data_processing.difference)
def difference(self, *args: Any, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -94,6 +96,7 @@ def read(cls, filepath: str | Path) -> Mesh:
else:
mesh = cls(pv.read(filepath))

mesh.filepath = Path(filepath).with_suffix(".vtu")
return mesh

@classmethod
Expand Down
129 changes: 113 additions & 16 deletions ogstools/meshlib/mesh_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from copy import copy, deepcopy
from functools import partial
from pathlib import Path
from typing import Any, Literal, overload
from typing import Any, Literal, cast, overload

import meshio
import numpy as np
import pyvista as pv
from lxml import etree as ET
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
from scipy.interpolate import (
Expand Down Expand Up @@ -148,8 +149,8 @@ def __len__(self) -> int:
return len(self.timesteps)

def __iter__(self) -> Iterator[Mesh]:
for t in self.timesteps:
yield self.mesh(t)
for i in np.arange(len(self.timevalues), dtype=int):
yield self.mesh(i)

def __str__(self) -> str:
if self._data_type == "vtu":
Expand Down Expand Up @@ -208,14 +209,14 @@ def ip_tesselated(self) -> MeshSeries:
ip_mesh = self.mesh(0).to_ip_mesh()
ip_pt_cloud = self.mesh(0).to_ip_point_cloud()
ordering = ip_mesh.find_containing_cell(ip_pt_cloud.points)
for ts in self.timesteps:
for i in np.arange(len(self.timevalues), dtype=int):
ip_data = {
key: self.mesh(ts).field_data[key][np.argsort(ordering)]
key: self.mesh(i).field_data[key][np.argsort(ordering)]
for key in ip_mesh.cell_data
}
ip_mesh.cell_data.update(ip_data)
ip_ms._mesh_cache[
self.timevalues[ts]
self.timevalues[i]
] = ip_mesh.copy() # pylint: disable=protected-access
ip_ms._timevalues = self._timevalues # pylint: disable=protected-access
return ip_ms
Expand Down Expand Up @@ -245,6 +246,10 @@ def mesh(self, timestep: int, lazy_eval: bool = True) -> Mesh:
mesh = Mesh(self.mesh_func(pv_mesh))
if lazy_eval:
self._mesh_cache[timevalue] = mesh
if self._data_type == "pvd":
mesh.filepath = Path(self.timestep_files[data_timestep])
else:
mesh.filepath = Path(self.filepath)
return mesh

def rawdata_file(self) -> Path | None:
Expand Down Expand Up @@ -290,7 +295,13 @@ def timevalues(self) -> np.ndarray:

@property
def timesteps(self) -> list:
"""Return the timesteps of the timeseries data."""
"""
Return the OGS simulation timesteps of the timeseries data.
Not to be confused with timevalues which returns a list of
times usually given in time units.
"""

# TODO: read time steps from fn string if available
return np.arange(len(self.timevalues), dtype=int)

def _xdmf_values(self, variable_name: str) -> np.ndarray:
Expand Down Expand Up @@ -579,16 +590,15 @@ def plot_probe(
def animate(
self,
variable: Variable,
timesteps: Sequence | None = None,
timevalues: Sequence | None = None,
plot_func: Callable[[plt.Axes, float], None] = lambda *_: None,
**kwargs: Any,
) -> FuncAnimation:
"""
Create an animation for a variable with given timesteps.
Create an animation for a variable with given timevalues.
:param variable: the field to be visualized on all timesteps
:param timesteps: if sequence of int: the timesteps to animate
if sequence of float: the timevalues to animate
:param variable: the field to be visualized on all timevalues
:param timevalues: the timevalues to animate
:param plot_func: A function which expects to read a matplotlib Axes
and the time value of the current frame. Useful to
customize the plot in the animation.
Expand All @@ -598,7 +608,7 @@ def animate(
plot.setup.layout = "tight"
plot.setup.combined_colorbar = True

ts = self.timesteps if timesteps is None else timesteps
ts = self.timevalues if timevalues is None else timevalues

fig = plot.contourf(self.mesh(0, lazy_eval=False), variable)
assert isinstance(fig, plt.Figure)
Expand All @@ -609,14 +619,14 @@ def animate(
def init() -> None:
pass

def animate_func(i: int | float, fig: plt.Figure) -> None:
def animate_func(tv: float, fig: plt.Figure) -> None:
fig.axes[-1].remove() # remove colorbar
for ax in np.ravel(np.asarray(fig.axes)):
ax.clear()
mesh = self[i] if isinstance(i, int) else self.read_interp(i, True)
mesh = self.read_interp(tv, True)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
plot_func(fig.axes[0], i)
plot_func(fig.axes[0], tv)
plot.contourplots.draw_plot(
mesh, variable, fig=fig, axes=fig.axes[0], **kwargs
) # type: ignore[assignment]
Expand Down Expand Up @@ -834,3 +844,90 @@ def extract(
),
}
return self.transform(func[preference])

def _rename_vtufiles(self, new_pvd_fn: Path, fns: list[Path]) -> list:
fns_new: list[Path] = []
for filename in fns:
filepathparts_at_timestep = list(filename.parts)
filepathparts_at_timestep[-1] = filepathparts_at_timestep[
-1
].replace(
Path(self.filepath).name.split(".")[0],
new_pvd_fn.name.split(".")[0],
)
fns_new.append(Path(*filepathparts_at_timestep))
return fns_new

def _save_vtu(self, new_pvd_fn: Path, fns: list[Path]) -> None:
for i, timestep in enumerate(self.timesteps):
if ".vtu" in fns[i].name:
pv.save_meshio(
Path(new_pvd_fn.parent, fns[i].name), self.mesh(i)
)
elif ".xdmf" in fns[i].name:
newname = fns[i].name.replace(
".xdmf", f"_ts_{timestep}_t_{self.timevalues[i]}.vtu"
)
pv.save_meshio(Path(new_pvd_fn.parent, newname), self.mesh(i))
else:
s = "File type not supported."
raise RuntimeError(s)

def _save_pvd(self, new_pvd_fn: Path, fns: list[Path]) -> None:
root = ET.Element("VTKFile")
root.attrib["type"] = "Collection"
root.attrib["version"] = "0.1"
root.attrib["byte_order"] = "LittleEndian"
root.attrib["compressor"] = "vtkZLibDataCompressor"
collection = ET.SubElement(root, "Collection")
for i, timestep in enumerate(self.timevalues):
timestepselement = ET.SubElement(collection, "DataSet")
timestepselement.attrib["timestep"] = str(timestep)
timestepselement.attrib["group"] = ""
timestepselement.attrib["part"] = "0"
if ".xdmf" in fns[i].name:
newname = fns[i].name.replace(
".xdmf", f"_ts_{self.timesteps[i]}_t_{timestep}.vtu"
)
timestepselement.attrib["file"] = newname
elif ".vtu" in fns[i].name:
timestepselement.attrib["file"] = fns[i].name
else:
s = "File type not supported."
raise RuntimeError(s)
tree = ET.ElementTree(root)
tree.write(
new_pvd_fn,
encoding="ISO-8859-1",
xml_declaration=True,
pretty_print=True,
)

def _check_path(self, filename: Path | None) -> Path:
if not isinstance(filename, Path):
s = "filename is empty"
raise RuntimeError(s)
assert isinstance(filename, Path)
return cast(Path, filename)

def save(self, filename: str, deep: bool = True) -> None:
"""
Save mesh series to disk.
:param filename: Filename to save the series to. Extension specifies
the file type. Currently only PVD is supported.
:param deep: Specifies whether VTU/H5 files should be written.
"""
fn = Path(filename)
fns = [
self._check_path(self.mesh(t).filepath)
for t in np.arange(len(self.timevalues), dtype=int)
]
if ".pvd" in fn.name:
if deep is True:
fns = self._rename_vtufiles(fn, fns)
self._save_vtu(fn, fns)
self._save_pvd(fn, fns)
else:
s = "Currently the save method is implemented for PVD/VTU only."
raise RuntimeError(s)
58 changes: 58 additions & 0 deletions tests/test_meshlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pkg_resources
import pytest
import pyvista as pv
from lxml import etree as ET

import ogstools as ot
from ogstools import examples
Expand Down Expand Up @@ -472,3 +473,60 @@ def test_copy_shallow(self):
ms_shallowcopy = ms.copy(deep=False)
ms.test_var = False
assert not ms_shallowcopy.test_var

def test_save_pvd_mesh_series(self):
temp = Path(mkdtemp())
file_name = "test.pvd"

ms = examples.load_meshseries_HT_2D_PVD()
ms.save(Path(temp, file_name), deep=True)
ms_test = ot.MeshSeries(Path(temp, file_name))
assert len(ms.timevalues) == len(ms_test.timevalues)
assert np.abs(ms.timevalues[1] - ms_test.timevalues[1]) < 1e-14
for var in ["temperature", "darcy_velocity", "pressure"]:
val_ref = np.sum(ms.aggregate_over_domain(var, np.max))
val_test = np.sum(ms_test.aggregate_over_domain(var, np.max))
assert np.abs(val_ref - val_test) < 1e-14

for m in ms_test:
assert "test" in m.filepath.name

ms.save(Path(temp, file_name), deep=False)
tree = ET.parse(Path(temp, file_name))
num_slices = len(ms.timevalues)
num_slices_test = len(tree.findall("./Collection/DataSet"))
assert num_slices == num_slices_test
pvd_entries = tree.findall("./Collection/DataSet")
for i in range(num_slices):
assert ms[i].filepath.name == pvd_entries[i].attrib["file"]
ts = float(pvd_entries[i].attrib["timestep"])
assert np.abs(ms.timevalues[i] - ts) < 1e-14

def test_save_xdmf_mesh_series(self):
temp = Path(mkdtemp())
file_name = "test.pvd"

ms = examples.load_meshseries_CT_2D_XDMF()
ms.save(Path(temp, file_name), deep=True)
ms_test = ot.MeshSeries(Path(temp, file_name))
assert len(ms.timevalues) == len(ms_test.timevalues)
assert np.abs(ms.timevalues[1] - ms_test.timevalues[1]) < 1e-14
assert (
np.abs(
np.sum(ms.aggregate_over_domain("Si", np.max))
- np.sum(ms_test.aggregate_over_domain("Si", np.max))
)
< 1e-14
)
for m in ms_test:
assert "test" in m.filepath.name

ms.save(Path(temp, file_name), deep=False)
tree = ET.parse(Path(temp, file_name))
num_slices = len(ms.timevalues)
pvd_entries = tree.findall("./Collection/DataSet")
num_slices_test = len(pvd_entries)
assert num_slices == num_slices_test
for i in range(num_slices):
ts = float(pvd_entries[i].attrib["timestep"])
assert np.abs(ms.timevalues[i] - ts) < 1e-14

0 comments on commit ba922e8

Please sign in to comment.