From d661d515c1685cc5c343002adb5d3fdf49b394a4 Mon Sep 17 00:00:00 2001 From: tmichela Date: Fri, 7 Feb 2025 11:14:26 +0100 Subject: [PATCH 1/2] fix: gui crashes when inspecting complex arrays --- damnit/gui/main_window.py | 22 ++++++++++++++++------ tests/test_gui.py | 11 +++++++++++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/damnit/gui/main_window.py b/damnit/gui/main_window.py index a188e50e..9ac98e98 100644 --- a/damnit/gui/main_window.py +++ b/damnit/gui/main_window.py @@ -691,7 +691,12 @@ def inspect_data(self, index): if data.ndim == 1: if isinstance(data, xr.DataArray): - canvas = Xarray1DPlotWindow(self, data, title=title) + try: + canvas = Xarray1DPlotWindow(self, data, title=title) + except Exception as exc: + QMessageBox.warning( + self, f"Can't inspect variable {quantity}", str(exc)) + return else: canvas = ScatterPlotWindow(self, x=[np.arange(len(data))], @@ -701,11 +706,16 @@ def inspect_data(self, index): title=title, ) elif data.ndim == 2 or (data.ndim == 3 and data.shape[-1] in (3, 4)): - canvas = ImagePlotWindow( - self, - image=data, - title=f"{variable.title} (run {run})", - ) + try: + canvas = ImagePlotWindow( + self, + image=data, + title=f"{variable.title} (run {run})", + ) + except Exception as exc: + QMessageBox.warning( + self, f"Can't inspect variable {quantity}", str(exc)) + return elif data.ndim == 0: # If this is a scalar value, then we can't plot it QMessageBox.warning(self, "Can't inspect variable", diff --git a/tests/test_gui.py b/tests/test_gui.py index 4d126cdf..dea1c5e3 100644 --- a/tests/test_gui.py +++ b/tests/test_gui.py @@ -708,6 +708,11 @@ def color_image(run): @Variable(title='2D data with summary', summary='mean') def mean_2d(run): return np.random.rand(512, 512) + + @Variable(title="1D Complex", summary='max') + def complex_1d(run): + import xarray as xr + return xr.DataArray(np.array([1+1j, 2+2j, 3+3j, 4+4j])) """ ctx_code = mock_ctx.code + "\n\n" + textwrap.dedent(const_array_code) (db_dir / "context.py").write_text(ctx_code) @@ -789,6 +794,12 @@ def get_index(title, row=0): win.inspect_data(mean_2d_index) warning.assert_not_called() + # xarray of complex data is not inspectable + complex_1d = get_index('1D Complex') + with patch.object(QMessageBox, "warning") as warning: + win.inspect_data(complex_1d) + warning.assert_called_once() + def test_open_dialog(mock_db, qtbot): db_dir, db = mock_db From 8ed030ffbdde8ec84206de424dc09ffd6fa1cc67 Mon Sep 17 00:00:00 2001 From: tmichela Date: Fri, 7 Feb 2025 11:54:19 +0100 Subject: [PATCH 2/2] test coverage test coverage --- tests/test_gui.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/test_gui.py b/tests/test_gui.py index dea1c5e3..15d0060d 100644 --- a/tests/test_gui.py +++ b/tests/test_gui.py @@ -709,8 +709,15 @@ def color_image(run): def mean_2d(run): return np.random.rand(512, 512) - @Variable(title="1D Complex", summary='max') - def complex_1d(run): + @Variable(title="2D Complex", summary='max') + def complex_2d(run): + return np.array([ + [1+1j, 2+2j, 3+3j, 4+4j], + [1+1j, 2+2j, 3+3j, 4+4j], + [1+1j, 2+2j, 3+3j, 4+4j]]) + + @Variable(title="1D Xarray Complex", summary='max') + def complex_xr_1d(run): import xarray as xr return xr.DataArray(np.array([1+1j, 2+2j, 3+3j, 4+4j])) """ @@ -795,9 +802,14 @@ def get_index(title, row=0): warning.assert_not_called() # xarray of complex data is not inspectable - complex_1d = get_index('1D Complex') + complex_2d = get_index('2D Complex') with patch.object(QMessageBox, "warning") as warning: - win.inspect_data(complex_1d) + win.inspect_data(complex_2d) + warning.assert_called_once() + + complex_xr_1d = get_index('1D Xarray Complex') + with patch.object(QMessageBox, "warning") as warning: + win.inspect_data(complex_xr_1d) warning.assert_called_once()