-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
27ab4a4
commit 07e7f1f
Showing
7 changed files
with
366 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from .CVC import build_dataset, KvasirDataSet | ||
from .CVC import build_dataset, KvasirDataSet, get_transform | ||
from .transforms import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
""" | ||
ONNX export script | ||
Export PyTorch models as ONNX graphs. | ||
This export script originally started as an adaptation of code snippets found at | ||
https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html | ||
The default parameters work with PyTorch 2.3 and ONNX 1.13 and produce an optimal ONNX graph | ||
for hosting in the ONNX runtime (see onnx_validate.py). To export an ONNX model compatible | ||
""" | ||
|
||
import argparse | ||
import torch | ||
import numpy as np | ||
import onnx | ||
import models | ||
from timm.models import create_model | ||
|
||
|
||
parser = argparse.ArgumentParser(description='PyTorch ONNX Deployment') | ||
parser.add_argument('--output', metavar='ONNX_FILE', default=None, type=str, | ||
help='output model filename') | ||
|
||
# Model & dataset params | ||
parser.add_argument('--model', type=str, default='UKAN_large', | ||
choices=['UKAN_samll', 'UKAN_base', 'UKAN_large'], | ||
help='model architecture (default: UKAN_large)') | ||
parser.add_argument('--checkpoint', default='./output/UKAN_large_best_model.pth', type=str, metavar='PATH', | ||
help='path to checkpoint (default: none)') | ||
parser.add_argument('--batch-size', default=1, type=int, | ||
metavar='N', help='mini-batch size (default: 1)') | ||
parser.add_argument('--img-size', default=256, type=int, | ||
metavar='N', help='Input image dimension, uses model default if empty') | ||
parser.add_argument('--nb-classes', type=int, default=2, | ||
help='Number classes in dataset') | ||
|
||
parser.add_argument('--opset', type=int, default=10, | ||
help='ONNX opset to use (default: 10)') | ||
parser.add_argument('--keep-init', action='store_true', default=False, | ||
help='Keep initializers as input. Needed for Caffe2 compatible export in newer PyTorch/ONNX.') | ||
parser.add_argument('--aten-fallback', action='store_true', default=False, | ||
help='Fallback to ATEN ops. Helps fix AdaptiveAvgPool issue with Caffe2 in newer PyTorch/ONNX.') | ||
parser.add_argument('--dynamic-size', action='store_true', default=False, | ||
help='Export model width dynamic width/height. Not recommended for "tf" models with SAME padding.') | ||
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', | ||
help='Override mean pixel value of dataset') | ||
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', | ||
help='Override std deviation of of dataset') | ||
|
||
|
||
|
||
|
||
def main(): | ||
args = parser.parse_args() | ||
|
||
# args.pretrained = True | ||
# if args.checkpoint: | ||
# args.pretrained = False | ||
|
||
if args.output == None: | ||
args.output = f'./{args.model}.onnx' | ||
|
||
print("==> Creating PyTorch {} model".format(args.model)) | ||
# NOTE exportable=True flag disables autofn/jit scripted activations and uses Conv2dSameExport layers | ||
# for models using SAME padding | ||
model = create_model( | ||
args.model, | ||
num_classes=args.nb_classes, | ||
# exportable=True | ||
) | ||
|
||
model.load_state_dict(torch.load(args.checkpoint)['model_state']) | ||
model.eval() | ||
|
||
example_input = torch.randn((args.batch_size, 3, args.img_size or 224, args.img_size or 224), requires_grad=True) | ||
|
||
# Run model once before export trace, sets padding for models with Conv2dSameExport. This means | ||
# that the padding for models with Conv2dSameExport (most models with tf_ prefix) is fixed for | ||
# the input img_size specified in this script. | ||
# Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to | ||
# issues in the tracing of the dynamic padding or errors attempting to export the model after jit | ||
# scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions... | ||
model(example_input) | ||
|
||
print("==> Exporting model to ONNX format at '{}'".format(args.output)) | ||
input_names = ["input0"] | ||
output_names = ["output0"] | ||
dynamic_axes = {'input0': {0: 'batch'}, 'output0': {0: 'batch'}} | ||
if args.dynamic_size: | ||
dynamic_axes['input0'][2] = 'height' | ||
dynamic_axes['input0'][3] = 'width' | ||
if args.aten_fallback: | ||
export_type = torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK | ||
else: | ||
export_type = torch.onnx.OperatorExportTypes.ONNX | ||
|
||
torch_out = torch.onnx._export( | ||
model, example_input, args.output, export_params=True, verbose=True, input_names=input_names, | ||
output_names=output_names, keep_initializers_as_inputs=args.keep_init, dynamic_axes=dynamic_axes, | ||
opset_version=args.opset, operator_export_type=export_type) | ||
|
||
print("==> Loading and checking exported model from '{}'".format(args.output)) | ||
onnx_model = onnx.load(args.output) | ||
onnx.checker.check_model(onnx_model) # assuming throw on error | ||
print("==> Passed") | ||
|
||
if args.keep_init and args.aten_fallback: | ||
import caffe2.python.onnx.backend as onnx_caffe2 | ||
# Caffe2 loading only works properly in newer PyTorch/ONNX combos when | ||
# keep_initializers_as_inputs and aten_fallback are set to True. | ||
print("==> Loading model into Caffe2 backend and comparing forward pass.".format(args.output)) | ||
caffe2_backend = onnx_caffe2.prepare(onnx_model) | ||
B = {onnx_model.graph.input[0].name: x.data.numpy()} | ||
c2_out = caffe2_backend.run(B)[0] | ||
np.testing.assert_almost_equal(torch_out.data.numpy(), c2_out, decimal=5) | ||
print("==> Passed") | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
""" ONNX optimization script | ||
Run ONNX models through the optimizer to prune unneeded nodes, fuse batchnorm layers into conv, etc. | ||
NOTE: This isn't working consistently in recent PyTorch/ONNX combos (ie PyTorch 1.6 and ONNX 1.7), | ||
it seems time to switch to using the onnxruntime online optimizer (can also be saved for offline). | ||
Copyright 2020 Ross Wightman | ||
""" | ||
import argparse | ||
import warnings | ||
|
||
import onnx | ||
import onnxoptimizer as optimizer | ||
|
||
|
||
parser = argparse.ArgumentParser(description="Optimize ONNX model") | ||
|
||
parser.add_argument("--model", type=str, default='UKAN_large', | ||
choices=['UKAN_samll', 'UKAN_base', 'UKAN_large'], | ||
help="The ONNX model") | ||
parser.add_argument("--output", default=None, help="The optimized model output filename") | ||
|
||
|
||
def traverse_graph(graph, prefix=''): | ||
content = [] | ||
indent = prefix + ' ' | ||
graphs = [] | ||
num_nodes = 0 | ||
for node in graph.node: | ||
pn, gs = onnx.helper.printable_node(node, indent, subgraphs=True) | ||
assert isinstance(gs, list) | ||
content.append(pn) | ||
graphs.extend(gs) | ||
num_nodes += 1 | ||
for g in graphs: | ||
g_count, g_str = traverse_graph(g) | ||
content.append('\n' + g_str) | ||
num_nodes += g_count | ||
return num_nodes, '\n'.join(content) | ||
|
||
|
||
def main(): | ||
args = parser.parse_args() | ||
|
||
if args.output == None: | ||
args.output = f'./{args.model}_optim.onnx' | ||
|
||
args.model = f'./{args.model}.onnx' | ||
|
||
onnx_model = onnx.load(args.model) | ||
num_original_nodes, original_graph_str = traverse_graph(onnx_model.graph) | ||
|
||
# Optimizer passes to perform | ||
passes = [ | ||
#'eliminate_deadend', | ||
'eliminate_identity', | ||
'eliminate_nop_dropout', | ||
'eliminate_nop_pad', | ||
'eliminate_nop_transpose', | ||
'eliminate_unused_initializer', | ||
'extract_constant_to_initializer', | ||
'fuse_add_bias_into_conv', | ||
'fuse_bn_into_conv', | ||
'fuse_consecutive_concats', | ||
'fuse_consecutive_reduce_unsqueeze', | ||
'fuse_consecutive_squeezes', | ||
'fuse_consecutive_transposes', | ||
#'fuse_matmul_add_bias_into_gemm', | ||
'fuse_pad_into_conv', | ||
#'fuse_transpose_into_gemm', | ||
#'lift_lexical_references', | ||
] | ||
|
||
# Apply the optimization on the original serialized model | ||
# WARNING I've had issues with optimizer in recent versions of PyTorch / ONNX causing | ||
# 'duplicate definition of name' errors, see: https://github.com/onnx/onnx/issues/2401 | ||
# It may be better to rely on onnxruntime optimizations, see onnx_validate.py script. | ||
warnings.warn("I've had issues with optimizer in recent versions of PyTorch / ONNX." | ||
"Try onnxruntime optimization if this doesn't work.") | ||
optimized_model = optimizer.optimize(onnx_model, passes) | ||
|
||
num_optimized_nodes, optimzied_graph_str = traverse_graph(optimized_model.graph) | ||
print('==> The model after optimization:\n{}\n'.format(optimzied_graph_str)) | ||
print('==> The optimized model has {} nodes, the original had {}.'.format(num_optimized_nodes, num_original_nodes)) | ||
|
||
# Save the ONNX model | ||
onnx.save(optimized_model, args.output) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
""" ONNX-runtime validation script | ||
This script was created to verify accuracy and performance of exported ONNX | ||
models running with the onnxruntime. It utilizes the PyTorch dataloader/processing | ||
pipeline for a fair comparison against the originals. | ||
Copyright 2020 Ross Wightman | ||
""" | ||
import argparse | ||
import numpy as np | ||
import torch | ||
import onnxruntime | ||
from util.utils import AverageMeter | ||
import time | ||
from datasets import KvasirDataSet, get_transform | ||
from util.metrics import Metrics | ||
import util.utils as utils | ||
|
||
|
||
parser = argparse.ArgumentParser(description='Pytorch ONNX Validation') | ||
parser.add_argument("--Kvasir_path", type=str, default='/mnt/d/MedicalSeg/Kvasir-SEG/', | ||
help="path to Kvasir Dataset") | ||
parser.add_argument("--ClinicDB_path", type=str, default='/mnt/d/MedicalSeg/CVC-ClinicDB/', | ||
help="path to CVC-ClinicDBDataset") | ||
parser.add_argument('--nb-classes', type=int, default=2, | ||
help='Number classes in dataset') | ||
parser.add_argument('--onnx-input', default='./UKAN_large_optim.onnx', type=str, metavar='PATH', | ||
help='path to onnx model/weights file') | ||
parser.add_argument('--onnx-output-opt', default='', type=str, metavar='PATH', | ||
help='path to output optimized onnx graph') | ||
parser.add_argument('--profile', action='store_true', default=False, | ||
help='Enable profiler output.') | ||
parser.add_argument('--workers', default=2, type=int, metavar='N', | ||
help='number of data loading workers (default: 2)') | ||
parser.add_argument('--batch-size', default=4, type=int, | ||
metavar='N', help='mini-batch size (default: 4), as same as the train_batch_size in train_gpu.py') | ||
parser.add_argument('--img-size', default=256, type=int, | ||
metavar='N', help='Input image dimension, uses model default if empty') | ||
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', | ||
help='Override mean pixel value of dataset') | ||
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', | ||
help='Override std deviation of of dataset') | ||
parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT', | ||
help='Override default crop pct of 0.875') | ||
parser.add_argument('--interpolation', default='', type=str, metavar='NAME', | ||
help='Image resize interpolation type (overrides model)') | ||
parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true', | ||
help='use tensorflow mnasnet preporcessing') | ||
parser.add_argument('--print-freq', '-p', default=10, type=int, | ||
metavar='N', help='print frequency (default: 10)') | ||
|
||
|
||
def main(): | ||
args = parser.parse_args() | ||
args.gpu_id = 0 | ||
|
||
args.input_size = args.img_size | ||
|
||
# Set graph optimization level | ||
sess_options = onnxruntime.SessionOptions() | ||
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL | ||
if args.profile: | ||
sess_options.enable_profiling = True | ||
if args.onnx_output_opt: | ||
sess_options.optimized_model_filepath = args.onnx_output_opt | ||
|
||
session = onnxruntime.InferenceSession(args.onnx_input, sess_options) | ||
|
||
|
||
val_set = build_dataset(args) | ||
|
||
loader = torch.utils.data.DataLoader( | ||
val_set, | ||
batch_size=args.batch_size, | ||
num_workers=args.workers, | ||
drop_last=False | ||
) | ||
|
||
input_name = session.get_inputs()[0].name | ||
|
||
batch_time = AverageMeter() | ||
end = time.time() | ||
|
||
metric = Metrics(args.nb_classes, ignore_label=255, device='cpu') | ||
confmat = utils.ConfusionMatrix(args.nb_classes) | ||
|
||
for i, (input, target) in enumerate(loader): | ||
# run the net and return prediction | ||
output = session.run([], {input_name: input.data.numpy()}) | ||
output = output[0] ## shape: [Batch_size, nb_classes, img_size, img_size] | ||
|
||
confmat.update(target.flatten(), output.argmax(1).flatten()) | ||
metric.update(output, target.flatten()) | ||
# measure elapsed time | ||
batch_time.update(time.time() - end) | ||
end = time.time() | ||
|
||
if i % args.print_freq == 0: | ||
print(f'Test: [{i}/{len(loader)}]\t' | ||
f'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {(input.size(0) / batch_time.avg):.3f}/s, {(100 * batch_time.avg / input.size(0)):.3f} ms/sample) \t' | ||
f'val_meanF1: {metric.compute_f1()[1]}\t' | ||
f'val_meanACC: {metric.compute_pixel_acc()[1]}\t' | ||
f'val_mIOU: {round((confmat.compute()[2].mean().item() * 100), 2)}\t' | ||
) | ||
|
||
mean_iou = confmat.compute()[2].mean().item() * 100 | ||
mean_iou = round(mean_iou, 2) | ||
all_f1, mean_f1 = metric.compute_f1() | ||
all_acc, mean_acc = metric.compute_pixel_acc() | ||
print(f"**val_meanF1: {mean_f1}\n**val_meanACC: {mean_acc}\n**val_mIOU: {mean_iou}") | ||
|
||
|
||
def build_dataset(args): | ||
valid_ds = KvasirDataSet( | ||
args.Kvasir_path, | ||
args.ClinicDB_path, | ||
args.img_size, | ||
train_mode=False, | ||
transform=get_transform(train=False, args=args) | ||
) | ||
return valid_ds | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.