Skip to content

Commit

Permalink
handle threading properly in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pcolange committed Feb 24, 2025
1 parent b8a86bf commit 6d1d471
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 47 deletions.
7 changes: 6 additions & 1 deletion src/digest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
QMenu,
)
from PySide6.QtGui import QDragEnterEvent, QDropEvent, QPixmap, QIcon, QFont
from PySide6.QtCore import Qt, QSize, QThreadPool
from PySide6.QtCore import Qt, QSize, QThreadPool, Signal

from digest.dialog import StatusDialog, InfoDialog, WarnDialog, ProgressDialog
from digest.similarity_analysis import SimilarityWorker, post_process
Expand Down Expand Up @@ -160,6 +160,9 @@ class ModelLoadError(Exception):


class DigestApp(QMainWindow):
"""Main application window for Digest."""

model_loaded = Signal() # Used for tests

class Page(IntEnum):
SPLASH = 0
Expand Down Expand Up @@ -389,6 +392,8 @@ def update_similarity_widget(
):
self.ui.saveBtn.setEnabled(True)

self.model_loaded.emit() # Used for tests

def load_model(self, file_path: str) -> None:
try:
file_path = os.path.normpath(file_path)
Expand Down
3 changes: 0 additions & 3 deletions src/digest/modelsummary.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@

class modelSummary(QWidget):

# def __init__(
# self, digest_model: Union[DigestOnnxModel, DigestReportModel], parent=None
# ):
def __init__(self, digest_model: DigestModel, parent=None):
super().__init__(parent)
self.ui = Ui_modelSummary()
Expand Down
124 changes: 81 additions & 43 deletions test/test_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

# pylint: disable=no-name-in-module
from PySide6.QtTest import QTest
from PySide6.QtCore import Qt
from PySide6.QtCore import Qt, QEventLoop, QTimer, Signal
from PySide6.QtWidgets import QApplication

import digest.main
Expand All @@ -29,43 +29,67 @@ class DigestGuiTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.app = QApplication(sys.argv)
cls.app.setQuitOnLastWindowClosed(True) # Ensure proper cleanup
return super().setUpClass()

@classmethod
def tearDownClass(cls):
if isinstance(cls.app, QApplication):
cls.app.processEvents()
cls.app.closeAllWindows()
cls.app = None
cls.app.quit() # Explicitly quit the application
cls.app = None
return super().tearDownClass()

def setUp(self):
self.digest_app = digest.main.DigestApp()
if self.digest_app is None:
self.fail("Failed to initialize DigestApp")
self.digest_app.show()
self.initial_tab_count = self.digest_app.ui.tabWidget.count()
self.addCleanup(self.digest_app.close)
QApplication.processEvents() # Process initial events
self.addCleanup(self._cleanup)

def _cleanup(self):
"""Ensure proper cleanup of Qt resources"""
if self.digest_app:
# Close all windows first
for window in QApplication.topLevelWindows():
window.close()
QApplication.processEvents()

# Wait for any pending threads with a shorter timeout
if hasattr(self.digest_app, "thread_pool"):
self.digest_app.thread_pool.clear() # Cancel any pending tasks
self.digest_app.thread_pool.waitForDone(2000) # 2 second timeout

QApplication.processEvents()
self.digest_app.close()
self.digest_app = None
QApplication.processEvents()

def tearDown(self):
self.digest_app.close()
QApplication.processEvents() # Ensure all pending events are processed
self.digest_app = None
super().tearDown()
def _wait_for_signal(self, signal: Signal, timeout_ms=THREAD_TIMEOUT):
"""Wait for a signal to be emitted, with a timeout."""
loop = QEventLoop()
signal_emitted = []

def wait_all_threads(self, timeout_ms=None) -> bool:
"""Wait for all tasks in the thread pool to complete."""
timeout_ms = timeout_ms or self.THREAD_TIMEOUT
QApplication.processEvents() # Ensure pending events are processed
return self.digest_app.thread_pool.waitForDone(timeout_ms)
def on_signal_emitted():
signal_emitted.append(True)
loop.quit()

signal.connect(on_signal_emitted)
QTimer.singleShot(timeout_ms, loop.quit)
loop.exec()

return bool(signal_emitted)

def _mock_file_open(self, mock_dialog, filepath):
"""Helper to mock file open dialog"""
"""Helper to mock file open dialog and wait for load completion."""
mock_dialog.return_value = (filepath, "")
QTest.mouseClick(self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton)
self.assertTrue(self.wait_all_threads())

def _verify_tab_added(self):
"""Verify that exactly one new tab was added"""
# Process events to ensure UI updates
QApplication.processEvents()
self.assertEqual(
self.digest_app.ui.tabWidget.count(),
Expand All @@ -79,29 +103,43 @@ def _close_current_tab(self):
self.digest_app.closeTab(current_tab)

def test_open_valid_onnx(self):
"""Test that opening a valid ONNX file creates a new tab in the UI."""
"""Test opening a valid ONNX file."""
with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog:
self._mock_file_open(mock_dialog, self.ONNX_FILEPATH)
self._verify_tab_added()
# Wait for the signal *after* clicking the button
self.assertTrue(
self._wait_for_signal(self.digest_app.model_loaded),
"Model load did not complete within timeout (test_open_valid_onnx)",
)
self._verify_tab_added() # Verify tab *after* successful load
self._close_current_tab()

def test_open_valid_yaml(self):
"""Test that opening a valid YAML report file creates a new tab in the UI."""
"""Test opening a valid YAML report file."""
with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog:
self._mock_file_open(mock_dialog, self.YAML_FILEPATH)
self.assertTrue(
self._wait_for_signal(self.digest_app.model_loaded),
"Model load did not complete within timeout (test_open_valid_yaml)",
)
self._verify_tab_added()
self._close_current_tab()

def test_open_invalid_file(self):
"""Test opening an invalid file (no tab should be added)."""
with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog:
self._mock_file_open(mock_dialog, "invalid_file.txt")
mock_dialog.return_value = ("invalid_file.txt", "")
QTest.mouseClick(self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton)
# No need to wait for a signal here, as it won't be emitted
QApplication.processEvents() # Process events to update UI
self.assertEqual(
self.digest_app.ui.tabWidget.count(),
self.initial_tab_count,
"No new tab should be added for invalid file",
)

def test_save_reports(self):
"""Test saving reports after loading a model."""
with patch(
"PySide6.QtWidgets.QFileDialog.getOpenFileName"
) as mock_open_dialog, patch(
Expand All @@ -111,27 +149,23 @@ def test_save_reports(self):
with tempfile.TemporaryDirectory() as tmpdirname:
mock_save_dialog.return_value = tmpdirname
self._mock_file_open(mock_open_dialog, self.ONNX_FILEPATH)

# Process any pending events and wait for threads
QApplication.processEvents()

# Wait for the signal
self.assertTrue(
self.wait_all_threads(),
"Background tasks did not complete within the specified timeout",
self._wait_for_signal(self.digest_app.model_loaded),
"Model load did not complete within timeout (save reports)",
)

# Process events again after threads complete
QApplication.processEvents()

# Add debug information
self._print_debug_info()
QApplication.processEvents() # Ensure UI is updated

self.assertTrue(
self.digest_app.ui.saveBtn.isEnabled(),
"Save button should be enabled after loading file",
)
QTest.mouseClick(self.digest_app.ui.saveBtn, Qt.MouseButton.LeftButton)
QApplication.processEvents()

def test_save_tables(self):
"""Test saving node tables."""
with patch(
"PySide6.QtWidgets.QFileDialog.getOpenFileName"
) as mock_open_dialog, patch(
Expand All @@ -146,11 +180,12 @@ def test_save_tables(self):

self._mock_file_open(mock_open_dialog, self.ONNX_FILEPATH)

# Process events and wait for threads before accessing nodes window
QApplication.processEvents()
# Wait for the signal
self.assertTrue(
self.wait_all_threads(), "Threads did not complete in time"
self._wait_for_signal(self.digest_app.model_loaded),
"Model load did not complete within timeout (save tables)",
)
QApplication.processEvents()

self._save_nodes_list(output_file)
self._close_current_tab()
Expand All @@ -160,16 +195,19 @@ def _save_nodes_list(self, expected_output):
QTest.mouseClick(self.digest_app.ui.nodesListBtn, Qt.MouseButton.LeftButton)

# Get the node window and verify it
_, node_window = self.digest_app.nodes_window.popitem()
node_summary = node_window.main_window.centralWidget()
self.assertIsInstance(node_summary, NodeSummary)
if self.digest_app.nodes_window:
_, node_window = self.digest_app.nodes_window.popitem()
node_summary = node_window.main_window.centralWidget()
self.assertIsInstance(node_summary, NodeSummary)

if isinstance(node_summary, NodeSummary):
QTest.mouseClick(node_summary.ui.saveCsvBtn, Qt.MouseButton.LeftButton)
self.assertTrue(
os.path.exists(expected_output),
f"Nodes csv file not found at {expected_output}",
)
if isinstance(node_summary, NodeSummary):
QTest.mouseClick(node_summary.ui.saveCsvBtn, Qt.MouseButton.LeftButton)
self.assertTrue(
os.path.exists(expected_output),
f"Nodes csv file not found at {expected_output}",
)
else:
self.fail("Node summary window did not appear within timeout.")

def _print_debug_info(self):
"""Print debug information about the current UI state."""
Expand Down

0 comments on commit 6d1d471

Please sign in to comment.