-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
failure of TensorRT 10.8.0.43 when running Unimatch Fp32 to Fp16 conversion on GPU Jetson Orin 8GB and NVIDIA RTX 4500 #4355
Comments
As a work around, please consider using strongly typed mode - https://docs.nvidia.com/deeplearning/tensorrt/latest/inference-library/advanced.html#strongly-typed-networks . |
@danielmimimi Can you provide sample inputs for this model? Using random inputs is not always a good measurement of the network's accuracy. |
Hello, @galagam I attached some images which should create nice disparities using the FP32 image. The following code shows how to preprocess them.
|
Hey @danielmimimi, First, I would like to clarify that ORT CPU EP (default provider in Polygraphy) doesn't support FP16 compute. This means that the inputs and initializers are cast down to FP16, but the compute itself is done in FP32, resulting in accuracy that is typically close to the FP32 baseline. To compare TRT accuracy to ORT FP16 accuracy, we should compare both FP16 implementation to the same baseline.
TRT and ORT accuracy are both not good, and are in the same ballpark. This means that whatever issue we have is inherent to the model/inputs, and is not implementation-specific. Looking at the sample inputs, we get a better picture of the actual accuracy: Trying weakly typed mode with the provided inputs:
This is already quite better compared to the error for random inputs. This is a common issue, which is why we'll always prefer to measure the accuracy using typical input data distributions (or real sample inputs in this case). Comparing both ORT CUDA EP and TRT against the baseline:
Again, TRT is in the same ballpark as ORT. Can we make this even better? Assuming some layers should execute in FP32 to improve accuracy, let’s check the best-case and worst-case scenarios:
See gmflow-scale1_simplified-flex-disable-all.onnx.onnx in the attached zip. Conclusions:
|
Hey @galagam ! I have a few follow-up questions if you don’t mind:
Thanks again for your time and help. |
Description
I have tried to convert Unimatch FP32 gmflow-scale1 Model to a Float16 engine model. However the Float 16 Model is not really usable, this concolusion comes from visually inspecting the result as such as the polygraphy tool. For converting the model I have used the following commands :
Both share the same warning output :
However I have been comparing the fp32 and fp16 model with polygraphy :
Applying the last command reveals the following :
That the output is completely different. According to the error message concerning the layernorm I have tried to actively not convert it with the tensorrt python api. The goal was to let the entire attention run on fp32. However it yielded the same result.
I have received the same behaviour converting this network also on the jetson orin 8GB (TensorRt 8.6 and 10.0). For extended tests I switched to another device.
Environment
I am using the docker nvcr.io/nvidia/tensorrt:25.01-py3 environment.
TensorRT Version:
10.8.0.43-1
NVIDIA GPU:
NVIDIA RTX 4500
NVIDIA Driver Version:
535.183.01
CUDA Version:
cuda12.8
CUDNN Version:
Operating System:
Ubuntu 24.04.1 LTS
Python Version (if applicable):
3.12.3
Tensorflow Version (if applicable):
No
PyTorch Version (if applicable):
No
Baremetal or Container (if so, version):
nvcr.io/nvidia/tensorrt:25.01-py3
Relevant Files
Model link:
Unimatch gmflow-scale1
Steps To Reproduce
--save-inputs inputs.json --save-outputs outputs_fp32.json
*polygraphy run --trt trexec_fp16_model.engine
--load-inputs inputs.json --load-outputs outputs_fp32.json
--atol 0.001 --rtol 0
Commands or scripts:
Have you tried the latest release?:
Yes, the docker is quite new.
Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (
polygraphy run <model.onnx> --onnxrt
):Yes, but not on FP16 I suppose.
The text was updated successfully, but these errors were encountered: