-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Exporting fp16 model to onnx produces invalid onnx model #107
Comments
Hi @dkloving . The ONNX model exported by |
Thanks, I am also trying to check on it. I'm not confident anymore about the information in my bug report. The issue may be caused elsewhere. I seem to have a lot to learn about how PyTorch exports to ONNX. I have started empirically testing individual parts of the model that I can isolate. I can confidently say that I'm currently stuck on testing To export this piece of the model, I am doing;
|
@dkloving torch.onnx.export 1.7.1 doesn't support model = yolov5_darknet_pan_s_r31(pretrained=False, progress=True, num_classes=2)
_export_module_friendly(model)
model = model.eval()
Yes, we use the above parameter Another option is to update PyTorch to 1.8.1, which natively supports exporting |
Thanks, I was able to make progress simply by updating to Pytorch 1.8.1 as you suggested, but also it was helpful for me to look at the export-friendly code. I was wrong in my initial bug report. The error occurs not on
or
Also, when I am exporting just
|
Another update. I can confirm that the problem is in exporting the Darknet backbone. The following code produces an Onnx model that does not behave correctly. This actually will allow you to create an onnxruntime inference session and run inference, but its outputs are
Inspecting the onnx file with Netron shows that the single It looks like this is an issue with the onnx converter itself. I am still investigating a fix or workaround. |
One issue seems to be with the conversion of |
Hi @dkloving This is because PyTorch doesn't currently support exporting the
Both |
Thanks @zhiqwang. It looks like pytorch has added support for SiLU. I exported a |
AnchorGenerator.grid_anchors had fp32 hard-coded which could result in forward pass returning mismatched datatypes, for example (fp32, fp16, fp16). Fix for zhiqwang#107
Export to onnx fp16 is still not working. The exported version of Tracking down the issue with torchvision is driving me bonkers. Somehow copy-pasting the code from here for example gives me a working exportable fp16 |
Hi @dkloving , Seems that torchvision's A more practical route is that we could also separate Actually the Edited: [Maybe I'm wrong here, check the following comment.] |
Hi @dkloving , Actually the And the torchvision's So we should do two things
FYI, torchvision is using the |
A temporary workaround for anyone who needs it is to force fp32 for post-processing by wrapping a yolort model like so:
|
🐛 Bug
When exporting a half precision (fp16) model to onnx it creates an invalid onnx file. This appears to be because of a node that remains in fp32 as a result of this line in
torch.nn.functional.interpolate
To Reproduce (REQUIRED)
Steps to reproduce the behavior:
model = model.to(device)
add the linemodel = model.half()
torch.onnx.export(...)
. Error will occur atonnx_model = onnx.load(export_onnx_name)
Relevant warnings on export appears to be:
Error on loading onnx model is:
Expected behavior
Successful execution of tutorial notebook when model is converted to half precision.
Environment
[pip3] numpy==1.19.2
[pip3] pytorch-lightning==1.3.0rc1
[pip3] torch==1.7.1
[pip3] torchaudio==0.7.0a0+a853dff
[pip3] torchmetrics==0.3.2
[pip3] torchvision==0.8.2
[conda] blas 1.0 mkl
[conda] cudatoolkit 10.2.89 hfd86e86_1
[conda] mkl 2020.2 256
[conda] mkl-service 2.3.0 py37he8ac12f_0
[conda] mkl_fft 1.3.0 py37h54f3939_0
[conda] mkl_random 1.1.1 py37h0573a6f_0
[conda] numpy 1.19.2 py37h54aff64_0
[conda] numpy-base 1.19.2 py37hfa32c7d_0
[conda] pytorch 1.7.1 py3.7_cuda10.2.89_cudnn7.6.5_0 pytorch
[conda] pytorch-lightning 1.3.0rc1 pypi_0 pypi
[conda] torchaudio 0.7.2 py37 pytorch
[conda] torchmetrics 0.3.2 pypi_0 pypi
[conda] torchvision 0.8.2 py37_cu102 pytorch
Additional context
It looks like a pytorch issue but I'm not sure how we are using this interpolate function. Perhaps we can find a workaround?
The text was updated successfully, but these errors were encountered: