Skip to content

Commit 9d59593

Browse files
authored
Merge pull request #4603 from kratman/rel/v24.11.1
Release v24.11.1
2 parents 5b1ef70 + 6ffdf1d commit 9d59593

File tree

8 files changed

+150
-121
lines changed

8 files changed

+150
-121
lines changed

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
# [Unreleased](https://github.com/pybamm-team/PyBaMM/)
22

3+
# [v24.11.1](https://github.com/pybamm-team/PyBaMM/tree/v24.11.1) - 2024-11-22
4+
5+
## Features
6+
7+
- Modified `quick_plot.plot` to accept a list of times and generate superimposed graphs for specified time points. ([#4529](https://github.com/pybamm-team/PyBaMM/pull/4529))
8+
9+
## Bug Fixes
10+
11+
- Added some dependencies which were left out of the `pyproject.toml` file ([#4602](https://github.com/pybamm-team/PyBaMM/pull/4602))
12+
313
# [v24.11.0](https://github.com/pybamm-team/PyBaMM/tree/v24.11.0) - 2024-11-20
414

515
## Features

CITATION.cff

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@ keywords:
2424
- "expression tree"
2525
- "python"
2626
- "symbolic differentiation"
27-
version: "24.11.0"
27+
version: "24.11.1"
2828
repository-code: "https://github.com/pybamm-team/PyBaMM"
2929
title: "Python Battery Mathematical Modelling (PyBaMM)"

noxfile.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,11 @@ def run_tests(session):
252252
set_environment_variables(PYBAMM_ENV, session=session)
253253
session.install("setuptools", silent=False)
254254
session.install("-e", ".[all,dev,jax]", silent=False)
255-
specific_test_files = session.posargs if session.posargs else []
256255
session.run(
257-
"python", "-m", "pytest", *specific_test_files, "-m", "unit or integration"
256+
"python",
257+
"-m",
258+
"pytest",
259+
*(session.posargs if session.posargs else ["-m", "unit or integration"]),
258260
)
259261

260262

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
1313

1414
[project]
1515
name = "pybamm"
16-
version = "24.11.0"
16+
version = "24.11.1"
1717
license = { file = "LICENSE.txt" }
1818
description = "Python Battery Mathematical Modelling"
1919
authors = [{name = "The PyBaMM Team", email = "[email protected]"}]
@@ -45,6 +45,8 @@ dependencies = [
4545
"pandas>=1.5.0",
4646
"pooch>=1.8.1",
4747
"posthog",
48+
"pyyaml",
49+
"platformdirs",
4850
]
4951

5052
[project.urls]

src/pybamm/plotting/quick_plot.py

Lines changed: 95 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#
22
# Class for quick plotting of variables from models
33
#
4+
from __future__ import annotations
45
import os
56
import numpy as np
67
import pybamm
@@ -479,24 +480,24 @@ def reset_axis(self):
479480
): # pragma: no cover
480481
raise ValueError(f"Axis limits cannot be NaN for variables '{key}'")
481482

482-
def plot(self, t, dynamic=False):
483+
def plot(self, t: float | list[float], dynamic: bool = False):
483484
"""Produces a quick plot with the internal states at time t.
484485
485486
Parameters
486487
----------
487-
t : float
488-
Dimensional time (in 'time_units') at which to plot.
488+
t : float or list of float
489+
Dimensional time (in 'time_units') at which to plot. Can be a single time or a list of times.
489490
dynamic : bool, optional
490491
Determine whether to allocate space for a slider at the bottom of the plot when generating a dynamic plot.
491492
If True, creates a dynamic plot with a slider.
492493
"""
493494

494495
plt = import_optional_dependency("matplotlib.pyplot")
495496
gridspec = import_optional_dependency("matplotlib.gridspec")
496-
cm = import_optional_dependency("matplotlib", "cm")
497-
colors = import_optional_dependency("matplotlib", "colors")
498497

499-
t_in_seconds = t * self.time_scaling_factor
498+
if not isinstance(t, list):
499+
t = [t]
500+
500501
self.fig = plt.figure(figsize=self.figsize)
501502

502503
self.gridspec = gridspec.GridSpec(self.n_rows, self.n_cols)
@@ -508,6 +509,11 @@ def plot(self, t, dynamic=False):
508509
# initialize empty handles, to be created only if the appropriate plots are made
509510
solution_handles = []
510511

512+
# Generate distinct colors for each time point
513+
time_colors = plt.cm.coolwarm(
514+
np.linspace(0, 1, len(t))
515+
) # Use a colormap for distinct colors
516+
511517
for k, (key, variable_lists) in enumerate(self.variables.items()):
512518
ax = self.fig.add_subplot(self.gridspec[k])
513519
self.axes.add(key, ax)
@@ -518,19 +524,17 @@ def plot(self, t, dynamic=False):
518524
ax.xaxis.set_major_locator(plt.MaxNLocator(3))
519525
self.plots[key] = defaultdict(dict)
520526
variable_handles = []
521-
# Set labels for the first subplot only (avoid repetition)
527+
522528
if variable_lists[0][0].dimensions == 0:
523-
# 0D plot: plot as a function of time, indicating time t with a line
529+
# 0D plot: plot as a function of time, indicating multiple times with lines
524530
ax.set_xlabel(f"Time [{self.time_unit}]")
525531
for i, variable_list in enumerate(variable_lists):
526532
for j, variable in enumerate(variable_list):
527-
if len(variable_list) == 1:
528-
# single variable -> use linestyle to differentiate model
529-
linestyle = self.linestyles[i]
530-
else:
531-
# multiple variables -> use linestyle to differentiate
532-
# variables (color differentiates models)
533-
linestyle = self.linestyles[j]
533+
linestyle = (
534+
self.linestyles[i]
535+
if len(variable_list) == 1
536+
else self.linestyles[j]
537+
)
534538
full_t = self.ts_seconds[i]
535539
(self.plots[key][i][j],) = ax.plot(
536540
full_t / self.time_scaling_factor,
@@ -542,128 +546,104 @@ def plot(self, t, dynamic=False):
542546
solution_handles.append(self.plots[key][i][0])
543547
y_min, y_max = ax.get_ylim()
544548
ax.set_ylim(y_min, y_max)
545-
(self.time_lines[key],) = ax.plot(
546-
[
547-
t_in_seconds / self.time_scaling_factor,
548-
t_in_seconds / self.time_scaling_factor,
549-
],
550-
[y_min, y_max],
551-
"k--",
552-
lw=1.5,
553-
)
549+
550+
# Add vertical lines for each time in the list, using different colors for each time
551+
for idx, t_single in enumerate(t):
552+
t_in_seconds = t_single * self.time_scaling_factor
553+
(self.time_lines[key],) = ax.plot(
554+
[
555+
t_in_seconds / self.time_scaling_factor,
556+
t_in_seconds / self.time_scaling_factor,
557+
],
558+
[y_min, y_max],
559+
"--", # Dashed lines
560+
lw=1.5,
561+
color=time_colors[idx], # Different color for each time
562+
label=f"t = {t_single:.2f} {self.time_unit}",
563+
)
564+
ax.legend()
565+
554566
elif variable_lists[0][0].dimensions == 1:
555-
# 1D plot: plot as a function of x at time t
556-
# Read dictionary of spatial variables
567+
# 1D plot: plot as a function of x at different times
557568
spatial_vars = self.spatial_variable_dict[key]
558569
spatial_var_name = next(iter(spatial_vars.keys()))
559-
ax.set_xlabel(
560-
f"{spatial_var_name} [{self.spatial_unit}]",
561-
)
562-
for i, variable_list in enumerate(variable_lists):
563-
for j, variable in enumerate(variable_list):
564-
if len(variable_list) == 1:
565-
# single variable -> use linestyle to differentiate model
566-
linestyle = self.linestyles[i]
567-
else:
568-
# multiple variables -> use linestyle to differentiate
569-
# variables (color differentiates models)
570-
linestyle = self.linestyles[j]
571-
(self.plots[key][i][j],) = ax.plot(
572-
self.first_spatial_variable[key],
573-
variable(t_in_seconds, **spatial_vars),
574-
color=self.colors[i],
575-
linestyle=linestyle,
576-
zorder=10,
577-
)
578-
variable_handles.append(self.plots[key][0][j])
579-
solution_handles.append(self.plots[key][i][0])
580-
# add lines for boundaries between subdomains
581-
for boundary in variable_lists[0][0].internal_boundaries:
582-
boundary_scaled = boundary * self.spatial_factor
583-
ax.axvline(boundary_scaled, color="0.5", lw=1, zorder=0)
570+
ax.set_xlabel(f"{spatial_var_name} [{self.spatial_unit}]")
571+
572+
for idx, t_single in enumerate(t):
573+
t_in_seconds = t_single * self.time_scaling_factor
574+
575+
for i, variable_list in enumerate(variable_lists):
576+
for j, variable in enumerate(variable_list):
577+
linestyle = (
578+
self.linestyles[i]
579+
if len(variable_list) == 1
580+
else self.linestyles[j]
581+
)
582+
(self.plots[key][i][j],) = ax.plot(
583+
self.first_spatial_variable[key],
584+
variable(t_in_seconds, **spatial_vars),
585+
color=time_colors[idx], # Different color for each time
586+
linestyle=linestyle,
587+
label=f"t = {t_single:.2f} {self.time_unit}", # Add time label
588+
zorder=10,
589+
)
590+
variable_handles.append(self.plots[key][0][j])
591+
solution_handles.append(self.plots[key][i][0])
592+
593+
# Add a legend to indicate which plot corresponds to which time
594+
ax.legend()
595+
584596
elif variable_lists[0][0].dimensions == 2:
585-
# Read dictionary of spatial variables
597+
# 2D plot: superimpose plots at different times
586598
spatial_vars = self.spatial_variable_dict[key]
587-
# there can only be one entry in the variable list
588599
variable = variable_lists[0][0]
589-
# different order based on whether the domains are x-r, x-z or y-z, etc
590-
if self.x_first_and_y_second[key] is False:
591-
x_name = list(spatial_vars.keys())[1][0]
592-
y_name = next(iter(spatial_vars.keys()))[0]
593-
x = self.second_spatial_variable[key]
594-
y = self.first_spatial_variable[key]
595-
var = variable(t_in_seconds, **spatial_vars)
596-
else:
597-
x_name = next(iter(spatial_vars.keys()))[0]
598-
y_name = list(spatial_vars.keys())[1][0]
600+
601+
for t_single in t:
602+
t_in_seconds = t_single * self.time_scaling_factor
599603
x = self.first_spatial_variable[key]
600604
y = self.second_spatial_variable[key]
601605
var = variable(t_in_seconds, **spatial_vars).T
602-
ax.set_xlabel(f"{x_name} [{self.spatial_unit}]")
603-
ax.set_ylabel(f"{y_name} [{self.spatial_unit}]")
604-
vmin, vmax = self.variable_limits[key]
605-
# store the plot and the var data (for testing) as cant access
606-
# z data from QuadMesh or QuadContourSet object
607-
if self.is_y_z[key] is True:
608-
self.plots[key][0][0] = ax.pcolormesh(
609-
x,
610-
y,
611-
var,
612-
vmin=vmin,
613-
vmax=vmax,
614-
shading=self.shading,
606+
607+
ax.set_xlabel(
608+
f"{next(iter(spatial_vars.keys()))[0]} [{self.spatial_unit}]"
615609
)
616-
else:
617-
self.plots[key][0][0] = ax.contourf(
618-
x, y, var, levels=100, vmin=vmin, vmax=vmax
610+
ax.set_ylabel(
611+
f"{list(spatial_vars.keys())[1][0]} [{self.spatial_unit}]"
619612
)
620-
self.plots[key][0][1] = var
621-
if vmin is None and vmax is None:
622-
vmin = ax_min(var)
623-
vmax = ax_max(var)
624-
self.colorbars[key] = self.fig.colorbar(
625-
cm.ScalarMappable(colors.Normalize(vmin=vmin, vmax=vmax)),
626-
ax=ax,
627-
)
628-
# Set either y label or legend entries
629-
if len(key) == 1:
630-
title = split_long_string(key[0])
631-
ax.set_title(title, fontsize="medium")
632-
else:
633-
ax.legend(
634-
variable_handles,
635-
[split_long_string(s, 6) for s in key],
636-
bbox_to_anchor=(0.5, 1),
637-
loc="lower center",
638-
)
613+
vmin, vmax = self.variable_limits[key]
614+
615+
# Use contourf and colorbars to represent the values
616+
contour_plot = ax.contourf(
617+
x, y, var, levels=100, vmin=vmin, vmax=vmax, cmap="coolwarm"
618+
)
619+
self.plots[key][0][0] = contour_plot
620+
self.colorbars[key] = self.fig.colorbar(contour_plot, ax=ax)
639621

640-
# Set global legend
622+
self.plots[key][0][1] = var
623+
624+
ax.set_title(f"t = {t_single:.2f} {self.time_unit}")
625+
626+
# Set global legend if there are multiple models
641627
if len(self.labels) > 1:
642628
fig_legend = self.fig.legend(
643629
solution_handles, self.labels, loc="lower right"
644630
)
645-
# Get the position of the top of the legend in relative figure units
646-
# There may be a better way ...
647-
try:
648-
legend_top_inches = fig_legend.get_window_extent(
649-
renderer=self.fig.canvas.get_renderer()
650-
).get_points()[1, 1]
651-
fig_height_inches = (self.fig.get_size_inches() * self.fig.dpi)[1]
652-
legend_top = legend_top_inches / fig_height_inches
653-
except AttributeError: # pragma: no cover
654-
# When testing the examples we set the matplotlib backend to "Template"
655-
# which means that the above code doesn't work. Since this is just for
656-
# that particular test we can just skip it
657-
legend_top = 0
658631
else:
659-
legend_top = 0
632+
fig_legend = None
660633

661-
# Fix layout
634+
# Fix layout for sliders if dynamic
662635
if dynamic:
663636
slider_top = 0.05
664637
else:
665638
slider_top = 0
666-
bottom = max(legend_top, slider_top)
639+
bottom = max(
640+
fig_legend.get_window_extent(
641+
renderer=self.fig.canvas.get_renderer()
642+
).get_points()[1, 1]
643+
if fig_legend
644+
else 0,
645+
slider_top,
646+
)
667647
self.gridspec.tight_layout(self.fig, rect=[0, bottom, 1, 1])
668648

669649
def dynamic_plot(self, show_plot=True, step=None):

src/pybamm/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "24.11.0"
1+
__version__ = "24.11.1"

tests/unit/test_plotting/test_quick_plot.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,41 @@ def test_simple_ode_model(self, solver):
252252
solution, ["a", "b broadcasted"], variable_limits="bad variable limits"
253253
)
254254

255+
# Test with a list of times
256+
# Test for a 0D variable
257+
quick_plot = pybamm.QuickPlot(solution, ["a"])
258+
quick_plot.plot(t=[0, 1, 2])
259+
assert len(quick_plot.plots[("a",)]) == 1
260+
261+
# Test for a 1D variable
262+
quick_plot = pybamm.QuickPlot(solution, ["c broadcasted"])
263+
264+
time_list = [0.5, 1.5, 2.5]
265+
quick_plot.plot(time_list)
266+
267+
ax = quick_plot.fig.axes[0]
268+
lines = ax.get_lines()
269+
270+
variable_key = ("c broadcasted",)
271+
if variable_key in quick_plot.variables:
272+
num_variables = len(quick_plot.variables[variable_key][0])
273+
else:
274+
raise KeyError(
275+
f"'{variable_key}' is not in quick_plot.variables. Available keys: "
276+
+ str(quick_plot.variables.keys())
277+
)
278+
279+
expected_lines = len(time_list) * num_variables
280+
281+
assert (
282+
len(lines) == expected_lines
283+
), f"Expected {expected_lines} superimposed lines, but got {len(lines)}"
284+
285+
# Test for a 2D variable
286+
quick_plot = pybamm.QuickPlot(solution, ["2D variable"])
287+
quick_plot.plot(t=[0, 1, 2])
288+
assert len(quick_plot.plots[("2D variable",)]) == 1
289+
255290
# Test errors
256291
with pytest.raises(ValueError, match="Mismatching variable domains"):
257292
pybamm.QuickPlot(solution, [["a", "b broadcasted"]])

vcpkg.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "pybamm",
3-
"version-string": "24.11.0",
3+
"version-string": "24.11.1",
44
"dependencies": [
55
"casadi",
66
{

0 commit comments

Comments
 (0)