1
1
#
2
2
# Class for quick plotting of variables from models
3
3
#
4
+ from __future__ import annotations
4
5
import os
5
6
import numpy as np
6
7
import pybamm
@@ -479,24 +480,24 @@ def reset_axis(self):
479
480
): # pragma: no cover
480
481
raise ValueError (f"Axis limits cannot be NaN for variables '{ key } '" )
481
482
482
- def plot (self , t , dynamic = False ):
483
+ def plot (self , t : float | list [ float ] , dynamic : bool = False ):
483
484
"""Produces a quick plot with the internal states at time t.
484
485
485
486
Parameters
486
487
----------
487
- t : float
488
- Dimensional time (in 'time_units') at which to plot.
488
+ t : float or list of float
489
+ Dimensional time (in 'time_units') at which to plot. Can be a single time or a list of times.
489
490
dynamic : bool, optional
490
491
Determine whether to allocate space for a slider at the bottom of the plot when generating a dynamic plot.
491
492
If True, creates a dynamic plot with a slider.
492
493
"""
493
494
494
495
plt = import_optional_dependency ("matplotlib.pyplot" )
495
496
gridspec = import_optional_dependency ("matplotlib.gridspec" )
496
- cm = import_optional_dependency ("matplotlib" , "cm" )
497
- colors = import_optional_dependency ("matplotlib" , "colors" )
498
497
499
- t_in_seconds = t * self .time_scaling_factor
498
+ if not isinstance (t , list ):
499
+ t = [t ]
500
+
500
501
self .fig = plt .figure (figsize = self .figsize )
501
502
502
503
self .gridspec = gridspec .GridSpec (self .n_rows , self .n_cols )
@@ -508,6 +509,11 @@ def plot(self, t, dynamic=False):
508
509
# initialize empty handles, to be created only if the appropriate plots are made
509
510
solution_handles = []
510
511
512
+ # Generate distinct colors for each time point
513
+ time_colors = plt .cm .coolwarm (
514
+ np .linspace (0 , 1 , len (t ))
515
+ ) # Use a colormap for distinct colors
516
+
511
517
for k , (key , variable_lists ) in enumerate (self .variables .items ()):
512
518
ax = self .fig .add_subplot (self .gridspec [k ])
513
519
self .axes .add (key , ax )
@@ -518,19 +524,17 @@ def plot(self, t, dynamic=False):
518
524
ax .xaxis .set_major_locator (plt .MaxNLocator (3 ))
519
525
self .plots [key ] = defaultdict (dict )
520
526
variable_handles = []
521
- # Set labels for the first subplot only (avoid repetition)
527
+
522
528
if variable_lists [0 ][0 ].dimensions == 0 :
523
- # 0D plot: plot as a function of time, indicating time t with a line
529
+ # 0D plot: plot as a function of time, indicating multiple times with lines
524
530
ax .set_xlabel (f"Time [{ self .time_unit } ]" )
525
531
for i , variable_list in enumerate (variable_lists ):
526
532
for j , variable in enumerate (variable_list ):
527
- if len (variable_list ) == 1 :
528
- # single variable -> use linestyle to differentiate model
529
- linestyle = self .linestyles [i ]
530
- else :
531
- # multiple variables -> use linestyle to differentiate
532
- # variables (color differentiates models)
533
- linestyle = self .linestyles [j ]
533
+ linestyle = (
534
+ self .linestyles [i ]
535
+ if len (variable_list ) == 1
536
+ else self .linestyles [j ]
537
+ )
534
538
full_t = self .ts_seconds [i ]
535
539
(self .plots [key ][i ][j ],) = ax .plot (
536
540
full_t / self .time_scaling_factor ,
@@ -542,128 +546,104 @@ def plot(self, t, dynamic=False):
542
546
solution_handles .append (self .plots [key ][i ][0 ])
543
547
y_min , y_max = ax .get_ylim ()
544
548
ax .set_ylim (y_min , y_max )
545
- (self .time_lines [key ],) = ax .plot (
546
- [
547
- t_in_seconds / self .time_scaling_factor ,
548
- t_in_seconds / self .time_scaling_factor ,
549
- ],
550
- [y_min , y_max ],
551
- "k--" ,
552
- lw = 1.5 ,
553
- )
549
+
550
+ # Add vertical lines for each time in the list, using different colors for each time
551
+ for idx , t_single in enumerate (t ):
552
+ t_in_seconds = t_single * self .time_scaling_factor
553
+ (self .time_lines [key ],) = ax .plot (
554
+ [
555
+ t_in_seconds / self .time_scaling_factor ,
556
+ t_in_seconds / self .time_scaling_factor ,
557
+ ],
558
+ [y_min , y_max ],
559
+ "--" , # Dashed lines
560
+ lw = 1.5 ,
561
+ color = time_colors [idx ], # Different color for each time
562
+ label = f"t = { t_single :.2f} { self .time_unit } " ,
563
+ )
564
+ ax .legend ()
565
+
554
566
elif variable_lists [0 ][0 ].dimensions == 1 :
555
- # 1D plot: plot as a function of x at time t
556
- # Read dictionary of spatial variables
567
+ # 1D plot: plot as a function of x at different times
557
568
spatial_vars = self .spatial_variable_dict [key ]
558
569
spatial_var_name = next (iter (spatial_vars .keys ()))
559
- ax .set_xlabel (
560
- f"{ spatial_var_name } [{ self .spatial_unit } ]" ,
561
- )
562
- for i , variable_list in enumerate (variable_lists ):
563
- for j , variable in enumerate (variable_list ):
564
- if len (variable_list ) == 1 :
565
- # single variable -> use linestyle to differentiate model
566
- linestyle = self .linestyles [i ]
567
- else :
568
- # multiple variables -> use linestyle to differentiate
569
- # variables (color differentiates models)
570
- linestyle = self .linestyles [j ]
571
- (self .plots [key ][i ][j ],) = ax .plot (
572
- self .first_spatial_variable [key ],
573
- variable (t_in_seconds , ** spatial_vars ),
574
- color = self .colors [i ],
575
- linestyle = linestyle ,
576
- zorder = 10 ,
577
- )
578
- variable_handles .append (self .plots [key ][0 ][j ])
579
- solution_handles .append (self .plots [key ][i ][0 ])
580
- # add lines for boundaries between subdomains
581
- for boundary in variable_lists [0 ][0 ].internal_boundaries :
582
- boundary_scaled = boundary * self .spatial_factor
583
- ax .axvline (boundary_scaled , color = "0.5" , lw = 1 , zorder = 0 )
570
+ ax .set_xlabel (f"{ spatial_var_name } [{ self .spatial_unit } ]" )
571
+
572
+ for idx , t_single in enumerate (t ):
573
+ t_in_seconds = t_single * self .time_scaling_factor
574
+
575
+ for i , variable_list in enumerate (variable_lists ):
576
+ for j , variable in enumerate (variable_list ):
577
+ linestyle = (
578
+ self .linestyles [i ]
579
+ if len (variable_list ) == 1
580
+ else self .linestyles [j ]
581
+ )
582
+ (self .plots [key ][i ][j ],) = ax .plot (
583
+ self .first_spatial_variable [key ],
584
+ variable (t_in_seconds , ** spatial_vars ),
585
+ color = time_colors [idx ], # Different color for each time
586
+ linestyle = linestyle ,
587
+ label = f"t = { t_single :.2f} { self .time_unit } " , # Add time label
588
+ zorder = 10 ,
589
+ )
590
+ variable_handles .append (self .plots [key ][0 ][j ])
591
+ solution_handles .append (self .plots [key ][i ][0 ])
592
+
593
+ # Add a legend to indicate which plot corresponds to which time
594
+ ax .legend ()
595
+
584
596
elif variable_lists [0 ][0 ].dimensions == 2 :
585
- # Read dictionary of spatial variables
597
+ # 2D plot: superimpose plots at different times
586
598
spatial_vars = self .spatial_variable_dict [key ]
587
- # there can only be one entry in the variable list
588
599
variable = variable_lists [0 ][0 ]
589
- # different order based on whether the domains are x-r, x-z or y-z, etc
590
- if self .x_first_and_y_second [key ] is False :
591
- x_name = list (spatial_vars .keys ())[1 ][0 ]
592
- y_name = next (iter (spatial_vars .keys ()))[0 ]
593
- x = self .second_spatial_variable [key ]
594
- y = self .first_spatial_variable [key ]
595
- var = variable (t_in_seconds , ** spatial_vars )
596
- else :
597
- x_name = next (iter (spatial_vars .keys ()))[0 ]
598
- y_name = list (spatial_vars .keys ())[1 ][0 ]
600
+
601
+ for t_single in t :
602
+ t_in_seconds = t_single * self .time_scaling_factor
599
603
x = self .first_spatial_variable [key ]
600
604
y = self .second_spatial_variable [key ]
601
605
var = variable (t_in_seconds , ** spatial_vars ).T
602
- ax .set_xlabel (f"{ x_name } [{ self .spatial_unit } ]" )
603
- ax .set_ylabel (f"{ y_name } [{ self .spatial_unit } ]" )
604
- vmin , vmax = self .variable_limits [key ]
605
- # store the plot and the var data (for testing) as cant access
606
- # z data from QuadMesh or QuadContourSet object
607
- if self .is_y_z [key ] is True :
608
- self .plots [key ][0 ][0 ] = ax .pcolormesh (
609
- x ,
610
- y ,
611
- var ,
612
- vmin = vmin ,
613
- vmax = vmax ,
614
- shading = self .shading ,
606
+
607
+ ax .set_xlabel (
608
+ f"{ next (iter (spatial_vars .keys ()))[0 ]} [{ self .spatial_unit } ]"
615
609
)
616
- else :
617
- self .plots [key ][0 ][0 ] = ax .contourf (
618
- x , y , var , levels = 100 , vmin = vmin , vmax = vmax
610
+ ax .set_ylabel (
611
+ f"{ list (spatial_vars .keys ())[1 ][0 ]} [{ self .spatial_unit } ]"
619
612
)
620
- self .plots [key ][0 ][1 ] = var
621
- if vmin is None and vmax is None :
622
- vmin = ax_min (var )
623
- vmax = ax_max (var )
624
- self .colorbars [key ] = self .fig .colorbar (
625
- cm .ScalarMappable (colors .Normalize (vmin = vmin , vmax = vmax )),
626
- ax = ax ,
627
- )
628
- # Set either y label or legend entries
629
- if len (key ) == 1 :
630
- title = split_long_string (key [0 ])
631
- ax .set_title (title , fontsize = "medium" )
632
- else :
633
- ax .legend (
634
- variable_handles ,
635
- [split_long_string (s , 6 ) for s in key ],
636
- bbox_to_anchor = (0.5 , 1 ),
637
- loc = "lower center" ,
638
- )
613
+ vmin , vmax = self .variable_limits [key ]
614
+
615
+ # Use contourf and colorbars to represent the values
616
+ contour_plot = ax .contourf (
617
+ x , y , var , levels = 100 , vmin = vmin , vmax = vmax , cmap = "coolwarm"
618
+ )
619
+ self .plots [key ][0 ][0 ] = contour_plot
620
+ self .colorbars [key ] = self .fig .colorbar (contour_plot , ax = ax )
639
621
640
- # Set global legend
622
+ self .plots [key ][0 ][1 ] = var
623
+
624
+ ax .set_title (f"t = { t_single :.2f} { self .time_unit } " )
625
+
626
+ # Set global legend if there are multiple models
641
627
if len (self .labels ) > 1 :
642
628
fig_legend = self .fig .legend (
643
629
solution_handles , self .labels , loc = "lower right"
644
630
)
645
- # Get the position of the top of the legend in relative figure units
646
- # There may be a better way ...
647
- try :
648
- legend_top_inches = fig_legend .get_window_extent (
649
- renderer = self .fig .canvas .get_renderer ()
650
- ).get_points ()[1 , 1 ]
651
- fig_height_inches = (self .fig .get_size_inches () * self .fig .dpi )[1 ]
652
- legend_top = legend_top_inches / fig_height_inches
653
- except AttributeError : # pragma: no cover
654
- # When testing the examples we set the matplotlib backend to "Template"
655
- # which means that the above code doesn't work. Since this is just for
656
- # that particular test we can just skip it
657
- legend_top = 0
658
631
else :
659
- legend_top = 0
632
+ fig_legend = None
660
633
661
- # Fix layout
634
+ # Fix layout for sliders if dynamic
662
635
if dynamic :
663
636
slider_top = 0.05
664
637
else :
665
638
slider_top = 0
666
- bottom = max (legend_top , slider_top )
639
+ bottom = max (
640
+ fig_legend .get_window_extent (
641
+ renderer = self .fig .canvas .get_renderer ()
642
+ ).get_points ()[1 , 1 ]
643
+ if fig_legend
644
+ else 0 ,
645
+ slider_top ,
646
+ )
667
647
self .gridspec .tight_layout (self .fig , rect = [0 , bottom , 1 , 1 ])
668
648
669
649
def dynamic_plot (self , show_plot = True , step = None ):
0 commit comments