Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

failure of TensorRT 10.8.0.43 when running Unimatch Fp32 to Fp16 conversion on GPU Jetson Orin 8GB and NVIDIA RTX 4500 #4355

Open
danielmimimi opened this issue Feb 10, 2025 · 5 comments
Assignees
Labels
internal-bug-tracked Tracked internally, will be fixed in a future release. Investigating Issue needs further investigation Module:Accuracy Output mismatch between TensorRT and other frameworks triaged Issue has been triaged by maintainers

Comments

@danielmimimi
Copy link

Description

I have tried to convert Unimatch FP32 gmflow-scale1 Model to a Float16 engine model. However the Float 16 Model is not really usable, this concolusion comes from visually inspecting the result as such as the polygraphy tool. For converting the model I have used the following commands :

  • trtexec --onnx=exportedOnnxyModel --saveEngine=trexec_fp16_model.engine --fp16
  • polygraphy runexportedOnnxyModel --fp16 --trt --save-engine polygraphy_fp16_model.engine

Both share the same warning output :

[02/10/2025-11:52:28] [W] [TRT] Running layernorm after self-attention with FP16 Reduce or Pow may cause overflow. Forcing Reduce or Pow Layers in FP32 precision, or exporting the model to use INormalizationLayer (available with ONNX opset >= 17) can help preserving accuracy.

However I have been comparing the fp32 and fp16 model with polygraphy :

polygraphy run --trt trexec_fp32_model.engine \
   --save-inputs inputs.json --save-outputs outputs_fp32.json

polygraphy run --trt trexec_fp16_model.engine \
   --load-inputs inputs.json --load-outputs outputs_fp32.json \
   --atol 0.001 --rtol 0

Applying the last command reveals the following :

[I] Accuracy Comparison | trt-runner-N0-02/10/25-11:56:38 vs. trt-runner-N0-02/10/25-11:46:36
[I]     Comparing Output: '5242' (dtype=float32, shape=(2, 512, 512)) with '5242' (dtype=float32, shape=(2, 512, 512))
[I]         Tolerance: [abs=0.001, rel=0] | Checking elemwise error
[I]         trt-runner-N0-02/10/25-11:56:38: 5242 | Stats: mean=-6.6426, std-dev=3.1498, var=9.9214, median=-6.0293, min=-11.602 at (0, 198, 279), max=-1.5293 at (1, 511, 7), avg-magnitude=6.6426, p90=-2.8984, p95=-2.6699, p99=-2.207
[I]             ---- Histogram ----
                Bin Range      |  Num Elems | Visualization
                (-11.6, -10.6) |      42189 | ###########
                (-10.6, -9.51) |     113680 | ##############################
                (-9.51, -8.46) |      94404 | #########################
                (-8.46, -7.41) |      11855 | ###
                (-7.41, -6.37) |         16 | 
                (-6.37, -5.32) |          0 | 
                (-5.32, -4.27) |      31276 | ########
                (-4.27, -3.22) |     149044 | ########################################
                (-3.22, -2.18) |      77238 | ####################
                (-2.18, -1.13) |       4586 | #
[I]         trt-runner-N0-02/10/25-11:46:36: 5242 | Stats: mean=-4.1858, std-dev=1.7002, var=2.8905, median=-3.7579, min=-6.9407 at (1, 16, 55), max=-1.1289 at (0, 7, 151), avg-magnitude=4.1858, p90=-2.0205, p95=-1.7581, p99=-1.4932
[I]             ---- Histogram ----
                Bin Range      |  Num Elems | Visualization
                (-11.6, -10.6) |          0 | 
                (-10.6, -9.51) |          0 | 
                (-9.51, -8.46) |          0 | 
                (-8.46, -7.41) |          0 | 
                (-7.41, -6.37) |      42160 | #########
                (-6.37, -5.32) |     152490 | ###################################
                (-5.32, -4.27) |      65357 | ###############
                (-4.27, -3.22) |      26894 | ######
                (-3.22, -2.18) |     171081 | ########################################
                (-2.18, -1.13) |      66306 | ###############
[I]         Error Metrics: 5242
[I]             Minimum Required Tolerance: elemwise error | [abs=9.6093] OR [rel=7.0797] (requirements may be lower if both abs/rel tolerances are set)
[I]             Absolute Difference | Stats: mean=4.6614, std-dev=2.5247, var=6.374, median=3.6656, min=1.7329 at (1, 407, 496), max=9.6093 at (0, 1, 280), avg-magnitude=4.6614, p90=7.722, p95=8.1335, p99=8.8531
[I]                 ---- Histogram ----
                    Bin Range    |  Num Elems | Visualization
                    (1.73, 2.52) |     249836 | ########################################
                    (2.52, 3.31) |      12308 | #
                    (3.31, 4.1 ) |          0 | 
                    (4.1 , 4.88) |        148 | 
                    (4.88, 5.67) |       8790 | #
                    (5.67, 6.46) |      46781 | #######
                    (6.46, 7.25) |      89001 | ##############
                    (7.25, 8.03) |      86080 | #############
                    (8.03, 8.82) |      25450 | ####
                    (8.82, 9.61) |       5894 | 
[I]             Relative Difference | Stats: mean=1.6571, std-dev=1.4322, var=2.0513, median=1.0151, min=0.26921 at (1, 7, 272), max=7.0797 at (0, 15, 39), avg-magnitude=1.6571, p90=3.7972, p95=4.2712, p99=5.3135
[I]                 ---- Histogram ----
                    Bin Range     |  Num Elems | Visualization
                    (0.269, 0.95) |     262144 | ########################################
                    (0.95 , 1.63) |        288 | 
                    (1.63 , 2.31) |      81557 | ############
                    (2.31 , 2.99) |      84821 | ############
                    (2.99 , 3.67) |      36171 | #####
                    (3.67 , 4.36) |      36182 | #####
                    (4.36 , 5.04) |      15488 | ##
                    (5.04 , 5.72) |       4978 | 
                    (5.72 , 6.4 ) |       2027 | 
                    (6.4  , 7.08) |        632 | 
[E]         FAILED | Output: '5242' | Difference exceeds tolerance (rel=0, abs=0.001)
[E]     FAILED | Mismatched outputs: ['5242']

That the output is completely different. According to the error message concerning the layernorm I have tried to actively not convert it with the tensorrt python api. The goal was to let the entire attention run on fp32. However it yielded the same result.

    if layer.type in [trt.LayerType.NORMALIZATION , trt.LayerType.REDUCE,trt.LayerType.MATRIX_MULTIPLY,trt.LayerType.SOFTMAX,trt.LayerType.ACTIVATION ]:
        layer.precision = trt.float32
        for output_idx in range(layer.num_outputs):
            layer.set_output_type(output_idx, trt.float32)

I have received the same behaviour converting this network also on the jetson orin 8GB (TensorRt 8.6 and 10.0). For extended tests I switched to another device.

Environment

I am using the docker nvcr.io/nvidia/tensorrt:25.01-py3 environment.

TensorRT Version:
10.8.0.43-1
NVIDIA GPU:
NVIDIA RTX 4500
NVIDIA Driver Version:
535.183.01
CUDA Version:
cuda12.8
CUDNN Version:

Operating System:
Ubuntu 24.04.1 LTS
Python Version (if applicable):
3.12.3
Tensorflow Version (if applicable):
No
PyTorch Version (if applicable):
No
Baremetal or Container (if so, version):
nvcr.io/nvidia/tensorrt:25.01-py3

Relevant Files

Model link:
Unimatch gmflow-scale1

Steps To Reproduce

  • trtexec --onnx=gmflow-scale1_simplified.onnx --saveEngine=trexec_fp32_model.engine
  • trtexec --onnx=gmflow-scale1_simplified.onnx --saveEngine=trexec_fp16_model.engine --fp16
  • polygraphy run --trt trexec_fp32_model.engine
    --save-inputs inputs.json --save-outputs outputs_fp32.json
    *polygraphy run --trt trexec_fp16_model.engine
    --load-inputs inputs.json --load-outputs outputs_fp32.json
    --atol 0.001 --rtol 0

Commands or scripts:

Have you tried the latest release?:
Yes, the docker is quite new.

Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt):
Yes, but not on FP16 I suppose.

@danielmimimi danielmimimi changed the title XXX failure of TensorRT X.Y when running XXX on GPU XXX failure of TensorRT 10.8.0.43 when running Unimatch Fp32 to Fp16 conversion on GPU Jetson Orin 8GB and NVIDIA RTX 4500 Feb 10, 2025
@LeoZDong LeoZDong self-assigned this Feb 10, 2025
@LeoZDong LeoZDong added triaged Issue has been triaged by maintainers Module:Polygraphy Issues with Polygraphy Investigating Issue needs further investigation labels Feb 10, 2025
@brnguyen2 brnguyen2 added Module:Accuracy Output mismatch between TensorRT and other frameworks and removed Module:Polygraphy Issues with Polygraphy labels Feb 12, 2025
@galagam
Copy link

galagam commented Feb 16, 2025

As a work around, please consider using strongly typed mode - https://docs.nvidia.com/deeplearning/tensorrt/latest/inference-library/advanced.html#strongly-typed-networks .

@LeoZDong LeoZDong added the internal-bug-tracked Tracked internally, will be fixed in a future release. label Feb 18, 2025
@galagam
Copy link

galagam commented Feb 18, 2025

@danielmimimi Can you provide sample inputs for this model? Using random inputs is not always a good measurement of the network's accuracy.

@danielmimimi
Copy link
Author

danielmimimi commented Feb 25, 2025

Hello, @galagam I attached some images which should create nice disparities using the FP32 image.

SAMPLE_10.2.zip

The following code shows how to preprocess them.

  import numpy as np
  import onnxruntime as ort
  import matplotlib.pyplot as plt
  from polygraphy.backend.trt import EngineFromBytes, TrtRunner
  from polygraphy.comparator import Comparator, DataLoader
  from polygraphy.common import TensorMetadata
  from PIL import Image

  def create_data_loader(image_left_path, image_right_path):

      height = 304
      width = 512
      
      left_image = np.array(Image.open(image_left_path).convert('RGB')).astype(np.float32)
      right_image = np.array(Image.open(image_right_path).convert('RGB')).astype(np.float32)
      
      imagenet_mean = np.array([0.485, 0.456, 0.406]).astype(np.float32)  # Shape [3]
      imagenet_std = np.array([0.229, 0.224, 0.225]).astype(np.float32)   # Shape [3]
  
  
      left_image = np.array(left_image, dtype=np.float32)
      right_image = np.array(right_image, dtype=np.float32)
      left_image = np.transpose(left_image, (2, 0, 1))  # [C, H, W]
      right_image = np.transpose(right_image, (2, 0, 1))  # [C, H, W]
      
      # Normalize the images (per-channel)
      left_image = (left_image / 255.0 - imagenet_mean[:, None, None]) / imagenet_std[:, None, None]
      right_image = (right_image / 255.0 - imagenet_mean[:, None, None]) / imagenet_std[:, None, None]
  
  
      # Add batch dimension (1, C, H, W)
      left_image = np.expand_dims(left_image, axis=0)
      right_image = np.expand_dims(right_image, axis=0)
  
      # Create input metadata (update names if your model uses different input names)
      input_metadata = TensorMetadata()
      input_metadata.add("left_image", dtype=np.float32, shape=(1, 3, height, width))
      input_metadata.add("right_image", dtype=np.float32, shape=(1, 3, height, width))
  
      # Create data loader class
      class CustomDataLoader(DataLoader):
          def __init__(self, left, right):
              super().__init__()
              self.left = left
              self.right = right
          
          def __iter__(self):
              yield {"left_image": self.left, "right_image": self.right}
  
      return CustomDataLoader(left_image, right_image)

@galagam
Copy link

galagam commented Feb 27, 2025

Hey @danielmimimi,

First, I would like to clarify that ORT CPU EP (default provider in Polygraphy) doesn't support FP16 compute. This means that the inputs and initializers are cast down to FP16, but the compute itself is done in FP32, resulting in accuracy that is typically close to the FP32 baseline. To compare TRT accuracy to ORT FP16 accuracy, we should compare both FP16 implementation to the same baseline.
Since the network outputs are relatively small, absolute error is the indicative measure. I’ll present the max abs error, median abs error and percentiles 90, 95, 99.

# Generate reference using CPU EP
$ polygraphy run gmflow-scale1_simplified.onnx --onnxrt --save-inputs rand-inputs.json --save-outputs ort-fp32-ref.json --providers cpu

# Run ORT CUDA EP and compare to reference (make sure onnxruntime-gpu and it's dependencies are installed)
$ polygraphy run FP16_gmflow-scale1_simplified.onnx --load-inputs rand-inputs.json --load-outputs ort-f32-ref.json --onnxrt --providers cuda
# max=3.5303, median=1.1649, p90=2.5107, p95=2.6117, p99=2.8665

# Run TRT and compare to the same reference
$ polygraphy run FP16_gmflow-scale1_simplified.onnx --load-inputs rand-inputs.json --load-outputs ort-f32-ref.json --trt --fp16
# max=5.1532, median=3.0949, p90=4.6191, p95=4.7907, p99=4.9518

TRT and ORT accuracy are both not good, and are in the same ballpark. This means that whatever issue we have is inherent to the model/inputs, and is not implementation-specific.

Looking at the sample inputs, we get a better picture of the actual accuracy:
The supplied model accepts two (3,512,512) inputs, but the preprocess script generates (3,304,512) inputs. I modified the script to resize the images after loading, using Pillow resize (bicubic interpolation), presumably that’s close enough to the real preprocess. See preprocess-input.py in the attached zip.
In the following experiments I used image_2/inputs.json.

Trying weakly typed mode with the provided inputs:

$ polygraphy run gmflow-scale1_simplified.onnx --trt --onnxrt --fp16 --load-inputs SAMPLE_10.2/image_2/inputs.json
# max=1.8176, median=0.057446, p90=0.18048, p95=0.2379, p99=0.39784

This is already quite better compared to the error for random inputs. This is a common issue, which is why we'll always prefer to measure the accuracy using typical input data distributions (or real sample inputs in this case).

Comparing both ORT CUDA EP and TRT against the baseline:

# Generate reference using CPU EP
$ polygraphy run fp16_acc_repro/FP16_gmflow-scale1_simplified.onnx --load-inputs SAMPLE_10.2/image_2/inputs.json --onnxrt --providers cpu --save-outputs SAMPLE_10.2/image_2/ort-cpu-outputs.json

# Run ORT CUDA EP and compare to reference
$ polygraphy run fp16_acc_repro/FP16_gmflow-scale1_simplified.onnx --load-inputs SAMPLE_10.2/image_2/inputs.json --onnxrt --providers cuda --load-outputs SAMPLE_10.2/image_2/ort-cpu-outputs.json
# max=0.85938, median=0.046875, p90=0.11719, p95=0.15625, p99=0.30469

# Run TRT and compare to the same reference
$ polygraphy run fp16_acc_repro/FP16_gmflow-scale1_simplified.onnx --load-inputs SAMPLE_10.2/image_2/inputs.json  --trt --strongly-typed --load-outputs SAMPLE_10.2/image_2/ort-cpu-outputs.json
# max=1.4141, median=0.13281, p90=0.28408, p95=0.35742, p99=0.44775

Again, TRT is in the same ballpark as ORT.

Can we make this even better?
As you acknowledged above, running layer normalization in FP16 is problematic. If you are able to re-export this model using opset>=17, the exporter should convert layer normalization directly to ONNX’s LayerNormalization op. This will preserve accuracy for the internal LN computation.

Assuming some layers should execute in FP32 to improve accuracy, let’s check the best-case and worst-case scenarios:
Please note that “--strongly-typed” is used here instead of “--fp16”, this means TRT respects the ONNX cast operations.

# Run everything in FP16:
$ polygraphy run fp16_acc_repro/FP16_gmflow-scale1_simplified.onnx --trt --onnxrt --strongly-typed --load-inputs SAMPLE_10.2/image_2/inputs.json
max=1.2188, median 0.07, p90=0.23834, p95=0.28906, p99=0.35156

# Run everything in FP32 (input and output are cast down to FP16): 
$ polygraphy run gmflow-scale1_simplified-flex-disable-all.onnx --trt --onnxrt --strongly-typed --load-inputs SAMPLE_10.2/image_2/inputs.json
max=0.015, median=0, p90=1.5259e-05, p95=0.00024414, p99=0.00097656

See gmflow-scale1_simplified-flex-disable-all.onnx.onnx in the attached zip.

Conclusions:

  1. ORT (CUDA EP) and TRT achieve results in the same ballpark for FP16 compute.
  2. The default random values used by Polygraphy are not fitting for this model. Using the typical input distribution is critical for error measurement.
  3. Performing layer norm in FP16 is generally not advisable. The recommended approach would be to use opset>=17.
  4. If accuracy is not sufficient, you should consider forcing additional layers to FP32. The recommended mode for controlling layer precisions is strongly typed mode.

trt_github_4355.zip

@danielmimimi
Copy link
Author

Hey @galagam !
Thank you so much for your feedback—it's truly appreciated!

I have a few follow-up questions if you don’t mind:

  1. Have you seen cases where a mix of precisions works? For example, if I explicitly set only the precision of normalizing layers to FP32, like in the snippet below:
 if "fp16" in self.provider and builder.platform_has_fast_fp16:
       for i in range(network.num_layers):
           layer = network.get_layer(i)
           if layer.type in (trt.LayerType.ANYTHING_WITH_NORMALIZATION): # Placeholder flag!
               layer.precision = trt.float32


  1. You mention that "The recommended mode for controlling layer precisions is strongly typed mode." does this mean I must use the following flag?

config.set_flag(trt.BuilderFlag.PREFER_PRECISION_CONSTRAINTS) # The recommended mode for controlling layer precisions is strongly typed mode.

  1. Do you know of any specific ONNX/PyTorch versions that reliably support opset >= 17 without running into this issue? If not, I’d probably go with PyTorch 2.1.0 and ONNX 1.14.0—would that be a safe choice?
  2. Do you have any recommended literature or resources explaining the TensorRT FP16 conversion issues in detail?

Thanks again for your time and help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
internal-bug-tracked Tracked internally, will be fixed in a future release. Investigating Issue needs further investigation Module:Accuracy Output mismatch between TensorRT and other frameworks triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

4 participants