Skip to content

ORT_ENABLE_ALL produces incorrect output for QuantizeLinear-DequantizeLinear Resize(linear) pattern #28448

@ALinrunrun

Description

@ALinrunrun

Describe the issue

ONNX Runtime CPUExecutionProvider produces different results for a QDQ Resize(linear) pattern when graph optimization is enabled.

The model applies:

QuantizeLinear -> DequantizeLinear -> Resize(linear, half_pixel) -> QuantizeLinear -> DequantizeLinear

With ORT_DISABLE_ALL, the output matches the expected result:

[2.0, 3.0, 6.0, 6.0, 5.0, 4.0]

With ORT_ENABLE_ALL, ORT returns:

[2.0, 3.0, 5.0, 6.0, 4.0, 4.0]

This changes the model result after optimization.

To reproduce

import numpy as np
import onnxruntime as ort
from onnx import TensorProto, helper, numpy_helper

inits = [
numpy_helper.from_array(np.array([1, 1, 1, 6], dtype=np.int64), "sizes"),
numpy_helper.from_array(np.float32(1.0), "sc"),
numpy_helper.from_array(np.int8(0), "zp"),
]

nodes = [
helper.make_node("QuantizeLinear", ["X", "sc", "zp"], ["Xq"]),
helper.make_node("DequantizeLinear", ["Xq", "sc", "zp"], ["Xd"]),
helper.make_node(
"Resize",
["Xd", "", "", "sizes"],
["R"],
mode="linear",
coordinate_transformation_mode="half_pixel",
),
helper.make_node("QuantizeLinear", ["R", "sc", "zp"], ["Yq"]),
helper.make_node("DequantizeLinear", ["Yq", "sc", "zp"], ["Y"]),
]

g = helper.make_graph(
nodes,
"g",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 1, 3])],
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 1, 6])],
initializer=inits,
)

m = helper.make_model(g, opset_imports=[helper.make_opsetid("", 19)])
m.ir_version = 9
mb = m.SerializeToString()

x = np.array([[[[2.0, 7.0, 4.0]]]], dtype=np.float32)

def run(opt_level):
so = ort.SessionOptions()
so.graph_optimization_level = opt_level
sess = ort.InferenceSession(
mb,
sess_options=so,
providers=["CPUExecutionProvider"],
)
return sess.run(None, {"X": x})[0].flatten()

out_disable = run(ort.GraphOptimizationLevel.ORT_DISABLE_ALL)
out_enable = run(ort.GraphOptimizationLevel.ORT_ENABLE_ALL)

expected = np.array([2.0, 3.0, 6.0, 6.0, 5.0, 4.0], dtype=np.float32)

print("expected:", expected.tolist())
print("ORT_DISABLE_ALL:", out_disable.tolist())
print("ORT_ENABLE_ALL: ", out_enable.tolist())
print("DISABLE matches:", np.array_equal(out_disable, expected))
print("ENABLE matches: ", np.array_equal(out_enable, expected))

Urgency

Expected output

expected: [2.0, 3.0, 6.0, 6.0, 5.0, 4.0]
ORT_DISABLE_ALL: [2.0, 3.0, 6.0, 6.0, 5.0, 4.0]
ORT_ENABLE_ALL:  [2.0, 3.0, 6.0, 6.0, 5.0, 4.0]
DISABLE matches: True
ENABLE matches:  True

Actual output

expected: [2.0, 3.0, 6.0, 6.0, 5.0, 4.0]
ORT_DISABLE_ALL: [2.0, 3.0, 6.0, 6.0, 5.0, 4.0]
ORT_ENABLE_ALL:  [2.0, 3.0, 5.0, 6.0, 4.0, 4.0]
DISABLE matches: True
ENABLE matches:  False

Platform

Linux

OS Version

Linux-6.17.0-20-generic-x86_64-with-glibc2.39

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.25.1

ONNX Runtime API

Python

Architecture

X86

Execution Provider

Default CPU

Execution Provider Library Version

No response

Metadata

Metadata

Labels

quantizationissues related to quantization

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions