Skip to content

Commit

Permalink
[Resolves] Support ingesting a digest model report/cache (#3)
Browse files Browse the repository at this point in the history
* Adds support for ingesting digest reports

---------

Signed-off-by: Philip <[email protected]>
Co-authored-by: Philip Colangelo <[email protected]>
  • Loading branch information
pcolange and Philip Colangelo authored Jan 24, 2025
1 parent 19a6194 commit cc55d21
Show file tree
Hide file tree
Showing 38 changed files with 2,594 additions and 1,403 deletions.
11 changes: 0 additions & 11 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ enable =
expression-not-assigned,
confusing-with-statement,
unnecessary-lambda,
assign-to-new-keyword,
redeclared-assigned-name,
pointless-statement,
pointless-string-statement,
Expand Down Expand Up @@ -123,7 +122,6 @@ enable =
invalid-length-returned,
protected-access,
attribute-defined-outside-init,
no-init,
abstract-method,
invalid-overridden-method,
arguments-differ,
Expand Down Expand Up @@ -165,9 +163,7 @@ enable =
### format
# Line length, indentation, whitespace:
bad-indentation,
mixed-indentation,
unnecessary-semicolon,
bad-whitespace,
missing-final-newline,
line-too-long,
mixed-line-endings,
Expand All @@ -187,7 +183,6 @@ enable =
import-self,
preferred-module,
reimported,
relative-import,
deprecated-module,
wildcard-import,
misplaced-future,
Expand Down Expand Up @@ -282,12 +277,6 @@ indent-string = ' '
# black doesn't always obey its own limit. See pyproject.toml.
max-line-length = 100

# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check =

# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt = no
Expand Down
26 changes: 15 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ SPDX-License-Identifier: Apache-2.0

DigestAI
===========================

<h3>DigestAI is a powerful model analysis tool that extracts insights from your models, enabling optimization and direct modification.</h3>

[![python](https://img.shields.io/badge/python-3.10-blue)](https://github.com/onnx/digestai)
Expand All @@ -16,11 +17,14 @@ DigestAI

![logo](src/digest/assets/images/banner.png)
---

<div align="left">

DigestAI is a powerful model analysis tool that extracts insights from your models, enabling optimization and direct modification.

**Get started quickly!** Download the DigestAI installer directly from [coming soon!].

**Get started quickly!** Download the DigestAI executable directly from [link coming soon].


**Developers: Contribute to DigestAI** Follow the installation instruction below to get started.

Expand Down Expand Up @@ -65,19 +69,18 @@ The following steps are recommended because they are reproducible, however, ther
**Workflow**

1. **Open Qt Designer:**
- **Activate Conda Environment:** Ensure your `digest` Conda environment is activated.
- **Launch:** Run `pyside6-designer.exe` from your terminal.
* **Activate Conda Environment:** Ensure your `digest` Conda environment is activated.
* **Launch:** Run `pyside6-designer.exe` from your terminal.

2. **Work with UI Files:**
- Open any existing UI file (`.ui`) from `src/digest/ui`.
- Design your interface using the drag-and-drop tools and property editor.
- Resource Files (Optional): If your UI uses custom icons, images, or stylesheets, please leverage the Qt resource file (`.qrc`). This makes it easier to manage and package resources with the application.
- Please add any new `.ui` files to the `.pylintrc` file.
* Open any existing UI file (`.ui`) from `src/digest/ui`.
* Design your interface using the drag-and-drop tools and property editor.
* Resource Files (Optional): If your UI uses custom icons, images, or stylesheets, please leverage the Qt resource file (`.qrc`). This makes it easier to manage and package resources with the application.
* Please add any new `.ui` files to the `.pylintrc` file.

3. **Recompile UI Files (After Making Changes):**
- From your terminal, navigate to the project's root directory.
- Run: `python src/digest/compile_digest_gui.py`
* From your terminal, navigate to the project's root directory.
* Run: `python src/digest/compile_digest_gui.py`
## Building EXE for Windows Deployment
Expand Down Expand Up @@ -114,8 +117,9 @@ pytest test/test_gui.py
```
## License
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE.txt) file for details.
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE.txt) file for details.
## Copyright
Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved.
40 changes: 21 additions & 19 deletions examples/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
import csv
from collections import Counter, defaultdict
from tqdm import tqdm
from digest.model_class.digest_model import (
NodeShapeCounts,
NodeTypeCounts,
save_node_shape_counts_csv_report,
save_node_type_counts_csv_report,
)
from digest.model_class.digest_onnx_model import DigestOnnxModel
from utils.onnx_utils import (
get_dynamic_input_dims,
load_onnx,
DigestOnnxModel,
save_node_shape_counts_csv_report,
save_node_type_counts_csv_report,
NodeTypeCounts,
NodeShapeCounts,
)

GLOBAL_MODEL_HEADERS = [
Expand Down Expand Up @@ -82,46 +84,46 @@ def main(onnx_files: str, output_dir: str):

global_model_data[model_name] = {
"opset": digest_model.opset,
"parameters": digest_model.model_parameters,
"flops": digest_model.model_flops,
"parameters": digest_model.parameters,
"flops": digest_model.flops,
}

# Model summary text report
summary_filepath = os.path.join(output_dir, f"{model_name}_summary.txt")
digest_model.save_txt_report(summary_filepath)
digest_model.save_text_report(summary_filepath)

# Model summary yaml report
summary_filepath = os.path.join(output_dir, f"{model_name}_summary.yaml")
digest_model.save_yaml_report(summary_filepath)

# Save csv containing node-level information
nodes_filepath = os.path.join(output_dir, f"{model_name}_nodes.csv")
digest_model.save_nodes_csv_report(nodes_filepath)

# Save csv containing node type counter
node_type_counter = digest_model.get_node_type_counts()
node_type_filepath = os.path.join(
output_dir, f"{model_name}_node_type_counts.csv"
)
if node_type_counter:
save_node_type_counts_csv_report(node_type_counter, node_type_filepath)

digest_model.save_node_type_counts_csv_report(node_type_filepath)

# Update global data structure for node type counter
global_node_type_counter.update(node_type_counter)
global_node_type_counter.update(digest_model.node_type_counts)

# Save csv containing node shape counts per op_type
node_shape_counts = digest_model.get_node_shape_counts()
node_shape_filepath = os.path.join(
output_dir, f"{model_name}_node_shape_counts.csv"
)
save_node_shape_counts_csv_report(node_shape_counts, node_shape_filepath)
digest_model.save_node_shape_counts_csv_report(node_shape_filepath)

# Update global data structure for node shape counter
for node_type, shape_counts in node_shape_counts.items():
for node_type, shape_counts in digest_model.get_node_shape_counts().items():
global_node_shape_counter[node_type].update(shape_counts)

if len(onnx_file_list) > 1:
global_filepath = os.path.join(output_dir, "global_node_type_counts.csv")
global_node_type_counter = NodeTypeCounts(
global_node_type_counter.most_common()
)
save_node_type_counts_csv_report(global_node_type_counter, global_filepath)
global_node_type_counts = NodeTypeCounts(global_node_type_counter.most_common())
save_node_type_counts_csv_report(global_node_type_counts, global_filepath)

global_filepath = os.path.join(output_dir, "global_node_shape_counts.csv")
save_node_shape_counts_csv_report(global_node_shape_counter, global_filepath)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name="digestai",
version="1.0.0",
version="1.1.0",
description="Model analysis toolkit",
author="Philip Colangelo, Daniel Holanda",
packages=find_packages(where="src"),
Expand Down
14 changes: 12 additions & 2 deletions src/digest/dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,23 @@ class WarnDialog(QDialog):

def __init__(self, warning_message: str, parent=None):
super().__init__(parent)
self.setWindowTitle("Warning Message")

self.setWindowIcon(QIcon(":/assets/images/digest_logo_500.jpg"))

self.setWindowTitle("Warning Message")
self.setWindowFlags(Qt.WindowType.Dialog)
self.setMinimumWidth(300)

self.setWindowModality(Qt.WindowModality.WindowModal)

layout = QVBoxLayout()

# Application Version
layout.addWidget(QLabel("<b>Something went wrong</b>"))
layout.addWidget(QLabel("<b>Warning</b>"))
layout.addWidget(QLabel(warning_message))

ok_button = QPushButton("OK")
ok_button.clicked.connect(self.accept) # Close dialog when clicked
layout.addWidget(ok_button)

self.setLayout(layout)
6 changes: 3 additions & 3 deletions src/digest/histogramchartwidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(self, *args, **kwargs):
super(StackedHistogramWidget, self).__init__(*args, **kwargs)

self.plot_widget = pg.PlotWidget()
self.plot_widget.setMaximumHeight(150)
self.plot_widget.setMaximumHeight(200)
plot_item = self.plot_widget.getPlotItem()
if plot_item:
plot_item.setContentsMargins(0, 0, 0, 0)
Expand All @@ -157,7 +157,6 @@ def __init__(self, *args, **kwargs):
self.bar_spacing = 25

def set_data(self, data: OrderedDict, model_name, y_max, title="", set_ticks=False):

title_color = "rgb(0,0,0)" if set_ticks else "rgb(200,200,200)"
self.plot_widget.setLabel(
"left",
Expand All @@ -173,7 +172,8 @@ def set_data(self, data: OrderedDict, model_name, y_max, title="", set_ticks=Fal
x_positions = list(range(len(op_count)))
total_count = sum(op_count)
width = 0.6
self.plot_widget.setFixedWidth(len(op_names) * self.bar_spacing)
self.plot_widget.setFixedWidth(500)

for count, x_pos, tick in zip(op_count, x_positions, op_names):
x0 = x_pos - width / 2
y0 = 0
Expand Down
Loading

0 comments on commit cc55d21

Please sign in to comment.