From 05186bbad10735188bc79be64ae5f3ab6d660a3f Mon Sep 17 00:00:00 2001 From: Philip Colangelo Date: Fri, 7 Feb 2025 16:01:06 -0500 Subject: [PATCH 01/11] simplifies model load and places it in a thread pool - still needs tests to be updated --- src/digest/main.py | 658 ++++-------------- src/digest/model_class/digest_onnx_model.py | 43 ++ src/digest/model_class/digest_report_model.py | 25 + src/digest/model_load.py | 247 +++++++ src/digest/modelsummary.py | 138 +++- src/digest/qt_utils.py | 26 + .../{thread.py => similarity_analysis.py} | 79 +-- 7 files changed, 611 insertions(+), 605 deletions(-) create mode 100644 src/digest/model_load.py rename src/digest/{thread.py => similarity_analysis.py} (58%) diff --git a/src/digest/main.py b/src/digest/main.py index 79e55c3..ceef362 100644 --- a/src/digest/main.py +++ b/src/digest/main.py @@ -6,7 +6,7 @@ import shutil import argparse from datetime import datetime -from typing import Dict, Tuple, Optional, Union +from typing import Dict, Tuple, Optional import tempfile from enum import IntEnum import pandas as pd @@ -35,10 +35,10 @@ QMenu, ) from PySide6.QtGui import QDragEnterEvent, QDropEvent, QPixmap, QMovie, QIcon, QFont -from PySide6.QtCore import Qt, QSize +from PySide6.QtCore import Qt, QSize, QThreadPool from digest.dialog import StatusDialog, InfoDialog, WarnDialog, ProgressDialog -from digest.thread import StatsThread, SimilarityThread, post_process +from digest.similarity_analysis import SimilarityWorker, post_process from digest.popup_window import PopupWindow from digest.huggingface_page import HuggingfacePage from digest.multi_model_selection_page import MultiModelSelectionPage @@ -46,9 +46,15 @@ from digest.modelsummary import modelSummary from digest.node_summary import NodeSummary from digest.qt_utils import apply_dark_style_sheet -from digest.model_class.digest_model import DigestModel -from digest.model_class.digest_onnx_model import DigestOnnxModel -from digest.model_class.digest_report_model import DigestReportModel +from digest.model_class.digest_model import SupportedModelTypes, DigestModel +from digest.model_class.digest_onnx_model import ( + DigestOnnxModel, + LoadDigestOnnxModelWorker, +) +from digest.model_class.digest_report_model import ( + DigestReportModel, + LoadDigestReportModelWorker, +) from utils import onnx_utils GUI_CONFIG = os.path.join(os.path.dirname(__file__), "gui_config.yaml") @@ -162,15 +168,13 @@ def __init__(self, model_file: Optional[str] = None): self.ui = Ui_MainWindow() self.ui.setupUi(self) + self.thread_pool = QThreadPool() + self.nodes_window: Dict[str, PopupWindow] = {} - self.status_dialog = None + self.status_dialog: Optional[StatusDialog] = None self.err_open_dialog = None self.temp_dir = tempfile.TemporaryDirectory() - self.digest_models: Dict[str, Union[DigestOnnxModel, DigestReportModel]] = {} - - # QThread containers - self.model_nodes_stats_thread: Dict[str, StatsThread] = {} - self.model_similarity_thread: Dict[str, SimilarityThread] = {} + self.digest_models: Dict[str, DigestModel] = {} self.model_similarity_report: Dict[str, SimilarityAnalysisReport] = {} @@ -195,6 +199,8 @@ def __init__(self, model_file: Optional[str] = None): self.ui.infoBtn.clicked.connect(self.show_info_dialog) self.infoDialog = None + self.load_progress: Optional[ProgressDialog] = None + enable_huggingface_model = True with open(GUI_CONFIG, "r", encoding="utf-8") as f: config = yaml.safe_load(f) @@ -227,17 +233,7 @@ def __init__(self, model_file: Optional[str] = None): # Load model file if given as input to the executable if model_file: - exists = os.path.exists(model_file) - ext = os.path.splitext(model_file)[-1] - if exists and ext == ".onnx": - self.load_onnx(model_file) - elif exists and ext == ".yaml": - self.load_report(model_file) - else: - self.err_open_dialog = StatusDialog( - f"Could not open {model_file}", parent=self - ) - self.err_open_dialog.show() + self.load_model(model_file) def uncheck_single_model_buttons(self): for button in self.ui.singleModelWidget.findChildren(QPushButton): @@ -250,11 +246,11 @@ def uncheck_ingest_buttons(self): def tab_focused(self, index): widget = self.ui.tabWidget.widget(index) if isinstance(widget, modelSummary): - unique_id = widget.digest_model.unique_id + model_id = widget.model_id if ( - self.stats_save_button_flag[unique_id] - and self.similarity_save_button_flag[unique_id] - and not isinstance(widget.digest_model, DigestReportModel) + self.stats_save_button_flag[model_id] + and self.similarity_save_button_flag[model_id] + and not widget.model_type == SupportedModelTypes.REPORT ): self.ui.saveBtn.setEnabled(True) else: @@ -263,22 +259,12 @@ def tab_focused(self, index): def closeTab(self, index): summary_widget = self.ui.tabWidget.widget(index) if isinstance(summary_widget, modelSummary): - unique_id = summary_widget.digest_model.unique_id + model_id = summary_widget.model_id summary_widget.deleteLater() - tab_thread = self.model_nodes_stats_thread.get(unique_id) - if tab_thread: - tab_thread.exit() - tab_thread.wait(5000) - - if not tab_thread.isRunning(): - del self.model_nodes_stats_thread[unique_id] - else: - print(f"Warning: Thread for {unique_id} did not finish in time") - # delete the digest model to free up used memory - if unique_id in self.digest_models: - del self.digest_models[unique_id] + if model_id in self.digest_models: + del self.digest_models[model_id] self.ui.tabWidget.removeTab(index) if self.ui.tabWidget.count() == 0: @@ -295,64 +281,6 @@ def openFile(self): self.load_model(file_name) - def update_cards( - self, - digest_model: DigestModel, - unique_id: str, - ): - self.digest_models[unique_id].flops = digest_model.flops - self.digest_models[unique_id].node_type_flops = digest_model.node_type_flops - self.digest_models[unique_id].parameters = digest_model.parameters - self.digest_models[unique_id].node_type_parameters = ( - digest_model.node_type_parameters - ) - self.digest_models[unique_id].node_data = digest_model.node_data - - # We must iterate over the tabWidget and match to the tab_name because the user - # may have switched the currentTab during the threads execution. - curr_index = -1 - for index in range(self.ui.tabWidget.count()): - widget = self.ui.tabWidget.widget(index) - if ( - isinstance(widget, modelSummary) - and widget.digest_model.unique_id == unique_id - ): - if digest_model.flops is None: - flops_str = "--" - else: - flops_str = format(digest_model.flops, ",") - - # Set up the pie chart - pie_chart_labels, pie_chart_data = zip( - *self.digest_models[unique_id].node_type_flops.items() - ) - widget.ui.flopsPieChart.set_data( - "FLOPs Intensity Per Op Type", - pie_chart_labels, - pie_chart_data, - ) - - widget.ui.flops.setText(flops_str) - - # Set up the pie chart - pie_chart_labels, pie_chart_data = zip( - *self.digest_models[unique_id].node_type_parameters.items() - ) - widget.ui.parametersPieChart.set_data( - "Parameter Intensity Per Op Type", - pie_chart_labels, - pie_chart_data, - ) - curr_index = index - break - - self.stats_save_button_flag[unique_id] = True - if self.ui.tabWidget.currentIndex() == curr_index: - if self.similarity_save_button_flag[unique_id] and not isinstance( - digest_model, DigestReportModel - ): - self.ui.saveBtn.setEnabled(True) - def open_similarity_report(self, model_id: str, image_path, most_similar_models): self.model_similarity_report[model_id] = SimilarityAnalysisReport( image_path, most_similar_models @@ -368,32 +296,25 @@ def update_similarity_widget( df_sorted: Optional[pd.DataFrame] = None, ): widget = None - digest_model = None curr_index = -1 for index in range(self.ui.tabWidget.count()): tab_widget = self.ui.tabWidget.widget(index) - if ( - isinstance(tab_widget, modelSummary) - and tab_widget.digest_model.unique_id == model_id - ): + if isinstance(tab_widget, modelSummary) and tab_widget.model_id == model_id: widget = tab_widget - digest_model = tab_widget.digest_model curr_index = index break # convert back to a List[str] most_similar_list = most_similar.split(",") - if ( - completed_successfully - and isinstance(widget, modelSummary) - and digest_model - and png_filepath - ): + if completed_successfully and isinstance(widget, modelSummary) and png_filepath: if df_sorted is not None: post_process( - digest_model.model_name, most_similar_list, df_sorted, png_filepath + self.digest_models[model_id].model_name, + most_similar_list, + df_sorted, + png_filepath, ) widget.load_gif.stop() @@ -459,464 +380,145 @@ def update_similarity_widget( self.similarity_save_button_flag[model_id] = True if self.ui.tabWidget.currentIndex() == curr_index: if self.stats_save_button_flag[model_id] and not isinstance( - digest_model, DigestReportModel + self.digest_models[model_id], DigestReportModel ): self.ui.saveBtn.setEnabled(True) - def load_onnx(self, filepath: str): + def load_model(self, file_path: str): # Ensure the filepath follows a standard formatting: - filepath = os.path.normpath(filepath) + file_path = os.path.normpath(file_path) - if not os.path.exists(filepath): + if not os.path.exists(file_path): + self.status_dialog = StatusDialog( + f"Model file {file_path} does not exist.", + parent=self, + ) + self.status_dialog.show() return - # Every time an onnx is loaded we should emulate a model summary button click - self.summary_clicked() + basename, file_ext = os.path.splitext(os.path.basename(file_path)) + + supported_exts = [".onnx", ".yaml"] + + if not file_ext in supported_exts: + self.status_dialog = StatusDialog( + f"Digest does not support files with the extension {file_ext}", + parent=self, + ) + self.status_dialog.show() + return # Before opening the file, check to see if it is already opened. for index in range(self.ui.tabWidget.count()): widget = self.ui.tabWidget.widget(index) - if isinstance(widget, modelSummary) and filepath == widget.file: + if isinstance(widget, modelSummary) and file_path == widget.file: self.ui.tabWidget.setCurrentIndex(index) return - try: - - progress = ProgressDialog("Loading & Optimizing ONNX Model...", 8, self) - QApplication.processEvents() # Process pending events + self.load_progress = ProgressDialog("Loading Model...", 3, self) + self.load_progress.step() - model = onnx_utils.load_onnx(filepath, load_external_data=False) - opt_model, opt_passed = onnx_utils.optimize_onnx_model(model) - progress.step() - - basename = os.path.splitext(os.path.basename(filepath)) - model_name = basename[0] - - # Save the model proto so we can use the Freeze Inputs feature - digest_model = DigestOnnxModel( - onnx_model=opt_model, model_name=model_name, save_proto=True - ) - model_id = digest_model.unique_id - - self.stats_save_button_flag[model_id] = False - self.similarity_save_button_flag[model_id] = False - - self.digest_models[model_id] = digest_model - - model_summary = modelSummary(digest_model) - if model_summary.freeze_inputs: - model_summary.freeze_inputs.complete_signal.connect(self.load_onnx) - - dynamic_input_dims = onnx_utils.get_dynamic_input_dims(opt_model) - if dynamic_input_dims: - model_summary.ui.freezeButton.setVisible(True) - model_summary.ui.warningLabel.setText( - "⚠️ Some model details are unavailable due to dynamic input dimensions. " - "See section Input Tensor(s) Information below for more details." - ) - model_summary.ui.warningLabel.show() - - elif not opt_passed: - model_summary.ui.warningLabel.setText( - "⚠️ The model could not be optimized either due to an ONNX Runtime " - "session error or it did not pass the ONNX checker." - ) - model_summary.ui.warningLabel.show() - - progress.step() - progress.setLabelText("Checking for dynamic Inputs") - - self.ui.tabWidget.addTab(model_summary, "") - model_summary.ui.flops.setText("Loading...") - - # Hide some of the components - model_summary.ui.similarityCorrelation.hide() - model_summary.ui.similarityCorrelationStatic.hide() - - model_summary.file = filepath - model_summary.setObjectName(model_name) - model_summary.ui.modelName.setText(model_name) - model_summary.ui.modelFilename.setText(filepath) - model_summary.ui.generatedDate.setText(datetime.now().strftime("%B %d, %Y")) + self.load_progress.setLabelText( + "Creating a Digest model. Please be patient as this could take a minute." + ) - digest_model.model_name = model_name - digest_model.filepath = filepath - digest_model.model_inputs = onnx_utils.get_model_input_shapes_types( - opt_model + if file_ext == ".onnx": + # Load the digest onnx model on a separate thread + digest_model_worker = LoadDigestOnnxModelWorker( + model_name=basename, model_file_path=file_path ) - digest_model.model_outputs = onnx_utils.get_model_output_shapes_types( - opt_model + elif file_ext == ".yaml": + digest_model_worker = LoadDigestReportModelWorker( + model_name=basename, model_file_path=file_path ) - progress.step() - progress.setLabelText("Calculating Parameter Count") - - parameter_count = onnx_utils.get_parameter_count(opt_model) - model_summary.ui.parameters.setText(format(parameter_count, ",")) - - # Kick off model stats thread - self.model_nodes_stats_thread[model_id] = StatsThread() - self.model_nodes_stats_thread[model_id].completed.connect(self.update_cards) + digest_model_worker.signals.completed.connect(self.post_load_model) - self.model_nodes_stats_thread[model_id].model = opt_model - self.model_nodes_stats_thread[model_id].tab_name = model_name - self.model_nodes_stats_thread[model_id].unique_id = model_id - self.model_nodes_stats_thread[model_id].start() + self.thread_pool.start(digest_model_worker) - progress.step() - progress.setLabelText("Calculating Node Type Counts") + def post_load_model(self, digest_model: DigestModel): + """This function is automatically run after the model load workers are finished""" - node_type_counts = onnx_utils.get_node_type_counts(opt_model) - if len(node_type_counts) < 15: - bar_spacing = 40 - else: - bar_spacing = 20 - model_summary.ui.opHistogramChart.bar_spacing = bar_spacing - model_summary.ui.opHistogramChart.set_data(node_type_counts) - model_summary.ui.nodes.setText(str(sum(node_type_counts.values()))) - digest_model.node_type_counts = node_type_counts - - progress.step() - progress.setLabelText("Gathering Model Inputs and Outputs") - - # Inputs Table - model_summary.ui.inputsTable.setRowCount( - len(self.digest_models[model_id].model_inputs) - ) + if self.load_progress: + self.load_progress.step() + self.load_progress.setLabelText("Displaying the model summary") - for row_idx, (input_name, input_info) in enumerate( - self.digest_models[model_id].model_inputs.items() - ): - model_summary.ui.inputsTable.setItem( - row_idx, 0, QTableWidgetItem(input_name) - ) - model_summary.ui.inputsTable.setItem( - row_idx, 1, QTableWidgetItem(str(input_info.shape)) - ) - model_summary.ui.inputsTable.setItem( - row_idx, 2, QTableWidgetItem(str(input_info.dtype)) - ) - model_summary.ui.inputsTable.setItem( - row_idx, 3, QTableWidgetItem(str(input_info.size_kbytes)) - ) + if digest_model.unique_id: + model_id = digest_model.unique_id + else: - model_summary.ui.inputsTable.resizeColumnsToContents() - model_summary.ui.inputsTable.resizeRowsToContents() + if self.load_progress: + self.load_progress.close() - # Outputs Table - model_summary.ui.outputsTable.setRowCount( - len(self.digest_models[model_id].model_outputs) + self.status_dialog = StatusDialog( + "Unexpected Error: Digest model did not return a valid ID.", + parent=self, ) - for row_idx, (output_name, output_info) in enumerate( - self.digest_models[model_id].model_outputs.items() - ): - model_summary.ui.outputsTable.setItem( - row_idx, 0, QTableWidgetItem(output_name) - ) - model_summary.ui.outputsTable.setItem( - row_idx, 1, QTableWidgetItem(str(output_info.shape)) - ) - model_summary.ui.outputsTable.setItem( - row_idx, 2, QTableWidgetItem(str(output_info.dtype)) - ) - model_summary.ui.outputsTable.setItem( - row_idx, 3, QTableWidgetItem(str(output_info.size_kbytes)) - ) + self.status_dialog.show() + print("Unexpected Error: Digest model did not return a valid ID.") + return + + self.stats_save_button_flag[model_id] = False + self.similarity_save_button_flag[model_id] = False - model_summary.ui.outputsTable.resizeColumnsToContents() - model_summary.ui.outputsTable.resizeRowsToContents() + # Every time an onnx is loaded we should emulate a model summary button click + self.summary_clicked() + self.digest_models[model_id] = digest_model - progress.step() - progress.setLabelText("Gathering Model Proto Data") + model_summary = modelSummary(self.digest_models[model_id]) + if model_summary.freeze_inputs: + model_summary.freeze_inputs.complete_signal.connect(self.load_model) - # ModelProto Info - model_summary.ui.modelProtoTable.setItem( - 0, 1, QTableWidgetItem(str(opt_model.model_version)) - ) - digest_model.model_version = opt_model.model_version + self.ui.tabWidget.addTab(model_summary, "") - model_summary.ui.modelProtoTable.setItem( - 1, 1, QTableWidgetItem(str(opt_model.graph.name)) - ) - digest_model.graph_name = opt_model.graph.name + new_tab_idx = self.ui.tabWidget.count() - 1 + self.ui.tabWidget.setTabText(new_tab_idx, "".join(digest_model.model_name)) + self.ui.tabWidget.setCurrentIndex(new_tab_idx) + self.ui.stackedWidget.setCurrentIndex(self.Page.SUMMARY) + self.ui.singleModelWidget.show() - producer_txt = f"{opt_model.producer_name} {opt_model.producer_version}" - model_summary.ui.modelProtoTable.setItem( - 2, 1, QTableWidgetItem(producer_txt) - ) - digest_model.producer_name = opt_model.producer_name - digest_model.producer_version = opt_model.producer_version + if self.load_progress: + self.load_progress.step() - model_summary.ui.modelProtoTable.setItem( - 3, 1, QTableWidgetItem(str(opt_model.ir_version)) + if isinstance(digest_model, DigestOnnxModel) and digest_model.model_proto: + self.stats_save_button_flag[model_id] = True + dynamic_input_dims = onnx_utils.get_dynamic_input_dims( + digest_model.model_proto ) - digest_model.ir_version = opt_model.ir_version - - for imp in opt_model.opset_import: - row_idx = model_summary.ui.importsTable.rowCount() - model_summary.ui.importsTable.insertRow(row_idx) - if imp.domain == "" or imp.domain == "ai.onnx": - model_summary.ui.opsetVersion.setText(str(imp.version)) - domain = "ai.onnx" - digest_model.opset = imp.version - else: - domain = imp.domain - model_summary.ui.importsTable.setItem( - row_idx, 0, QTableWidgetItem(str(domain)) - ) - model_summary.ui.importsTable.setItem( - row_idx, 1, QTableWidgetItem(str(imp.version)) + if dynamic_input_dims: + model_summary.ui.freezeButton.setVisible(True) + model_summary.ui.warningLabel.setText( + "⚠️ Some model details are unavailable due to dynamic input dimensions. " + "See section Input Tensor(s) Information below for more details." ) - row_idx += 1 - - digest_model.imports[imp.domain] = imp.version - - progress.step() - progress.setLabelText("Wrapping Up Model Analysis") - - model_summary.ui.importsTable.resizeColumnsToContents() - model_summary.ui.modelProtoTable.resizeColumnsToContents() - model_summary.setObjectName(model_name) - new_tab_idx = self.ui.tabWidget.count() - 1 - self.ui.tabWidget.setTabText(new_tab_idx, "".join(model_name)) - self.ui.tabWidget.setCurrentIndex(new_tab_idx) - self.ui.stackedWidget.setCurrentIndex(self.Page.SUMMARY) - self.ui.singleModelWidget.show() - progress.step() + model_summary.ui.warningLabel.show() # Start similarity Analysis # Note: Should only be started after the model tab has been created png_tmp_path = os.path.join(self.temp_dir.name, model_id) os.makedirs(png_tmp_path, exist_ok=True) assert os.path.exists(png_tmp_path), f"Error with creating {png_tmp_path}" - self.model_similarity_thread[model_id] = SimilarityThread() - self.model_similarity_thread[model_id].completed_successfully.connect( - self.update_similarity_widget + png_file_path = os.path.join( + png_tmp_path, f"heatmap_{digest_model.model_name}.png" ) - self.model_similarity_thread[model_id].model_filepath = filepath - self.model_similarity_thread[model_id].png_filepath = os.path.join( - png_tmp_path, f"heatmap_{model_name}.png" - ) - self.model_similarity_thread[model_id].model_id = model_id - self.model_similarity_thread[model_id].start() - - progress.close() - - except FileNotFoundError as e: - print(f"File not found: {e.filename}") - def load_report(self, filepath: str): + model_summary.png_file_path = png_file_path - # Ensure the filepath follows a standard formatting: - filepath = os.path.normpath(filepath) - - if not os.path.exists(filepath): - return - - # Every time a report is loaded we should emulate a model summary button click - self.summary_clicked() - - # Before opening the file, check to see if it is already opened. - for index in range(self.ui.tabWidget.count()): - widget = self.ui.tabWidget.widget(index) - if isinstance(widget, modelSummary) and filepath == widget.file: - self.ui.tabWidget.setCurrentIndex(index) - return - - try: - - progress = ProgressDialog("Loading Digest Report File...", 2, self) - QApplication.processEvents() # Process pending events - - digest_model = DigestReportModel(filepath) - - if not digest_model.is_valid: - progress.close() - invalid_yaml_dialog = StatusDialog( - title="Warning", - status_message=f"YAML file {filepath} is not a valid digest report", - ) - invalid_yaml_dialog.show() - - return - - model_id = digest_model.unique_id - - # There is no sense in offering to save the report - self.stats_save_button_flag[model_id] = False - self.similarity_save_button_flag[model_id] = False - - self.digest_models[model_id] = digest_model - - model_summary = modelSummary(digest_model) - - self.ui.tabWidget.addTab(model_summary, "") - model_summary.ui.flops.setText("Loading...") - - # Hide some of the components - model_summary.ui.similarityCorrelation.hide() - model_summary.ui.similarityCorrelationStatic.hide() - - model_summary.file = filepath - model_summary.setObjectName(digest_model.model_name) - model_summary.ui.modelName.setText(digest_model.model_name) - model_summary.ui.modelFilename.setText(filepath) - model_summary.ui.generatedDate.setText(datetime.now().strftime("%B %d, %Y")) - - model_summary.ui.parameters.setText(format(digest_model.parameters, ",")) - - node_type_counts = digest_model.node_type_counts - if len(node_type_counts) < 15: - bar_spacing = 40 - else: - bar_spacing = 20 - - model_summary.ui.opHistogramChart.bar_spacing = bar_spacing - model_summary.ui.opHistogramChart.set_data(node_type_counts) - model_summary.ui.nodes.setText(str(sum(node_type_counts.values()))) - - progress.step() - progress.setLabelText("Gathering Model Inputs and Outputs") - - # Inputs Table - model_summary.ui.inputsTable.setRowCount( - len(self.digest_models[model_id].model_inputs) + similarity_worker = SimilarityWorker( + digest_model.filepath, png_file_path, model_id ) + similarity_worker.signals.completed.connect(self.update_similarity_widget) + self.thread_pool.start(similarity_worker) - for row_idx, (input_name, input_info) in enumerate( - self.digest_models[model_id].model_inputs.items() - ): - model_summary.ui.inputsTable.setItem( - row_idx, 0, QTableWidgetItem(input_name) - ) - model_summary.ui.inputsTable.setItem( - row_idx, 1, QTableWidgetItem(str(input_info.shape)) - ) - model_summary.ui.inputsTable.setItem( - row_idx, 2, QTableWidgetItem(str(input_info.dtype)) - ) - model_summary.ui.inputsTable.setItem( - row_idx, 3, QTableWidgetItem(str(input_info.size_kbytes)) - ) - - model_summary.ui.inputsTable.resizeColumnsToContents() - model_summary.ui.inputsTable.resizeRowsToContents() - - # Outputs Table - model_summary.ui.outputsTable.setRowCount( - len(self.digest_models[model_id].model_outputs) - ) - for row_idx, (output_name, output_info) in enumerate( - self.digest_models[model_id].model_outputs.items() - ): - model_summary.ui.outputsTable.setItem( - row_idx, 0, QTableWidgetItem(output_name) - ) - model_summary.ui.outputsTable.setItem( - row_idx, 1, QTableWidgetItem(str(output_info.shape)) - ) - model_summary.ui.outputsTable.setItem( - row_idx, 2, QTableWidgetItem(str(output_info.dtype)) - ) - model_summary.ui.outputsTable.setItem( - row_idx, 3, QTableWidgetItem(str(output_info.size_kbytes)) - ) - - model_summary.ui.outputsTable.resizeColumnsToContents() - model_summary.ui.outputsTable.resizeRowsToContents() - - progress.step() - progress.setLabelText("Gathering Model Proto Data") - - # ModelProto Info - model_summary.ui.modelProtoTable.setItem( - 0, 1, QTableWidgetItem(str(digest_model.model_data["model_version"])) - ) - - model_summary.ui.modelProtoTable.setItem( - 1, 1, QTableWidgetItem(str(digest_model.model_data["graph_name"])) - ) - - producer_txt = ( - f"{digest_model.model_data['producer_name']} " - f"{digest_model.model_data['producer_version']}" - ) - model_summary.ui.modelProtoTable.setItem( - 2, 1, QTableWidgetItem(producer_txt) - ) - - model_summary.ui.modelProtoTable.setItem( - 3, 1, QTableWidgetItem(str(digest_model.model_data["ir_version"])) - ) - - for domain, version in digest_model.model_data["import_list"].items(): - row_idx = model_summary.ui.importsTable.rowCount() - model_summary.ui.importsTable.insertRow(row_idx) - if domain == "" or domain == "ai.onnx": - model_summary.ui.opsetVersion.setText(str(version)) - domain = "ai.onnx" - - model_summary.ui.importsTable.setItem( - row_idx, 0, QTableWidgetItem(str(domain)) - ) - model_summary.ui.importsTable.setItem( - row_idx, 1, QTableWidgetItem(str(version)) - ) - row_idx += 1 - - progress.step() - progress.setLabelText("Wrapping Up Model Analysis") - - model_summary.ui.importsTable.resizeColumnsToContents() - model_summary.ui.modelProtoTable.resizeColumnsToContents() - model_summary.setObjectName(digest_model.model_name) - new_tab_idx = self.ui.tabWidget.count() - 1 - self.ui.tabWidget.setTabText(new_tab_idx, "".join(digest_model.model_name)) - self.ui.tabWidget.setCurrentIndex(new_tab_idx) - self.ui.stackedWidget.setCurrentIndex(self.Page.SUMMARY) - self.ui.singleModelWidget.show() - progress.step() - - self.update_cards(digest_model, digest_model.unique_id) - - movie = QMovie(":/assets/gifs/load.gif") - model_summary.ui.similarityImg.setMovie(movie) - movie.start() - + elif isinstance(digest_model, DigestReportModel): self.update_similarity_widget( - completed_successfully=bool(digest_model.similarity_heatmap_path), - model_id=digest_model.unique_id, + bool(digest_model.similarity_heatmap_path), + model_id=model_id, most_similar="", png_filepath=digest_model.similarity_heatmap_path, ) - progress.close() - - except FileNotFoundError as e: - print(f"File not found: {e.filename}") - - def load_model(self, file_path: str): - - # Ensure the filepath follows a standard formatting: - file_path = os.path.normpath(file_path) - - if not os.path.exists(file_path): - return - - file_ext = os.path.splitext(file_path)[-1] - - if file_ext == ".onnx": - self.load_onnx(file_path) - elif file_ext == ".yaml": - self.load_report(file_path) - else: - bad_ext_dialog = StatusDialog( - f"Digest does not support files with the extension {file_ext}", - parent=self, - ) - bad_ext_dialog.show() - def dragEnterEvent(self, event: QDragEnterEvent): if event.mimeData().hasUrls(): event.acceptProposedAction() @@ -963,7 +565,7 @@ def save_reports(self): if not isinstance(current_tab, modelSummary): return - digest_model = current_tab.digest_model + digest_model = self.digest_models[current_tab.model_id] if not digest_model.model_name: print("Warning, digest_model model name not set.") @@ -999,12 +601,12 @@ def save_reports(self): node_type_filepath = os.path.join( save_directory, f"{model_name}_node_type_counts.csv" ) + digest_model.save_node_type_counts_csv_report(node_type_filepath) # Save (copy) the similarity image - png_file_path = self.model_similarity_thread[ - digest_model.unique_id - ].png_filepath + png_file_path = current_tab.png_file_path + png_save_path = os.path.join(save_directory, f"{model_name}_heatmap.png") if png_file_path and os.path.exists(png_file_path): shutil.copy(png_file_path, png_save_path) @@ -1078,7 +680,7 @@ def save_nodes_csv(self, csv_filepath: Optional[str], open_dialog: bool = True): raise ValueError("A filepath must be given.") current_tab = self.ui.tabWidget.currentWidget() if isinstance(current_tab, modelSummary): - current_tab.digest_model.save_nodes_csv_report(csv_filepath) + self.digest_models[current_tab.model_id].save_nodes_csv_report(csv_filepath) def save_chart(self, chart_view): path, _ = self.save_file_dialog("Save PNG", "PNG(*.png)") @@ -1095,7 +697,7 @@ def open_node_summary(self): return model_name = current_tab.ui.modelName.text() - model_id = current_tab.digest_model.unique_id + model_id = current_tab.model_id if model_id in self.nodes_window: del self.nodes_window[model_id] @@ -1114,14 +716,6 @@ def open_node_summary(self): self.nodes_window[model_id].open() def closeEvent(self, event): - for thread in self.model_nodes_stats_thread.values(): - thread.quit() # Request the thread to stop - thread.wait(5000) # Wait for the thread to finish - - for thread in self.model_similarity_thread.values(): - thread.quit() # Request the thread to stop - thread.wait(5000) # Wait for the thread to finish - for window in QApplication.topLevelWidgets(): if window != self: window.close() diff --git a/src/digest/model_class/digest_onnx_model.py b/src/digest/model_class/digest_onnx_model.py index 8c8dd7f..7151628 100644 --- a/src/digest/model_class/digest_onnx_model.py +++ b/src/digest/model_class/digest_onnx_model.py @@ -1,7 +1,9 @@ # Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved. +# pylint: disable=no-name-in-module import os from typing import List, Dict, Optional, Tuple, cast +from PySide6.QtCore import QRunnable, Signal, Slot, QObject from datetime import datetime import importlib.metadata from collections import OrderedDict @@ -654,3 +656,44 @@ def save_text_report(self, filepath: str) -> None: f_p.write("Output Tensor(s) Information:\n") f_p.write(output_table.get_string()) f_p.write("\n\n") + + +class WorkerSignals(QObject): + completed = Signal(DigestOnnxModel) + + +class LoadDigestOnnxModelWorker(QRunnable): + + def __init__( + self, + model_file_path: str, + model_name: str, + ): + super().__init__() + self.signals = WorkerSignals() + self.tab_name = model_name + self.model_file_path = model_file_path + self.unique_id: Optional[str] = None + + @Slot() + def run(self): + try: + model_proto = onnx_utils.load_onnx( + self.model_file_path, load_external_data=False + ) + opt_model, _ = onnx_utils.optimize_onnx_model(model_proto) + except FileNotFoundError as e: + print(f"File not found: {e.filename}") + + digest_model = DigestOnnxModel( + opt_model, + model_name=self.tab_name, + onnx_filepath=self.model_file_path, + ) + + self.unique_id = digest_model.unique_id + + if not self.tab_name: + self.tab_name = digest_model.model_name + + self.signals.completed.emit(digest_model) diff --git a/src/digest/model_class/digest_report_model.py b/src/digest/model_class/digest_report_model.py index c84ef20..1f5daa1 100644 --- a/src/digest/model_class/digest_report_model.py +++ b/src/digest/model_class/digest_report_model.py @@ -4,6 +4,7 @@ import csv import ast import re +from PySide6.QtCore import QRunnable, Signal, Slot, QObject from typing import Tuple, Optional, List, Dict, Any, Union import yaml from digest.model_class.digest_model import ( @@ -153,6 +154,30 @@ def save_text_report(self, filepath: str) -> None: return +class WorkerSignals(QObject): + completed = Signal(DigestReportModel) + + +class LoadDigestReportModelWorker(QRunnable): + + def __init__( + self, + model_file_path: str, + model_name: str, + ): + super().__init__() + self.signals = WorkerSignals() + self.tab_name = model_name + self.model_file_path = model_file_path + self.unique_id: Optional[str] = None + + @Slot() + def run(self): + + digest_model = DigestReportModel(self.model_file_path) + self.signals.completed.emit(digest_model) + + def validate_yaml(report_file_path: str) -> bool: """Check that the provided yaml file is indeed a Digest Report file.""" expected_keys = [ diff --git a/src/digest/model_load.py b/src/digest/model_load.py new file mode 100644 index 0000000..9c42c69 --- /dev/null +++ b/src/digest/model_load.py @@ -0,0 +1,247 @@ +import os + + +def load_onnx(filepath: str): + + # Ensure the filepath follows a standard formatting: + filepath = os.path.normpath(filepath) + + if not os.path.exists(filepath): + return + + # Before opening the file, check to see if it is already opened. + for index in range(self.ui.tabWidget.count()): + widget = self.ui.tabWidget.widget(index) + if isinstance(widget, modelSummary) and filepath == widget.file: + self.ui.tabWidget.setCurrentIndex(index) + return + + self.load_progress = ProgressDialog("Loading & Optimizing ONNX Model...", 3, self) + self.load_progress.step() + + self.load_progress.setLabelText( + "Creating a Digest model. Please be patient as this might take a minute." + ) + + basename = os.path.splitext(os.path.basename(filepath)) + model_name = basename[0] + + # Load the digest onnx model on a separate thread + digest_model_worker = LoadDigestOnnxModelWorker( + model_name=model_name, model_file_path=filepath + ) + + digest_model_worker.signals.completed.connect(self.post_load) + + self.thread_pool.start(digest_model_worker) + + +def load_report(filepath: str): + + # Ensure the filepath follows a standard formatting: + filepath = os.path.normpath(filepath) + + if not os.path.exists(filepath): + return + + # Every time a report is loaded we should emulate a model summary button click + self.summary_clicked() + + # Before opening the file, check to see if it is already opened. + for index in range(self.ui.tabWidget.count()): + widget = self.ui.tabWidget.widget(index) + if isinstance(widget, modelSummary) and filepath == widget.file: + self.ui.tabWidget.setCurrentIndex(index) + return + + try: + + progress = ProgressDialog("Loading Digest Report File...", 2, self) + QApplication.processEvents() # Process pending events + + digest_model = DigestReportModel(filepath) + + if not digest_model.is_valid: + progress.close() + invalid_yaml_dialog = StatusDialog( + title="Warning", + status_message=f"YAML file {filepath} is not a valid digest report", + ) + invalid_yaml_dialog.show() + + return + + model_id = digest_model.unique_id + + # There is no sense in offering to save the report + self.stats_save_button_flag[model_id] = False + self.similarity_save_button_flag[model_id] = False + + self.digest_models[model_id] = digest_model + + model_summary = modelSummary(digest_model) + + self.ui.tabWidget.addTab(model_summary, "") + model_summary.ui.flops.setText("Loading...") + + # Hide some of the components + model_summary.ui.similarityCorrelation.hide() + model_summary.ui.similarityCorrelationStatic.hide() + + model_summary.file = filepath + model_summary.setObjectName(digest_model.model_name) + model_summary.ui.modelName.setText(digest_model.model_name) + model_summary.ui.modelFilename.setText(filepath) + model_summary.ui.generatedDate.setText(datetime.now().strftime("%B %d, %Y")) + + model_summary.ui.parameters.setText(format(digest_model.parameters, ",")) + + node_type_counts = digest_model.node_type_counts + if len(node_type_counts) < 15: + bar_spacing = 40 + else: + bar_spacing = 20 + + model_summary.ui.opHistogramChart.bar_spacing = bar_spacing + model_summary.ui.opHistogramChart.set_data(node_type_counts) + model_summary.ui.nodes.setText(str(sum(node_type_counts.values()))) + + progress.step() + progress.setLabelText("Gathering Model Inputs and Outputs") + + # Inputs Table + model_summary.ui.inputsTable.setRowCount( + len(self.digest_models[model_id].model_inputs) + ) + + for row_idx, (input_name, input_info) in enumerate( + self.digest_models[model_id].model_inputs.items() + ): + model_summary.ui.inputsTable.setItem( + row_idx, 0, QTableWidgetItem(input_name) + ) + model_summary.ui.inputsTable.setItem( + row_idx, 1, QTableWidgetItem(str(input_info.shape)) + ) + model_summary.ui.inputsTable.setItem( + row_idx, 2, QTableWidgetItem(str(input_info.dtype)) + ) + model_summary.ui.inputsTable.setItem( + row_idx, 3, QTableWidgetItem(str(input_info.size_kbytes)) + ) + + model_summary.ui.inputsTable.resizeColumnsToContents() + model_summary.ui.inputsTable.resizeRowsToContents() + + # Outputs Table + model_summary.ui.outputsTable.setRowCount( + len(self.digest_models[model_id].model_outputs) + ) + for row_idx, (output_name, output_info) in enumerate( + self.digest_models[model_id].model_outputs.items() + ): + model_summary.ui.outputsTable.setItem( + row_idx, 0, QTableWidgetItem(output_name) + ) + model_summary.ui.outputsTable.setItem( + row_idx, 1, QTableWidgetItem(str(output_info.shape)) + ) + model_summary.ui.outputsTable.setItem( + row_idx, 2, QTableWidgetItem(str(output_info.dtype)) + ) + model_summary.ui.outputsTable.setItem( + row_idx, 3, QTableWidgetItem(str(output_info.size_kbytes)) + ) + + model_summary.ui.outputsTable.resizeColumnsToContents() + model_summary.ui.outputsTable.resizeRowsToContents() + + progress.step() + progress.setLabelText("Gathering Model Proto Data") + + # ModelProto Info + model_summary.ui.modelProtoTable.setItem( + 0, 1, QTableWidgetItem(str(digest_model.model_data["model_version"])) + ) + + model_summary.ui.modelProtoTable.setItem( + 1, 1, QTableWidgetItem(str(digest_model.model_data["graph_name"])) + ) + + producer_txt = ( + f"{digest_model.model_data['producer_name']} " + f"{digest_model.model_data['producer_version']}" + ) + model_summary.ui.modelProtoTable.setItem(2, 1, QTableWidgetItem(producer_txt)) + + model_summary.ui.modelProtoTable.setItem( + 3, 1, QTableWidgetItem(str(digest_model.model_data["ir_version"])) + ) + + for domain, version in digest_model.model_data["import_list"].items(): + row_idx = model_summary.ui.importsTable.rowCount() + model_summary.ui.importsTable.insertRow(row_idx) + if domain == "" or domain == "ai.onnx": + model_summary.ui.opsetVersion.setText(str(version)) + domain = "ai.onnx" + + model_summary.ui.importsTable.setItem( + row_idx, 0, QTableWidgetItem(str(domain)) + ) + model_summary.ui.importsTable.setItem( + row_idx, 1, QTableWidgetItem(str(version)) + ) + row_idx += 1 + + progress.step() + progress.setLabelText("Wrapping Up Model Analysis") + + model_summary.ui.importsTable.resizeColumnsToContents() + model_summary.ui.modelProtoTable.resizeColumnsToContents() + model_summary.setObjectName(digest_model.model_name) + new_tab_idx = self.ui.tabWidget.count() - 1 + self.ui.tabWidget.setTabText(new_tab_idx, "".join(digest_model.model_name)) + self.ui.tabWidget.setCurrentIndex(new_tab_idx) + self.ui.stackedWidget.setCurrentIndex(self.Page.SUMMARY) + self.ui.singleModelWidget.show() + progress.step() + + # self.update_cards(digest_model.unique_id) + + movie = QMovie(":/assets/gifs/load.gif") + model_summary.ui.similarityImg.setMovie(movie) + movie.start() + + self.update_similarity_widget( + completed_successfully=bool(digest_model.similarity_heatmap_path), + model_id=digest_model.unique_id, + most_similar="", + png_filepath=digest_model.similarity_heatmap_path, + ) + + progress.close() + + except FileNotFoundError as e: + print(f"File not found: {e.filename}") + + +def load_model(file_path: str): + + # Ensure the filepath follows a standard formatting: + file_path = os.path.normpath(file_path) + + if not os.path.exists(file_path): + return + + file_ext = os.path.splitext(file_path)[-1] + + if file_ext == ".onnx": + self.load_onnx(file_path) + elif file_ext == ".yaml": + self.load_report(file_path) + else: + bad_ext_dialog = StatusDialog( + f"Digest does not support files with the extension {file_ext}", + parent=self, + ) + bad_ext_dialog.show() diff --git a/src/digest/modelsummary.py b/src/digest/modelsummary.py index a92b756..c9b9591 100644 --- a/src/digest/modelsummary.py +++ b/src/digest/modelsummary.py @@ -1,12 +1,13 @@ # Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved. import os +from datetime import datetime # pylint: disable=invalid-name from typing import Optional, Union # pylint: disable=no-name-in-module -from PySide6.QtWidgets import QWidget +from PySide6.QtWidgets import QWidget, QTableWidgetItem from PySide6.QtGui import QMovie from PySide6.QtCore import QSize @@ -16,6 +17,7 @@ from digest.freeze_inputs import FreezeInputs from digest.popup_window import PopupWindow from digest.qt_utils import apply_dark_style_sheet +from digest.model_class.digest_model import SupportedModelTypes, DigestModel from digest.model_class.digest_onnx_model import DigestOnnxModel from digest.model_class.digest_report_model import DigestReportModel @@ -25,9 +27,10 @@ class modelSummary(QWidget): - def __init__( - self, digest_model: Union[DigestOnnxModel, DigestReportModel], parent=None - ): + # def __init__( + # self, digest_model: Union[DigestOnnxModel, DigestReportModel], parent=None + # ): + def __init__(self, digest_model: DigestModel, parent=None): super().__init__(parent) self.ui = Ui_modelSummary() self.ui.setupUi(self) @@ -35,10 +38,11 @@ def __init__( self.file: Optional[str] = None self.ui.warningLabel.hide() - self.digest_model = digest_model + self.model_id = digest_model.unique_id self.model_proto: Optional[ModelProto] = None model_name: str = digest_model.model_name if digest_model.model_name else "" + self.png_file_path: Optional[str] = None self.load_gif = QMovie(":/assets/gifs/load.gif") # We set the size of the GIF to half the original self.load_gif.setScaledSize(QSize(214, 120)) @@ -50,13 +54,137 @@ def __init__( self.freeze_inputs: Optional[FreezeInputs] = None self.freeze_window: Optional[QWidget] = None + self.model_type: Optional[SupportedModelTypes] = None + if isinstance(digest_model, DigestOnnxModel): + self.model_type = SupportedModelTypes.ONNX self.model_proto = ( digest_model.model_proto if digest_model.model_proto else ModelProto() ) self.freeze_inputs = FreezeInputs(self.model_proto, model_name) self.ui.freezeButton.clicked.connect(self.open_freeze_inputs) self.freeze_inputs.complete_signal.connect(self.close_freeze_window) + elif isinstance(digest_model, DigestReportModel): + self.model_type = SupportedModelTypes.REPORT + + # Hide some of the components + self.ui.similarityCorrelation.hide() + self.ui.similarityCorrelationStatic.hide() + + self.file = digest_model.filepath + self.setObjectName(model_name) + self.ui.modelName.setText(model_name) + if self.file: + self.ui.modelFilename.setText(self.file) + + self.ui.generatedDate.setText(datetime.now().strftime("%B %d, %Y")) + + self.ui.parameters.setText(format(digest_model.parameters, ",")) + + node_type_counts = digest_model.node_type_counts + if len(node_type_counts) < 15: + bar_spacing = 40 + else: + bar_spacing = 20 + self.ui.opHistogramChart.bar_spacing = bar_spacing + self.ui.opHistogramChart.set_data(node_type_counts) + self.ui.nodes.setText(str(sum(node_type_counts.values()))) + + flops_str = format(digest_model.flops, ",") + self.ui.flops.setText(flops_str) + + # Set up the FLOPs pie chart + pie_chart_labels, pie_chart_data = zip(*digest_model.node_type_flops.items()) + self.ui.flopsPieChart.set_data( + "FLOPs Intensity Per Op Type", + pie_chart_labels, + pie_chart_data, + ) + + # Set up the params pie chart + pie_chart_labels, pie_chart_data = zip( + *digest_model.node_type_parameters.items() + ) + self.ui.parametersPieChart.set_data( + "Parameter Intensity Per Op Type", + pie_chart_labels, + pie_chart_data, + ) + + # Inputs Table + self.ui.inputsTable.setRowCount(len(digest_model.model_inputs)) + + for row_idx, (input_name, input_info) in enumerate( + digest_model.model_inputs.items() + ): + self.ui.inputsTable.setItem(row_idx, 0, QTableWidgetItem(input_name)) + self.ui.inputsTable.setItem( + row_idx, 1, QTableWidgetItem(str(input_info.shape)) + ) + self.ui.inputsTable.setItem( + row_idx, 2, QTableWidgetItem(str(input_info.dtype)) + ) + self.ui.inputsTable.setItem( + row_idx, 3, QTableWidgetItem(str(input_info.size_kbytes)) + ) + + self.ui.inputsTable.resizeColumnsToContents() + self.ui.inputsTable.resizeRowsToContents() + + # Outputs Table + self.ui.outputsTable.setRowCount(len(digest_model.model_outputs)) + for row_idx, (output_name, output_info) in enumerate( + digest_model.model_outputs.items() + ): + self.ui.outputsTable.setItem(row_idx, 0, QTableWidgetItem(output_name)) + self.ui.outputsTable.setItem( + row_idx, 1, QTableWidgetItem(str(output_info.shape)) + ) + self.ui.outputsTable.setItem( + row_idx, 2, QTableWidgetItem(str(output_info.dtype)) + ) + self.ui.outputsTable.setItem( + row_idx, 3, QTableWidgetItem(str(output_info.size_kbytes)) + ) + + self.ui.outputsTable.resizeColumnsToContents() + self.ui.outputsTable.resizeRowsToContents() + + if isinstance(digest_model, DigestOnnxModel): + + if digest_model.model_version: + # ModelProto Info + self.ui.modelProtoTable.setItem( + 0, 1, QTableWidgetItem(digest_model.model_version) + ) + + if digest_model.graph_name: + self.ui.modelProtoTable.setItem( + 1, 1, QTableWidgetItem(digest_model.graph_name) + ) + + producer_txt = ( + f"{digest_model.producer_name} {digest_model.producer_version}" + ) + self.ui.modelProtoTable.setItem(2, 1, QTableWidgetItem(producer_txt)) + + self.ui.modelProtoTable.setItem( + 3, 1, QTableWidgetItem(str(digest_model.ir_version)) + ) + + for domain, version in digest_model.imports.items(): + row_idx = self.ui.importsTable.rowCount() + self.ui.importsTable.insertRow(row_idx) + if domain == "" or domain == "ai.onnx": + self.ui.opsetVersion.setText(str(version)) + domain = "ai.onnx" + self.ui.importsTable.setItem(row_idx, 0, QTableWidgetItem(domain)) + self.ui.importsTable.setItem(row_idx, 1, QTableWidgetItem(str(version))) + row_idx += 1 + + self.ui.importsTable.resizeColumnsToContents() + self.ui.modelProtoTable.resizeColumnsToContents() + self.setObjectName(model_name) def open_freeze_inputs(self): if self.freeze_inputs: diff --git a/src/digest/qt_utils.py b/src/digest/qt_utils.py index 1015844..e2b3863 100644 --- a/src/digest/qt_utils.py +++ b/src/digest/qt_utils.py @@ -6,6 +6,7 @@ # pylint: disable=no-name-in-module from PySide6.QtWidgets import QWidget, QApplication +from PySide6.QtCore import QThread, QEventLoop, QTimer from PySide6.QtCore import QFile, QTextStream from digest.dialog import StatusDialog @@ -14,6 +15,31 @@ BASE_STYLE_FILE = os.path.join(ROOT_FOLDER, "styles", "darkstyle.qss") +def wait_threads(threads: List[QThread], timeout=10000) -> bool: + + loop = QEventLoop() + timer = QTimer() + timer.setSingleShot(True) + timer.timeout.connect(loop.quit) + + def check_threads(): + if all(thread.isFinished() for thread in threads): + loop.quit() + + check_timer = QTimer() + check_timer.timeout.connect(check_threads) + check_timer.start(100) # Check every 100ms + + timer.start(timeout) + loop.exec() + + check_timer.stop() + timer.stop() + + # Return True if all threads finished, False if timed out + return all(thread.isFinished() for thread in threads) + + def get_ram_utilization() -> float: mem = psutil.virtual_memory() ram_util_perc = mem.percent diff --git a/src/digest/thread.py b/src/digest/similarity_analysis.py similarity index 58% rename from src/digest/thread.py rename to src/digest/similarity_analysis.py index bf9c546..1e48e3d 100644 --- a/src/digest/thread.py +++ b/src/digest/similarity_analysis.py @@ -3,83 +3,29 @@ # pylint: disable=no-name-in-module import os from typing import List, Optional -from PySide6.QtCore import QThread, Signal, QEventLoop, QTimer +from PySide6.QtCore import Signal, QRunnable, QObject import matplotlib.pyplot as plt import numpy as np import pandas as pd -from digest.model_class.digest_onnx_model import DigestOnnxModel from digest.subgraph_analysis.find_match import find_match -def wait_threads(threads: List[QThread], timeout=10000) -> bool: +class WorkerSignals(QObject): + completed = Signal(bool, str, str, str, pd.DataFrame) - loop = QEventLoop() - timer = QTimer() - timer.setSingleShot(True) - timer.timeout.connect(loop.quit) - def check_threads(): - if all(thread.isFinished() for thread in threads): - loop.quit() - - check_timer = QTimer() - check_timer.timeout.connect(check_threads) - check_timer.start(100) # Check every 100ms - - timer.start(timeout) - loop.exec() - - check_timer.stop() - timer.stop() - - # Return True if all threads finished, False if timed out - return all(thread.isFinished() for thread in threads) - - -class StatsThread(QThread): - - completed = Signal(DigestOnnxModel, str) +class SimilarityWorker(QRunnable): def __init__( self, - model=None, - tab_name: Optional[str] = None, - unique_id: Optional[str] = None, - ): - super().__init__() - self.model = model - self.tab_name = tab_name - self.unique_id = unique_id - - def run(self): - if not self.model: - raise ValueError("You must specify a model.") - if not self.tab_name: - raise ValueError("You must specify a tab name.") - if not self.unique_id: - raise ValueError("You must specify a unique id.") - - digest_model = DigestOnnxModel(self.model, save_proto=False) - - self.completed.emit(digest_model, self.unique_id) - - def wait(self, timeout=10000): - wait_threads([self], timeout) - - -class SimilarityThread(QThread): - - completed_successfully = Signal(bool, str, str, str, pd.DataFrame) - - def __init__( - self, - model_filepath: Optional[str] = None, - png_filepath: Optional[str] = None, + model_file_path: Optional[str] = None, + png_file_path: Optional[str] = None, model_id: Optional[str] = None, ): super().__init__() - self.model_filepath = model_filepath - self.png_filepath = png_filepath + self.signals = WorkerSignals() + self.model_filepath = model_file_path + self.png_filepath = png_file_path self.model_id = model_id def run(self): @@ -99,19 +45,16 @@ def run(self): most_similar = [os.path.basename(path) for path in most_similar] # We convert List[str] to str to send through the signal most_similar = ",".join(most_similar) - self.completed_successfully.emit( + self.signals.completed.emit( True, self.model_id, most_similar, self.png_filepath, df_sorted ) except Exception as e: # pylint: disable=broad-exception-caught most_similar = "" - self.completed_successfully.emit( + self.signals.completed.emit( False, self.model_id, most_similar, self.png_filepath, df_sorted ) print(f"Issue creating similarity analysis: {e}") - def wait(self, timeout=10000): - wait_threads([self], timeout) - def post_process( model_name: str, From f54d7414f496aab1820f90a362b23eb25315bba6 Mon Sep 17 00:00:00 2001 From: pcolange Date: Tue, 11 Feb 2025 17:28:50 -0500 Subject: [PATCH 02/11] Tests now support new model loading scheme --- src/digest/main.py | 238 ++++++++++++++++++++++++--------------------- test/test_gui.py | 210 ++++++++++++++++----------------------- 2 files changed, 212 insertions(+), 236 deletions(-) diff --git a/src/digest/main.py b/src/digest/main.py index ceef362..ac95c64 100644 --- a/src/digest/main.py +++ b/src/digest/main.py @@ -5,7 +5,6 @@ import sys import shutil import argparse -from datetime import datetime from typing import Dict, Tuple, Optional import tempfile from enum import IntEnum @@ -22,7 +21,6 @@ QApplication, QFileDialog, QPushButton, - QTableWidgetItem, QMainWindow, QLabel, QTextEdit, @@ -34,7 +32,7 @@ QSizePolicy, QMenu, ) -from PySide6.QtGui import QDragEnterEvent, QDropEvent, QPixmap, QMovie, QIcon, QFont +from PySide6.QtGui import QDragEnterEvent, QDropEvent, QPixmap, QIcon, QFont from PySide6.QtCore import Qt, QSize, QThreadPool from digest.dialog import StatusDialog, InfoDialog, WarnDialog, ProgressDialog @@ -57,7 +55,10 @@ ) from utils import onnx_utils -GUI_CONFIG = os.path.join(os.path.dirname(__file__), "gui_config.yaml") + +class DigestConfig: + GUI_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "gui_config.yaml") + SUPPORTED_EXTENSIONS = [".onnx", ".yaml"] class SimilarityAnalysisReport(QMainWindow): @@ -154,6 +155,12 @@ def copy_chart_to_clipboard(self): QApplication.clipboard().setPixmap(pixmap) +class ModelLoadError(Exception): + """Raised when there's an error loading a model.""" + + pass + + class DigestApp(QMainWindow): class Page(IntEnum): @@ -165,15 +172,15 @@ class Page(IntEnum): def __init__(self, model_file: Optional[str] = None): super(DigestApp, self).__init__() - self.ui = Ui_MainWindow() + self.ui: Ui_MainWindow = Ui_MainWindow() self.ui.setupUi(self) - self.thread_pool = QThreadPool() + self.thread_pool: QThreadPool = QThreadPool() self.nodes_window: Dict[str, PopupWindow] = {} self.status_dialog: Optional[StatusDialog] = None - self.err_open_dialog = None - self.temp_dir = tempfile.TemporaryDirectory() + self.err_open_dialog: Optional[StatusDialog] = None + self.temp_dir: tempfile.TemporaryDirectory = tempfile.TemporaryDirectory() self.digest_models: Dict[str, DigestModel] = {} self.model_similarity_report: Dict[str, SimilarityAnalysisReport] = {} @@ -202,7 +209,7 @@ def __init__(self, model_file: Optional[str] = None): self.load_progress: Optional[ProgressDialog] = None enable_huggingface_model = True - with open(GUI_CONFIG, "r", encoding="utf-8") as f: + with open(DigestConfig.GUI_CONFIG_PATH, "r", encoding="utf-8") as f: config = yaml.safe_load(f) enable_huggingface_model = config["modules"]["huggingface"] @@ -384,58 +391,55 @@ def update_similarity_widget( ): self.ui.saveBtn.setEnabled(True) - def load_model(self, file_path: str): + def load_model(self, file_path: str) -> None: + try: + file_path = os.path.normpath(file_path) + if not os.path.exists(file_path): + raise ModelLoadError(f"Model file {file_path} does not exist.") - # Ensure the filepath follows a standard formatting: - file_path = os.path.normpath(file_path) + basename, file_ext = os.path.splitext(os.path.basename(file_path)) + supported_exts = [".onnx", ".yaml"] - if not os.path.exists(file_path): - self.status_dialog = StatusDialog( - f"Model file {file_path} does not exist.", - parent=self, - ) - self.status_dialog.show() - return + if file_ext not in supported_exts: + raise ModelLoadError( + f"Digest does not support files with the extension {file_ext}" + ) - basename, file_ext = os.path.splitext(os.path.basename(file_path)) + # Before opening the file, check to see if it is already opened. + for index in range(self.ui.tabWidget.count()): + widget = self.ui.tabWidget.widget(index) + if isinstance(widget, modelSummary) and file_path == widget.file: + self.ui.tabWidget.setCurrentIndex(index) + return - supported_exts = [".onnx", ".yaml"] + self.load_progress = ProgressDialog("Loading Model...", 3, self) + self.load_progress.step() - if not file_ext in supported_exts: - self.status_dialog = StatusDialog( - f"Digest does not support files with the extension {file_ext}", - parent=self, + self.load_progress.setLabelText( + "Creating a Digest model. Please be patient as this could take a minute." ) - self.status_dialog.show() - return - - # Before opening the file, check to see if it is already opened. - for index in range(self.ui.tabWidget.count()): - widget = self.ui.tabWidget.widget(index) - if isinstance(widget, modelSummary) and file_path == widget.file: - self.ui.tabWidget.setCurrentIndex(index) - return - self.load_progress = ProgressDialog("Loading Model...", 3, self) - self.load_progress.step() + if file_ext == ".onnx": + # Load the digest onnx model on a separate thread + digest_model_worker = LoadDigestOnnxModelWorker( + model_name=basename, model_file_path=file_path + ) + elif file_ext == ".yaml": + digest_model_worker = LoadDigestReportModelWorker( + model_name=basename, model_file_path=file_path + ) - self.load_progress.setLabelText( - "Creating a Digest model. Please be patient as this could take a minute." - ) + digest_model_worker.signals.completed.connect(self.post_load_model) - if file_ext == ".onnx": - # Load the digest onnx model on a separate thread - digest_model_worker = LoadDigestOnnxModelWorker( - model_name=basename, model_file_path=file_path - ) - elif file_ext == ".yaml": - digest_model_worker = LoadDigestReportModelWorker( - model_name=basename, model_file_path=file_path + self.thread_pool.start(digest_model_worker) + except ModelLoadError as e: + self.status_dialog = StatusDialog(str(e), parent=self) + self.status_dialog.show() + except Exception as e: + self.status_dialog = StatusDialog( + f"Unexpected error loading model: {str(e)}", parent=self ) - - digest_model_worker.signals.completed.connect(self.post_load_model) - - self.thread_pool.start(digest_model_worker) + self.status_dialog.show() def post_load_model(self, digest_model: DigestModel): """This function is automatically run after the model load workers are finished""" @@ -565,79 +569,79 @@ def save_reports(self): if not isinstance(current_tab, modelSummary): return - digest_model = self.digest_models[current_tab.model_id] - if not digest_model.model_name: - print("Warning, digest_model model name not set.") - - model_name = str(digest_model.model_name) - - save_directory = QFileDialog(self).getExistingDirectory( - self, "Select Directory" - ) - + save_directory = self._get_save_directory() if not save_directory: return - # Check if the directory exists and is writable - if not os.path.exists(save_directory) or not os.access(save_directory, os.W_OK): + try: + self._save_report_files(current_tab, save_directory) + except Exception as exception: + self._handle_save_error(exception) + else: + self._show_save_success(save_directory) + + def _get_save_directory(self) -> Optional[str]: + """Get and validate the save directory from user.""" + directory = QFileDialog(self).getExistingDirectory(self, "Select Directory") + if not directory: + return None + + if not os.path.exists(directory) or not os.access(directory, os.W_OK): self.show_warning_dialog( - f"The directory {save_directory} is not valid or writable." + f"The directory {directory} is not valid or writable." ) + return None - save_directory = os.path.join( - save_directory, str(digest_model.model_name) + "_reports" + return directory + + def _save_report_files(self, current_tab, save_directory): + model_name = current_tab.ui.modelName.text() + + # Save the node histogram image + node_histogram = current_tab.ui.opHistogramChart.grab() + node_histogram.save( + os.path.join(save_directory, f"{model_name}_histogram.png"), "PNG" ) - try: - os.makedirs(save_directory, exist_ok=True) + # Save csv of node type counts + node_type_filepath = os.path.join( + save_directory, f"{model_name}_node_type_counts.csv" + ) - # Save the node histogram image - node_histogram = current_tab.ui.opHistogramChart.grab() - node_histogram.save( - os.path.join(save_directory, f"{model_name}_histogram.png"), "PNG" - ) + self.digest_models[current_tab.model_id].save_node_type_counts_csv_report( + node_type_filepath + ) - # Save csv of node type counts - node_type_filepath = os.path.join( - save_directory, f"{model_name}_node_type_counts.csv" - ) + # Save (copy) the similarity image + png_file_path = current_tab.png_file_path - digest_model.save_node_type_counts_csv_report(node_type_filepath) + png_save_path = os.path.join(save_directory, f"{model_name}_heatmap.png") + if png_file_path and os.path.exists(png_file_path): + shutil.copy(png_file_path, png_save_path) - # Save (copy) the similarity image - png_file_path = current_tab.png_file_path + # Save the text report + txt_report_filepath = os.path.join(save_directory, f"{model_name}_report.txt") + self.digest_models[current_tab.model_id].save_text_report(txt_report_filepath) - png_save_path = os.path.join(save_directory, f"{model_name}_heatmap.png") - if png_file_path and os.path.exists(png_file_path): - shutil.copy(png_file_path, png_save_path) + # Save the yaml report + yaml_report_filepath = os.path.join(save_directory, f"{model_name}_report.yaml") + self.digest_models[current_tab.model_id].save_yaml_report(yaml_report_filepath) - # Save the text report - txt_report_filepath = os.path.join( - save_directory, f"{model_name}_report.txt" - ) - digest_model.save_text_report(txt_report_filepath) + # Save the node list + nodes_report_filepath = os.path.join(save_directory, f"{model_name}_nodes.csv") - # Save the yaml report - yaml_report_filepath = os.path.join( - save_directory, f"{model_name}_report.yaml" - ) - digest_model.save_yaml_report(yaml_report_filepath) + self.save_nodes_csv(nodes_report_filepath, False) - # Save the node list - nodes_report_filepath = os.path.join( - save_directory, f"{model_name}_nodes.csv" - ) + def _handle_save_error(self, exception): + self.status_dialog = StatusDialog(f"{exception}") + self.status_dialog.show() - self.save_nodes_csv(nodes_report_filepath, False) - except Exception as exception: # pylint: disable=broad-exception-caught - self.status_dialog = StatusDialog(f"{exception}") - self.status_dialog.show() - else: - self.status_dialog = StatusDialog( - f"Saved reports to: \n{os.path.abspath(save_directory)}", - "Successfully saved reports!", - ) - self.status_dialog.show() + def _show_save_success(self, save_directory): + self.status_dialog = StatusDialog( + f"Saved reports to: \n{os.path.abspath(save_directory)}", + "Successfully saved reports!", + ) + self.status_dialog.show() def on_dialog_closed(self): self.infoDialog = None @@ -716,11 +720,23 @@ def open_node_summary(self): self.nodes_window[model_id].open() def closeEvent(self, event): - for window in QApplication.topLevelWidgets(): - if window != self: - window.close() + """Ensure proper cleanup of resources when closing the application.""" + try: + # Close all child windows + for window in QApplication.topLevelWidgets(): + if window != self: + window.close() + + # Cleanup temporary directory + if hasattr(self, "temp_dir"): + self.temp_dir.cleanup() + + # Wait for thread pool to finish + if hasattr(self, "thread_pool"): + self.thread_pool.waitForDone() - super().closeEvent(event) + finally: + super().closeEvent(event) def main(): diff --git a/test/test_gui.py b/test/test_gui.py index 59fbb8f..5857c94 100644 --- a/test/test_gui.py +++ b/test/test_gui.py @@ -24,6 +24,7 @@ class DigestGuiTest(unittest.TestCase): TEST_DIR, f"{MODEL_BASENAME}_reports", f"{MODEL_BASENAME}_report.yaml" ) ) + THREAD_TIMEOUT = 10000 # milliseconds @classmethod def setUpClass(cls): @@ -39,66 +40,64 @@ def tearDownClass(cls): def setUp(self): self.digest_app = digest.main.DigestApp() self.digest_app.show() + self.initial_tab_count = self.digest_app.ui.tabWidget.count() + self.addCleanup(self.digest_app.close) def tearDown(self): self.digest_app.close() - - def wait_all_threads(self, timeout=10000) -> bool: - all_threads = list(self.digest_app.model_nodes_stats_thread.values()) + list( - self.digest_app.model_similarity_thread.values() + QApplication.processEvents() # Ensure all pending events are processed + self.digest_app = None + super().tearDown() + + def wait_all_threads(self, timeout_ms=None) -> bool: + """Wait for all tasks in the thread pool to complete.""" + timeout_ms = timeout_ms or self.THREAD_TIMEOUT + QApplication.processEvents() # Ensure pending events are processed + return self.digest_app.thread_pool.waitForDone(timeout_ms) + + def _mock_file_open(self, mock_dialog, filepath): + """Helper to mock file open dialog""" + mock_dialog.return_value = (filepath, "") + QTest.mouseClick(self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton) + self.assertTrue(self.wait_all_threads()) + + def _verify_tab_added(self): + """Verify that exactly one new tab was added""" + # Process events to ensure UI updates + QApplication.processEvents() + self.assertEqual( + self.digest_app.ui.tabWidget.count(), + self.initial_tab_count + 1, + "Expected one new tab to be added", ) - for thread in all_threads: - thread.wait(timeout) - - # Return True if all threads finished, False if timed out - return all(thread.isFinished() for thread in all_threads) + def _close_current_tab(self): + """Close the most recently opened tab""" + current_tab = self.digest_app.ui.tabWidget.count() - 1 + self.digest_app.closeTab(current_tab) def test_open_valid_onnx(self): + """Test that opening a valid ONNX file creates a new tab in the UI.""" with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog: - mock_dialog.return_value = ( - self.ONNX_FILEPATH, - "", - ) - - num_tabs_prior = self.digest_app.ui.tabWidget.count() - - QTest.mouseClick(self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton) - - self.assertTrue(self.wait_all_threads()) - - self.assertTrue( - self.digest_app.ui.tabWidget.count() == num_tabs_prior + 1 - ) # Check if a tab was added - - self.digest_app.closeTab(num_tabs_prior) + self._mock_file_open(mock_dialog, self.ONNX_FILEPATH) + self._verify_tab_added() + self._close_current_tab() def test_open_valid_yaml(self): + """Test that opening a valid YAML report file creates a new tab in the UI.""" with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog: - mock_dialog.return_value = ( - self.YAML_FILEPATH, - "", - ) - - num_tabs_prior = self.digest_app.ui.tabWidget.count() - - QTest.mouseClick(self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton) - - self.assertTrue(self.wait_all_threads()) - - self.assertTrue( - self.digest_app.ui.tabWidget.count() == num_tabs_prior + 1 - ) # Check if a tab was added - - self.digest_app.closeTab(num_tabs_prior) + self._mock_file_open(mock_dialog, self.YAML_FILEPATH) + self._verify_tab_added() + self._close_current_tab() def test_open_invalid_file(self): with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog: - mock_dialog.return_value = ("invalid_file.txt", "") - num_tabs_prior = self.digest_app.ui.tabWidget.count() - QTest.mouseClick(self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton) - self.assertTrue(self.wait_all_threads()) - self.assertEqual(self.digest_app.ui.tabWidget.count(), num_tabs_prior) + self._mock_file_open(mock_dialog, "invalid_file.txt") + self.assertEqual( + self.digest_app.ui.tabWidget.count(), + self.initial_tab_count, + "No new tab should be added for invalid file", + ) def test_save_reports(self): with patch( @@ -107,65 +106,28 @@ def test_save_reports(self): "PySide6.QtWidgets.QFileDialog.getExistingDirectory" ) as mock_save_dialog: - mock_open_dialog.return_value = (self.ONNX_FILEPATH, "") with tempfile.TemporaryDirectory() as tmpdirname: mock_save_dialog.return_value = tmpdirname + self._mock_file_open(mock_open_dialog, self.ONNX_FILEPATH) - QTest.mouseClick( - self.digest_app.ui.openFileBtn, - Qt.MouseButton.LeftButton, - ) - - self.assertTrue(self.wait_all_threads()) + # Process any pending events and wait for threads + QApplication.processEvents() self.assertTrue( - self.digest_app.ui.saveBtn.isEnabled(), "Save button is disabled!" + self.wait_all_threads(), + "Background tasks did not complete within the specified timeout", ) - QTest.mouseClick(self.digest_app.ui.saveBtn, Qt.MouseButton.LeftButton) + # Process events again after threads complete + QApplication.processEvents() - mock_save_dialog.assert_called_once() + # Add debug information + self._print_debug_info() - result_basepath = os.path.join( - tmpdirname, f"{self.MODEL_BASENAME}_reports" - ) - - # Text report test - text_report_filepath = os.path.join( - result_basepath, f"{self.MODEL_BASENAME}_report.txt" - ) self.assertTrue( - os.path.isfile(text_report_filepath), - f"{text_report_filepath} not found!", - ) - - # YAML report test - yaml_report_filepath = os.path.join( - result_basepath, f"{self.MODEL_BASENAME}_report.yaml" - ) - self.assertTrue(os.path.isfile(yaml_report_filepath)) - - # Nodes test - nodes_csv_report_filepath = os.path.join( - result_basepath, f"{self.MODEL_BASENAME}_nodes.csv" - ) - self.assertTrue(os.path.isfile(nodes_csv_report_filepath)) - - # Histogram test - histogram_filepath = os.path.join( - result_basepath, f"{self.MODEL_BASENAME}_histogram.png" - ) - self.assertTrue(os.path.isfile(histogram_filepath)) - - # Heatmap test - heatmap_filepath = os.path.join( - result_basepath, f"{self.MODEL_BASENAME}_heatmap.png" + self.digest_app.ui.saveBtn.isEnabled(), + "Save button should be enabled after loading file", ) - self.assertTrue(os.path.isfile(heatmap_filepath)) - - num_tabs = self.digest_app.ui.tabWidget.count() - self.assertTrue(num_tabs == 1) - self.digest_app.closeTab(0) def test_save_tables(self): with patch( @@ -174,47 +136,45 @@ def test_save_tables(self): "PySide6.QtWidgets.QFileDialog.getSaveFileName" ) as mock_save_dialog: - mock_open_dialog.return_value = (self.ONNX_FILEPATH, "") with tempfile.TemporaryDirectory() as tmpdirname: - mock_save_dialog.return_value = ( - os.path.join(tmpdirname, f"{self.MODEL_BASENAME}_nodes.csv"), - "", - ) - - QTest.mouseClick( - self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton + output_file = os.path.join( + tmpdirname, f"{self.MODEL_BASENAME}_nodes.csv" ) + mock_save_dialog.return_value = (output_file, "") - self.assertTrue(self.wait_all_threads()) + self._mock_file_open(mock_open_dialog, self.ONNX_FILEPATH) - QTest.mouseClick( - self.digest_app.ui.nodesListBtn, Qt.MouseButton.LeftButton + # Process events and wait for threads before accessing nodes window + QApplication.processEvents() + self.assertTrue( + self.wait_all_threads(), "Threads did not complete in time" ) - # We assume there is only one model loaded - _, node_window = self.digest_app.nodes_window.popitem() - node_summary = node_window.main_window.centralWidget() - - self.assertIsInstance(node_summary, NodeSummary) + self._save_nodes_list(output_file) + self._close_current_tab() - # This line of code seems redundant but we do this to clean pylance - if isinstance(node_summary, NodeSummary): - QTest.mouseClick( - node_summary.ui.saveCsvBtn, Qt.MouseButton.LeftButton - ) + def _save_nodes_list(self, expected_output): + """Helper to handle nodes list saving logic""" + QTest.mouseClick(self.digest_app.ui.nodesListBtn, Qt.MouseButton.LeftButton) - mock_save_dialog.assert_called_once() + # Get the node window and verify it + _, node_window = self.digest_app.nodes_window.popitem() + node_summary = node_window.main_window.centralWidget() + self.assertIsInstance(node_summary, NodeSummary) - self.assertTrue( - os.path.exists( - os.path.join(tmpdirname, f"{self.MODEL_BASENAME}_nodes.csv") - ), - "Nodes csv file not found.", - ) + if isinstance(node_summary, NodeSummary): + QTest.mouseClick(node_summary.ui.saveCsvBtn, Qt.MouseButton.LeftButton) + self.assertTrue( + os.path.exists(expected_output), + f"Nodes csv file not found at {expected_output}", + ) - num_tabs = self.digest_app.ui.tabWidget.count() - self.assertTrue(num_tabs == 1) - self.digest_app.closeTab(0) + def _print_debug_info(self): + """Print debug information about the current UI state.""" + current_tab = self.digest_app.ui.tabWidget.currentWidget() + print(f"Current tab: {current_tab}") + print(f"Tab count: {self.digest_app.ui.tabWidget.count()}") + print(f"Save button enabled: {self.digest_app.ui.saveBtn.isEnabled()}") if __name__ == "__main__": From 6e77423a0f8eb99b7179ba4b42495ab560847d25 Mon Sep 17 00:00:00 2001 From: pcolange Date: Tue, 11 Feb 2025 17:33:30 -0500 Subject: [PATCH 03/11] linting --- src/digest/main.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/digest/main.py b/src/digest/main.py index ac95c64..19a4ad8 100644 --- a/src/digest/main.py +++ b/src/digest/main.py @@ -158,8 +158,6 @@ def copy_chart_to_clipboard(self): class ModelLoadError(Exception): """Raised when there's an error loading a model.""" - pass - class DigestApp(QMainWindow): @@ -419,8 +417,10 @@ def load_model(self, file_path: str) -> None: "Creating a Digest model. Please be patient as this could take a minute." ) + # Initialize worker variable + digest_model_worker = None + if file_ext == ".onnx": - # Load the digest onnx model on a separate thread digest_model_worker = LoadDigestOnnxModelWorker( model_name=basename, model_file_path=file_path ) @@ -429,13 +429,13 @@ def load_model(self, file_path: str) -> None: model_name=basename, model_file_path=file_path ) - digest_model_worker.signals.completed.connect(self.post_load_model) - - self.thread_pool.start(digest_model_worker) + if digest_model_worker is not None: + digest_model_worker.signals.completed.connect(self.post_load_model) + self.thread_pool.start(digest_model_worker) except ModelLoadError as e: self.status_dialog = StatusDialog(str(e), parent=self) self.status_dialog.show() - except Exception as e: + except Exception as e: # pylint: disable=broad-except self.status_dialog = StatusDialog( f"Unexpected error loading model: {str(e)}", parent=self ) @@ -575,7 +575,7 @@ def save_reports(self): try: self._save_report_files(current_tab, save_directory) - except Exception as exception: + except Exception as exception: # pylint: disable=broad-except self._handle_save_error(exception) else: self._show_save_success(save_directory) From fe8612c4c8e125656908f38c1e4e020f30af76b7 Mon Sep 17 00:00:00 2001 From: pcolange Date: Tue, 11 Feb 2025 17:43:30 -0500 Subject: [PATCH 04/11] cleaned up dead code --- src/digest/main.py | 3 +- src/digest/model_load.py | 247 ------------------------------------- src/digest/modelsummary.py | 3 +- src/digest/qt_utils.py | 26 ---- 4 files changed, 3 insertions(+), 276 deletions(-) delete mode 100644 src/digest/model_load.py diff --git a/src/digest/main.py b/src/digest/main.py index 19a4ad8..f9edc97 100644 --- a/src/digest/main.py +++ b/src/digest/main.py @@ -414,7 +414,8 @@ def load_model(self, file_path: str) -> None: self.load_progress.step() self.load_progress.setLabelText( - "Creating a Digest model. Please be patient as this could take a minute." + "Creating a Digest model. " + "Please be patient as this could take a minute." ) # Initialize worker variable diff --git a/src/digest/model_load.py b/src/digest/model_load.py deleted file mode 100644 index 9c42c69..0000000 --- a/src/digest/model_load.py +++ /dev/null @@ -1,247 +0,0 @@ -import os - - -def load_onnx(filepath: str): - - # Ensure the filepath follows a standard formatting: - filepath = os.path.normpath(filepath) - - if not os.path.exists(filepath): - return - - # Before opening the file, check to see if it is already opened. - for index in range(self.ui.tabWidget.count()): - widget = self.ui.tabWidget.widget(index) - if isinstance(widget, modelSummary) and filepath == widget.file: - self.ui.tabWidget.setCurrentIndex(index) - return - - self.load_progress = ProgressDialog("Loading & Optimizing ONNX Model...", 3, self) - self.load_progress.step() - - self.load_progress.setLabelText( - "Creating a Digest model. Please be patient as this might take a minute." - ) - - basename = os.path.splitext(os.path.basename(filepath)) - model_name = basename[0] - - # Load the digest onnx model on a separate thread - digest_model_worker = LoadDigestOnnxModelWorker( - model_name=model_name, model_file_path=filepath - ) - - digest_model_worker.signals.completed.connect(self.post_load) - - self.thread_pool.start(digest_model_worker) - - -def load_report(filepath: str): - - # Ensure the filepath follows a standard formatting: - filepath = os.path.normpath(filepath) - - if not os.path.exists(filepath): - return - - # Every time a report is loaded we should emulate a model summary button click - self.summary_clicked() - - # Before opening the file, check to see if it is already opened. - for index in range(self.ui.tabWidget.count()): - widget = self.ui.tabWidget.widget(index) - if isinstance(widget, modelSummary) and filepath == widget.file: - self.ui.tabWidget.setCurrentIndex(index) - return - - try: - - progress = ProgressDialog("Loading Digest Report File...", 2, self) - QApplication.processEvents() # Process pending events - - digest_model = DigestReportModel(filepath) - - if not digest_model.is_valid: - progress.close() - invalid_yaml_dialog = StatusDialog( - title="Warning", - status_message=f"YAML file {filepath} is not a valid digest report", - ) - invalid_yaml_dialog.show() - - return - - model_id = digest_model.unique_id - - # There is no sense in offering to save the report - self.stats_save_button_flag[model_id] = False - self.similarity_save_button_flag[model_id] = False - - self.digest_models[model_id] = digest_model - - model_summary = modelSummary(digest_model) - - self.ui.tabWidget.addTab(model_summary, "") - model_summary.ui.flops.setText("Loading...") - - # Hide some of the components - model_summary.ui.similarityCorrelation.hide() - model_summary.ui.similarityCorrelationStatic.hide() - - model_summary.file = filepath - model_summary.setObjectName(digest_model.model_name) - model_summary.ui.modelName.setText(digest_model.model_name) - model_summary.ui.modelFilename.setText(filepath) - model_summary.ui.generatedDate.setText(datetime.now().strftime("%B %d, %Y")) - - model_summary.ui.parameters.setText(format(digest_model.parameters, ",")) - - node_type_counts = digest_model.node_type_counts - if len(node_type_counts) < 15: - bar_spacing = 40 - else: - bar_spacing = 20 - - model_summary.ui.opHistogramChart.bar_spacing = bar_spacing - model_summary.ui.opHistogramChart.set_data(node_type_counts) - model_summary.ui.nodes.setText(str(sum(node_type_counts.values()))) - - progress.step() - progress.setLabelText("Gathering Model Inputs and Outputs") - - # Inputs Table - model_summary.ui.inputsTable.setRowCount( - len(self.digest_models[model_id].model_inputs) - ) - - for row_idx, (input_name, input_info) in enumerate( - self.digest_models[model_id].model_inputs.items() - ): - model_summary.ui.inputsTable.setItem( - row_idx, 0, QTableWidgetItem(input_name) - ) - model_summary.ui.inputsTable.setItem( - row_idx, 1, QTableWidgetItem(str(input_info.shape)) - ) - model_summary.ui.inputsTable.setItem( - row_idx, 2, QTableWidgetItem(str(input_info.dtype)) - ) - model_summary.ui.inputsTable.setItem( - row_idx, 3, QTableWidgetItem(str(input_info.size_kbytes)) - ) - - model_summary.ui.inputsTable.resizeColumnsToContents() - model_summary.ui.inputsTable.resizeRowsToContents() - - # Outputs Table - model_summary.ui.outputsTable.setRowCount( - len(self.digest_models[model_id].model_outputs) - ) - for row_idx, (output_name, output_info) in enumerate( - self.digest_models[model_id].model_outputs.items() - ): - model_summary.ui.outputsTable.setItem( - row_idx, 0, QTableWidgetItem(output_name) - ) - model_summary.ui.outputsTable.setItem( - row_idx, 1, QTableWidgetItem(str(output_info.shape)) - ) - model_summary.ui.outputsTable.setItem( - row_idx, 2, QTableWidgetItem(str(output_info.dtype)) - ) - model_summary.ui.outputsTable.setItem( - row_idx, 3, QTableWidgetItem(str(output_info.size_kbytes)) - ) - - model_summary.ui.outputsTable.resizeColumnsToContents() - model_summary.ui.outputsTable.resizeRowsToContents() - - progress.step() - progress.setLabelText("Gathering Model Proto Data") - - # ModelProto Info - model_summary.ui.modelProtoTable.setItem( - 0, 1, QTableWidgetItem(str(digest_model.model_data["model_version"])) - ) - - model_summary.ui.modelProtoTable.setItem( - 1, 1, QTableWidgetItem(str(digest_model.model_data["graph_name"])) - ) - - producer_txt = ( - f"{digest_model.model_data['producer_name']} " - f"{digest_model.model_data['producer_version']}" - ) - model_summary.ui.modelProtoTable.setItem(2, 1, QTableWidgetItem(producer_txt)) - - model_summary.ui.modelProtoTable.setItem( - 3, 1, QTableWidgetItem(str(digest_model.model_data["ir_version"])) - ) - - for domain, version in digest_model.model_data["import_list"].items(): - row_idx = model_summary.ui.importsTable.rowCount() - model_summary.ui.importsTable.insertRow(row_idx) - if domain == "" or domain == "ai.onnx": - model_summary.ui.opsetVersion.setText(str(version)) - domain = "ai.onnx" - - model_summary.ui.importsTable.setItem( - row_idx, 0, QTableWidgetItem(str(domain)) - ) - model_summary.ui.importsTable.setItem( - row_idx, 1, QTableWidgetItem(str(version)) - ) - row_idx += 1 - - progress.step() - progress.setLabelText("Wrapping Up Model Analysis") - - model_summary.ui.importsTable.resizeColumnsToContents() - model_summary.ui.modelProtoTable.resizeColumnsToContents() - model_summary.setObjectName(digest_model.model_name) - new_tab_idx = self.ui.tabWidget.count() - 1 - self.ui.tabWidget.setTabText(new_tab_idx, "".join(digest_model.model_name)) - self.ui.tabWidget.setCurrentIndex(new_tab_idx) - self.ui.stackedWidget.setCurrentIndex(self.Page.SUMMARY) - self.ui.singleModelWidget.show() - progress.step() - - # self.update_cards(digest_model.unique_id) - - movie = QMovie(":/assets/gifs/load.gif") - model_summary.ui.similarityImg.setMovie(movie) - movie.start() - - self.update_similarity_widget( - completed_successfully=bool(digest_model.similarity_heatmap_path), - model_id=digest_model.unique_id, - most_similar="", - png_filepath=digest_model.similarity_heatmap_path, - ) - - progress.close() - - except FileNotFoundError as e: - print(f"File not found: {e.filename}") - - -def load_model(file_path: str): - - # Ensure the filepath follows a standard formatting: - file_path = os.path.normpath(file_path) - - if not os.path.exists(file_path): - return - - file_ext = os.path.splitext(file_path)[-1] - - if file_ext == ".onnx": - self.load_onnx(file_path) - elif file_ext == ".yaml": - self.load_report(file_path) - else: - bad_ext_dialog = StatusDialog( - f"Digest does not support files with the extension {file_ext}", - parent=self, - ) - bad_ext_dialog.show() diff --git a/src/digest/modelsummary.py b/src/digest/modelsummary.py index c9b9591..35100a1 100644 --- a/src/digest/modelsummary.py +++ b/src/digest/modelsummary.py @@ -3,8 +3,7 @@ import os from datetime import datetime -# pylint: disable=invalid-name -from typing import Optional, Union +from typing import Optional # pylint: disable=no-name-in-module from PySide6.QtWidgets import QWidget, QTableWidgetItem diff --git a/src/digest/qt_utils.py b/src/digest/qt_utils.py index e2b3863..1015844 100644 --- a/src/digest/qt_utils.py +++ b/src/digest/qt_utils.py @@ -6,7 +6,6 @@ # pylint: disable=no-name-in-module from PySide6.QtWidgets import QWidget, QApplication -from PySide6.QtCore import QThread, QEventLoop, QTimer from PySide6.QtCore import QFile, QTextStream from digest.dialog import StatusDialog @@ -15,31 +14,6 @@ BASE_STYLE_FILE = os.path.join(ROOT_FOLDER, "styles", "darkstyle.qss") -def wait_threads(threads: List[QThread], timeout=10000) -> bool: - - loop = QEventLoop() - timer = QTimer() - timer.setSingleShot(True) - timer.timeout.connect(loop.quit) - - def check_threads(): - if all(thread.isFinished() for thread in threads): - loop.quit() - - check_timer = QTimer() - check_timer.timeout.connect(check_threads) - check_timer.start(100) # Check every 100ms - - timer.start(timeout) - loop.exec() - - check_timer.stop() - timer.stop() - - # Return True if all threads finished, False if timed out - return all(thread.isFinished() for thread in threads) - - def get_ram_utilization() -> float: mem = psutil.virtual_memory() ram_util_perc = mem.percent From 7e0d78272e34337ab6a09665e9c79540e988ef07 Mon Sep 17 00:00:00 2001 From: pcolange Date: Tue, 11 Feb 2025 17:48:02 -0500 Subject: [PATCH 05/11] add back pylint block for invalid-name --- src/digest/modelsummary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/digest/modelsummary.py b/src/digest/modelsummary.py index 35100a1..7c46acc 100644 --- a/src/digest/modelsummary.py +++ b/src/digest/modelsummary.py @@ -1,8 +1,8 @@ # Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved. +# pylint: disable=invalid-name import os from datetime import datetime - from typing import Optional # pylint: disable=no-name-in-module From 5e88bedaa39010dc8707996d20ab8793c7911fcc Mon Sep 17 00:00:00 2001 From: pcolange Date: Tue, 11 Feb 2025 17:49:48 -0500 Subject: [PATCH 06/11] newline the load progress dialog --- src/digest/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/digest/main.py b/src/digest/main.py index f9edc97..999c0c8 100644 --- a/src/digest/main.py +++ b/src/digest/main.py @@ -414,7 +414,7 @@ def load_model(self, file_path: str) -> None: self.load_progress.step() self.load_progress.setLabelText( - "Creating a Digest model. " + "Creating a Digest model.\n" "Please be patient as this could take a minute." ) From 9bf040842b027c652cc19c83a3c4ca17807b85ab Mon Sep 17 00:00:00 2001 From: pcolange Date: Tue, 11 Feb 2025 18:00:51 -0500 Subject: [PATCH 07/11] fix for flops logic and increased version --- setup.py | 2 +- src/digest/modelsummary.py | 44 ++++++++++++++++++++++---------------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/setup.py b/setup.py index 0e48553..836ced3 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ setup( name="digestai", - version="1.1.1", + version="1.1.2", description="Model analysis toolkit", author="Philip Colangelo, Daniel Holanda", packages=find_packages(where="src"), diff --git a/src/digest/modelsummary.py b/src/digest/modelsummary.py index 7c46acc..fe5387e 100644 --- a/src/digest/modelsummary.py +++ b/src/digest/modelsummary.py @@ -89,26 +89,32 @@ def __init__(self, digest_model: DigestModel, parent=None): self.ui.opHistogramChart.set_data(node_type_counts) self.ui.nodes.setText(str(sum(node_type_counts.values()))) - flops_str = format(digest_model.flops, ",") - self.ui.flops.setText(flops_str) + # Format flops with commas if available + flops_str = "N/A" + if digest_model.flops is not None: + flops_str = format(digest_model.flops, ",") + + # Set up the FLOPs pie chart + pie_chart_labels, pie_chart_data = zip( + *digest_model.node_type_flops.items() + ) + self.ui.flopsPieChart.set_data( + "FLOPs Intensity Per Op Type", + pie_chart_labels, + pie_chart_data, + ) - # Set up the FLOPs pie chart - pie_chart_labels, pie_chart_data = zip(*digest_model.node_type_flops.items()) - self.ui.flopsPieChart.set_data( - "FLOPs Intensity Per Op Type", - pie_chart_labels, - pie_chart_data, - ) - - # Set up the params pie chart - pie_chart_labels, pie_chart_data = zip( - *digest_model.node_type_parameters.items() - ) - self.ui.parametersPieChart.set_data( - "Parameter Intensity Per Op Type", - pie_chart_labels, - pie_chart_data, - ) + # Set up the params pie chart + pie_chart_labels, pie_chart_data = zip( + *digest_model.node_type_parameters.items() + ) + self.ui.parametersPieChart.set_data( + "Parameter Intensity Per Op Type", + pie_chart_labels, + pie_chart_data, + ) + + self.ui.flops.setText(flops_str) # Inputs Table self.ui.inputsTable.setRowCount(len(digest_model.model_inputs)) From a1d1e912c173ea09d7ccfe7a9c7b6685d4f85895 Mon Sep 17 00:00:00 2001 From: pcolange Date: Wed, 12 Feb 2025 09:16:09 -0500 Subject: [PATCH 08/11] handle case where no models are found in dir --- src/digest/multi_model_selection_page.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/digest/multi_model_selection_page.py b/src/digest/multi_model_selection_page.py index a637f84..4cded29 100644 --- a/src/digest/multi_model_selection_page.py +++ b/src/digest/multi_model_selection_page.py @@ -232,6 +232,10 @@ def set_directory(self, directory: str): total_num_models = len(onnx_file_list) + len(report_file_list) + if total_num_models == 0: + self.update_message_label("No models found in the selected directory.") + return + serialized_models_paths: defaultdict[bytes, List[str]] = defaultdict(list) progress = ProgressDialog("Loading models", total_num_models, self) From 8ac2c31a921dfd02eccff42bfae0058d764669cc Mon Sep 17 00:00:00 2001 From: pcolange Date: Wed, 12 Feb 2025 09:25:52 -0500 Subject: [PATCH 09/11] changed model load progress bar to indeterminate --- src/digest/dialog.py | 2 +- src/digest/main.py | 18 ++++++------------ 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/src/digest/dialog.py b/src/digest/dialog.py index cd848a6..b85c6ab 100644 --- a/src/digest/dialog.py +++ b/src/digest/dialog.py @@ -21,7 +21,7 @@ class ProgressDialog(QProgressDialog): """A pop up window with a progress label that goes from 1 to 100""" - def __init__(self, label: str, num_steps: int, parent=None): + def __init__(self, label: str, num_steps: int = 0, parent=None): """ label: the text to be shown in the pop up dialog num_steps: the total number of events the progress bar will load through diff --git a/src/digest/main.py b/src/digest/main.py index 999c0c8..6f53c36 100644 --- a/src/digest/main.py +++ b/src/digest/main.py @@ -410,9 +410,11 @@ def load_model(self, file_path: str) -> None: self.ui.tabWidget.setCurrentIndex(index) return - self.load_progress = ProgressDialog("Loading Model...", 3, self) - self.load_progress.step() - + # Create progress dialog with indeterminate progress bar + self.load_progress = ProgressDialog(label="Loading Model...", parent=self) + # Setting min=max=0 creates an indeterminate progress bar + self.load_progress.setMinimum(0) + self.load_progress.setMaximum(0) self.load_progress.setLabelText( "Creating a Digest model.\n" "Please be patient as this could take a minute." @@ -446,16 +448,11 @@ def post_load_model(self, digest_model: DigestModel): """This function is automatically run after the model load workers are finished""" if self.load_progress: - self.load_progress.step() - self.load_progress.setLabelText("Displaying the model summary") + self.load_progress.close() if digest_model.unique_id: model_id = digest_model.unique_id else: - - if self.load_progress: - self.load_progress.close() - self.status_dialog = StatusDialog( "Unexpected Error: Digest model did not return a valid ID.", parent=self, @@ -483,9 +480,6 @@ def post_load_model(self, digest_model: DigestModel): self.ui.stackedWidget.setCurrentIndex(self.Page.SUMMARY) self.ui.singleModelWidget.show() - if self.load_progress: - self.load_progress.step() - if isinstance(digest_model, DigestOnnxModel) and digest_model.model_proto: self.stats_save_button_flag[model_id] = True dynamic_input_dims = onnx_utils.get_dynamic_input_dims( From b8a86bf4ba1215d9e040da722cffbd2381790275 Mon Sep 17 00:00:00 2001 From: pcolange Date: Wed, 12 Feb 2025 09:52:33 -0500 Subject: [PATCH 10/11] slight tweak to the test gui script to check for the digest app --- test/test_gui.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_gui.py b/test/test_gui.py index 5857c94..9197960 100644 --- a/test/test_gui.py +++ b/test/test_gui.py @@ -39,6 +39,8 @@ def tearDownClass(cls): def setUp(self): self.digest_app = digest.main.DigestApp() + if self.digest_app is None: + self.fail("Failed to initialize DigestApp") self.digest_app.show() self.initial_tab_count = self.digest_app.ui.tabWidget.count() self.addCleanup(self.digest_app.close) From 6d1d471e9e27684998f1d41fc5272b3f8bb4e96c Mon Sep 17 00:00:00 2001 From: pcolange Date: Mon, 24 Feb 2025 10:51:21 -0500 Subject: [PATCH 11/11] handle threading properly in tests --- src/digest/main.py | 7 ++- src/digest/modelsummary.py | 3 - test/test_gui.py | 124 ++++++++++++++++++++++++------------- 3 files changed, 87 insertions(+), 47 deletions(-) diff --git a/src/digest/main.py b/src/digest/main.py index 6f53c36..cb4cdf8 100644 --- a/src/digest/main.py +++ b/src/digest/main.py @@ -33,7 +33,7 @@ QMenu, ) from PySide6.QtGui import QDragEnterEvent, QDropEvent, QPixmap, QIcon, QFont -from PySide6.QtCore import Qt, QSize, QThreadPool +from PySide6.QtCore import Qt, QSize, QThreadPool, Signal from digest.dialog import StatusDialog, InfoDialog, WarnDialog, ProgressDialog from digest.similarity_analysis import SimilarityWorker, post_process @@ -160,6 +160,9 @@ class ModelLoadError(Exception): class DigestApp(QMainWindow): + """Main application window for Digest.""" + + model_loaded = Signal() # Used for tests class Page(IntEnum): SPLASH = 0 @@ -389,6 +392,8 @@ def update_similarity_widget( ): self.ui.saveBtn.setEnabled(True) + self.model_loaded.emit() # Used for tests + def load_model(self, file_path: str) -> None: try: file_path = os.path.normpath(file_path) diff --git a/src/digest/modelsummary.py b/src/digest/modelsummary.py index fe5387e..ddc9df3 100644 --- a/src/digest/modelsummary.py +++ b/src/digest/modelsummary.py @@ -26,9 +26,6 @@ class modelSummary(QWidget): - # def __init__( - # self, digest_model: Union[DigestOnnxModel, DigestReportModel], parent=None - # ): def __init__(self, digest_model: DigestModel, parent=None): super().__init__(parent) self.ui = Ui_modelSummary() diff --git a/test/test_gui.py b/test/test_gui.py index 9197960..a7eb98a 100644 --- a/test/test_gui.py +++ b/test/test_gui.py @@ -8,7 +8,7 @@ # pylint: disable=no-name-in-module from PySide6.QtTest import QTest -from PySide6.QtCore import Qt +from PySide6.QtCore import Qt, QEventLoop, QTimer, Signal from PySide6.QtWidgets import QApplication import digest.main @@ -29,13 +29,17 @@ class DigestGuiTest(unittest.TestCase): @classmethod def setUpClass(cls): cls.app = QApplication(sys.argv) + cls.app.setQuitOnLastWindowClosed(True) # Ensure proper cleanup return super().setUpClass() @classmethod def tearDownClass(cls): if isinstance(cls.app, QApplication): + cls.app.processEvents() cls.app.closeAllWindows() - cls.app = None + cls.app.quit() # Explicitly quit the application + cls.app = None + return super().tearDownClass() def setUp(self): self.digest_app = digest.main.DigestApp() @@ -43,29 +47,49 @@ def setUp(self): self.fail("Failed to initialize DigestApp") self.digest_app.show() self.initial_tab_count = self.digest_app.ui.tabWidget.count() - self.addCleanup(self.digest_app.close) + QApplication.processEvents() # Process initial events + self.addCleanup(self._cleanup) + + def _cleanup(self): + """Ensure proper cleanup of Qt resources""" + if self.digest_app: + # Close all windows first + for window in QApplication.topLevelWindows(): + window.close() + QApplication.processEvents() + + # Wait for any pending threads with a shorter timeout + if hasattr(self.digest_app, "thread_pool"): + self.digest_app.thread_pool.clear() # Cancel any pending tasks + self.digest_app.thread_pool.waitForDone(2000) # 2 second timeout + + QApplication.processEvents() + self.digest_app.close() + self.digest_app = None + QApplication.processEvents() - def tearDown(self): - self.digest_app.close() - QApplication.processEvents() # Ensure all pending events are processed - self.digest_app = None - super().tearDown() + def _wait_for_signal(self, signal: Signal, timeout_ms=THREAD_TIMEOUT): + """Wait for a signal to be emitted, with a timeout.""" + loop = QEventLoop() + signal_emitted = [] - def wait_all_threads(self, timeout_ms=None) -> bool: - """Wait for all tasks in the thread pool to complete.""" - timeout_ms = timeout_ms or self.THREAD_TIMEOUT - QApplication.processEvents() # Ensure pending events are processed - return self.digest_app.thread_pool.waitForDone(timeout_ms) + def on_signal_emitted(): + signal_emitted.append(True) + loop.quit() + + signal.connect(on_signal_emitted) + QTimer.singleShot(timeout_ms, loop.quit) + loop.exec() + + return bool(signal_emitted) def _mock_file_open(self, mock_dialog, filepath): - """Helper to mock file open dialog""" + """Helper to mock file open dialog and wait for load completion.""" mock_dialog.return_value = (filepath, "") QTest.mouseClick(self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton) - self.assertTrue(self.wait_all_threads()) def _verify_tab_added(self): """Verify that exactly one new tab was added""" - # Process events to ensure UI updates QApplication.processEvents() self.assertEqual( self.digest_app.ui.tabWidget.count(), @@ -79,22 +103,35 @@ def _close_current_tab(self): self.digest_app.closeTab(current_tab) def test_open_valid_onnx(self): - """Test that opening a valid ONNX file creates a new tab in the UI.""" + """Test opening a valid ONNX file.""" with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog: self._mock_file_open(mock_dialog, self.ONNX_FILEPATH) - self._verify_tab_added() + # Wait for the signal *after* clicking the button + self.assertTrue( + self._wait_for_signal(self.digest_app.model_loaded), + "Model load did not complete within timeout (test_open_valid_onnx)", + ) + self._verify_tab_added() # Verify tab *after* successful load self._close_current_tab() def test_open_valid_yaml(self): - """Test that opening a valid YAML report file creates a new tab in the UI.""" + """Test opening a valid YAML report file.""" with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog: self._mock_file_open(mock_dialog, self.YAML_FILEPATH) + self.assertTrue( + self._wait_for_signal(self.digest_app.model_loaded), + "Model load did not complete within timeout (test_open_valid_yaml)", + ) self._verify_tab_added() self._close_current_tab() def test_open_invalid_file(self): + """Test opening an invalid file (no tab should be added).""" with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog: - self._mock_file_open(mock_dialog, "invalid_file.txt") + mock_dialog.return_value = ("invalid_file.txt", "") + QTest.mouseClick(self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton) + # No need to wait for a signal here, as it won't be emitted + QApplication.processEvents() # Process events to update UI self.assertEqual( self.digest_app.ui.tabWidget.count(), self.initial_tab_count, @@ -102,6 +139,7 @@ def test_open_invalid_file(self): ) def test_save_reports(self): + """Test saving reports after loading a model.""" with patch( "PySide6.QtWidgets.QFileDialog.getOpenFileName" ) as mock_open_dialog, patch( @@ -111,27 +149,23 @@ def test_save_reports(self): with tempfile.TemporaryDirectory() as tmpdirname: mock_save_dialog.return_value = tmpdirname self._mock_file_open(mock_open_dialog, self.ONNX_FILEPATH) - - # Process any pending events and wait for threads - QApplication.processEvents() - + # Wait for the signal self.assertTrue( - self.wait_all_threads(), - "Background tasks did not complete within the specified timeout", + self._wait_for_signal(self.digest_app.model_loaded), + "Model load did not complete within timeout (save reports)", ) - # Process events again after threads complete - QApplication.processEvents() - - # Add debug information - self._print_debug_info() + QApplication.processEvents() # Ensure UI is updated self.assertTrue( self.digest_app.ui.saveBtn.isEnabled(), "Save button should be enabled after loading file", ) + QTest.mouseClick(self.digest_app.ui.saveBtn, Qt.MouseButton.LeftButton) + QApplication.processEvents() def test_save_tables(self): + """Test saving node tables.""" with patch( "PySide6.QtWidgets.QFileDialog.getOpenFileName" ) as mock_open_dialog, patch( @@ -146,11 +180,12 @@ def test_save_tables(self): self._mock_file_open(mock_open_dialog, self.ONNX_FILEPATH) - # Process events and wait for threads before accessing nodes window - QApplication.processEvents() + # Wait for the signal self.assertTrue( - self.wait_all_threads(), "Threads did not complete in time" + self._wait_for_signal(self.digest_app.model_loaded), + "Model load did not complete within timeout (save tables)", ) + QApplication.processEvents() self._save_nodes_list(output_file) self._close_current_tab() @@ -160,16 +195,19 @@ def _save_nodes_list(self, expected_output): QTest.mouseClick(self.digest_app.ui.nodesListBtn, Qt.MouseButton.LeftButton) # Get the node window and verify it - _, node_window = self.digest_app.nodes_window.popitem() - node_summary = node_window.main_window.centralWidget() - self.assertIsInstance(node_summary, NodeSummary) + if self.digest_app.nodes_window: + _, node_window = self.digest_app.nodes_window.popitem() + node_summary = node_window.main_window.centralWidget() + self.assertIsInstance(node_summary, NodeSummary) - if isinstance(node_summary, NodeSummary): - QTest.mouseClick(node_summary.ui.saveCsvBtn, Qt.MouseButton.LeftButton) - self.assertTrue( - os.path.exists(expected_output), - f"Nodes csv file not found at {expected_output}", - ) + if isinstance(node_summary, NodeSummary): + QTest.mouseClick(node_summary.ui.saveCsvBtn, Qt.MouseButton.LeftButton) + self.assertTrue( + os.path.exists(expected_output), + f"Nodes csv file not found at {expected_output}", + ) + else: + self.fail("Node summary window did not appear within timeout.") def _print_debug_info(self): """Print debug information about the current UI state."""