Skip to content

Commit

Permalink
Merge branch 'lineplot_API' into 'main'
Browse files Browse the repository at this point in the history
[plot] lineplot API same as matplotlib.pyplot.plot

See merge request ogs/tools/ogstools!241
  • Loading branch information
TobiasMeisel committed Jan 23, 2025
2 parents 1ccdf3c + f2dbec1 commit 573a9e0
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 39 deletions.
2 changes: 1 addition & 1 deletion docs/examples/howto_plot/plot_timeslice.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
color = str(0.8 * timevalue / mesh_series.timevalues[-1])
label = f"{timevalue:.1f} a"
fig = ot.plot.line(
sample, "z", si, ax=ax, label=label, color=color, fontsize=20
sample, si, "z", ax=ax, label=label, color=color, fontsize=20
)
# %% [markdown]
# As the above kind of plot is getting cluttered for lots of timesteps we
Expand Down
24 changes: 11 additions & 13 deletions docs/examples/howto_postprocessing/plot_sample_mesh_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,18 @@
# %% [markdown]
# Simple case: straight line
# ==========================
# We use the ``pyvista`` function ``sample_over_line`` and use two points to define
# the line and get a Mesh with the sampled data. Let's plot the Mesh and the
# line together.
# We use the ``pyvista`` function ``sample_over_line`` and use two points to
# define the line and get a Mesh with the sampled data. Let's plot the Mesh and
# the line together.

# %%
sample = mesh.sample_over_line([25, -460, 0], [100, -800, 0])
fig = ot.plot.contourf(mesh, ot.variables.temperature)
fig = ot.plot.line(
mesh=sample, y_var="y", x_var="x", ax=fig.axes[0], linestyle="--"
)
fig = mesh.plot_contourf(ot.variables.temperature)
fig = ot.plot.line(sample, ax=fig.axes[0], linestyle="--")

# %% [markdown]
# Now we plot the temperature data. The spatial coordinate for the x-axis is
# automatically detected here by not passing ``x_var`` explicitly.
# automatically detected here.

# %%
fig = ot.plot.line(sample, ot.variables.temperature)
Expand All @@ -67,8 +65,8 @@
center=[150, -650, 0],
)
fig, axs = plt.subplots(ncols=2, figsize=[26, 10])
ot.plot.contourf(mesh, ot.variables.displacement["x"], fig=fig, ax=axs[1])
ot.plot.line(sample, "y", "x", axs[1], linewidth="8", color="red")
mesh.plot_contourf(ot.variables.displacement["x"], fig=fig, ax=axs[1])
ot.plot.line(sample, ax=axs[1], linewidth="8", color="red")
ot.plot.line(sample, ot.variables.displacement["x"], ax=axs[0])
fig.tight_layout()

Expand Down Expand Up @@ -132,11 +130,11 @@
# %%
fig, axs = plt.subplots(ncols=2, figsize=[26, 10])
u_x = ot.variables.displacement["x"]
ot.plot.contourf(mesh, u_x, fig=fig, ax=axs[1])
mesh.plot_contourf(u_x, fig=fig, ax=axs[1])
for i, sample in enumerate([sample_1, sample_2, sample_3, sample_4]):
c = f"C{i}" # cycle through default color cycle
ot.plot.line(sample, "y", "x", ax=axs[1], linestyle="--", color=c)
ot.plot.line(sample, "y", u_x, ax=axs[0], label=f"sample {i + 1}", color=c)
ot.plot.line(sample, ax=axs[1], linestyle="--", color=c)
ot.plot.line(sample, u_x, "y", ax=axs[0], label=f"sample {i + 1}", color=c)
fig.tight_layout()

# %% [markdown]
Expand Down
60 changes: 41 additions & 19 deletions ogstools/plot/lineplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,49 +9,71 @@


def line(
mesh: pv.UnstructuredGrid,
y_var: str | Variable,
x_var: str | Variable | None = None,
mesh: pv.DataSet,
var1: str | Variable | None = None,
var2: str | Variable | None = None,
ax: plt.Axes | None = None,
**kwargs: Any,
) -> plt.Figure | None:
"""Plot some data of a (1D) mesh.
You can pass "x", "y" or "z" to either of x_var or y_var to specify which
spatial dimension should be used for the corresponding axis. You can also
pass two data variables for a phase plot.
pass two data variables for a phase plot. if no value is given, automatic
detection of spatial axis is tried.
>>> line(mesh) # z=const: y over x, y=const: z over x, x=const: z over y
>>> line(mesh, ot.variables.temperature) # temperature over x, y or z
>>> line(mesh, "y", "temperature") # temperature over y
>>> line(mesh, ot.variables.pressure, "y") # y over pressure
>>> line(mesh, "pressure", "temperature") # temperature over pressure
:param mesh: The mesh which contains the data to plot
:param y_var: The variable to use for the y-axis
:param x_var: The variable to use for the x-axis, if None automatic
detection of spatial axis is tried.
:param ax: The matplotlib axis to use for plotting, if None creates a
new figure.
:param var1: Variable for the x-axis if var2 is given else for y-axis.
:param var2: Variable for the y-axis if var1 is given.
:param ax: The matplotlib axis to use for plotting, if None a new
figure will be created.
:Keyword Arguments:
- figsize: figure size
- figsize: figure size (default=[16, 10])
- color: color of the line
- linewidth: width of the line
- linestyle: style of the line
- label: label in the legend
- grid: if True, show grid
- all other kwargs get passed to matplotlib's plot function
"""
if isinstance(var1, plt.Axes) or isinstance(var2, plt.Axes):
msg = "Please provide ax as keyword argument only!"
raise TypeError(msg)

figsize = kwargs.pop("figsize", [16, 10])
ax_ = plt.subplots(figsize=figsize)[1] if ax is None else ax

if x_var is None:
non_flat_axis = np.argwhere(
np.invert(np.all(np.isclose(mesh.points, mesh.points[0]), axis=0))
).ravel()
x_var = "xyz"[non_flat_axis[0]]

x_var = get_preset(x_var, mesh).magnitude
y_var = get_preset(y_var, mesh).magnitude
axes_idx = np.argwhere(
np.invert(np.all(np.isclose(mesh.points, mesh.points[0]), axis=0))
).ravel()
if len(axes_idx) == 0:
axes_idx = [0, 1]

match var1, var2:
case None, None:
if len(axes_idx) == 1:
axes_idx = [0, axes_idx[0] if axes_idx[0] != 0 else 1]
x_var = get_preset("xyz"[axes_idx[0]], mesh).magnitude
y_var = get_preset("xyz"[axes_idx[1]], mesh).magnitude
case var1, None:
x_var = get_preset("xyz"[axes_idx[0]], mesh).magnitude
y_var = get_preset(var1, mesh).magnitude # type: ignore[arg-type]
case None, var2:
x_var = get_preset("xyz"[axes_idx[0]], mesh).magnitude
y_var = get_preset(var2, mesh).magnitude # type: ignore[arg-type]
case var1, var2:
x_var = get_preset(var1, mesh).magnitude # type: ignore[arg-type]
y_var = get_preset(var2, mesh).magnitude # type: ignore[arg-type]

kwargs.setdefault("color", y_var.color)
pure_spatial = y_var.data_name in "xyz" and x_var.data_name in "xyz"
lw_scale = 5 if pure_spatial else 3
lw_scale = 4 if pure_spatial else 2.5
kwargs.setdefault("linewidth", setup.linewidth * lw_scale)
fontsize = kwargs.pop("fontsize", setup.fontsize)
show_grid = kwargs.pop("grid", True) and not pure_spatial
Expand Down
42 changes: 36 additions & 6 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,42 @@ def test_xdmf_with_slices(self):

def test_lineplot(self):
"""Test creation of a linesplot from sampled profile data"""
mesh = examples.load_meshseries_HT_2D_XDMF().mesh(-1)
profile_HT = mesh.sample_over_line([4, 2, 0], [4, 18, 0])
ot.plot.setup.set_units(spatial="km", time="a")
fig = ot.plot.line(profile_HT, "pressure")
ot.plot.line(profile_HT, ot.variables.pressure, "x", ax=fig.axes[0])
fig = ot.plot.line(
profile_HT, "y", "x", figsize=[5, 5], color="g", linewidth=1,
mesh = examples.load_meshseries_THM_2D_PVD().mesh(-1)
mesh.points[:, 2] = 0.0
x1, x2, y1, y2 = mesh.bounds[:4]
xc, yc, z = mesh.center
sample_x = mesh.sample_over_line([x1, yc, z], [x2, yc, z])
sample_y = mesh.sample_over_line([xc, y1, z], [xc, y2, z])
sample_xy = mesh.sample_over_line([x1, y1, z], [x2, y2, z])
sample_xz = mesh.rotate_x(90).sample_over_line([x1, 0, y1], [x2, 0, y2])
sample_yz = mesh.rotate_y(90).sample_over_line([0, y1, x1], [0, y2, x2])

def check(*args, x_l: str, y_l: str) -> None:
fig = ot.plot.line(*args, figsize=[4, 3])
assert fig.axes[0].get_xlabel().split(" ")[0] == x_l
assert fig.axes[0].get_ylabel().split(" ")[0] == y_l

check(sample_x, ot.variables.temperature, x_l="x", y_l="temperature")
check(sample_x, x_l="x", y_l="y")
check(sample_y, ot.variables.temperature, x_l="y", y_l="temperature")
check(sample_y, x_l="x", y_l="y")
check(sample_xy, ot.variables.temperature, x_l="x", y_l="temperature")
check(sample_xy, x_l="x", y_l="y")
check(sample_xz, ot.variables.temperature, x_l="x", y_l="temperature")
check(sample_xz, x_l="x", y_l="z")
check(sample_yz, ot.variables.temperature, x_l="y", y_l="temperature")
check(sample_yz, x_l="y", y_l="z")
check(sample_yz, "z", "y", x_l="z", y_l="y")
check(sample_x, "x", "temperature", x_l="x", y_l="temperature")
check(sample_y, "temperature", "y", x_l="temperature", y_l="y")
check(sample_xy, ot.variables.displacement, ot.variables.temperature,
x_l="displacement", y_l="temperature") # fmt: skip
_, ax = plt.subplots(figsize=[4, 3])
ot.plot.line(sample_y, ot.variables.pressure, "x", ax=ax)
_ = ot.plot.line(
sample_y, "y", "x", figsize=[5, 5], color="g", linewidth=1,
ls="--", label="test", grid=True,
) # fmt: skip
with pytest.raises(TypeError):
ot.plot.line(sample_y, ax)

0 comments on commit 573a9e0

Please sign in to comment.