import numpy as np
import onnxruntime as ort
from onnx import TensorProto, helper
node = helper.make_node(
"LayerNormalization",
["x", "s", "b"],
["y"],
axis=-1,
epsilon=1e-5,
)
graph = helper.make_graph(
[node],
"g",
[
helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 4]),
helper.make_tensor_value_info("s", TensorProto.FLOAT, [4]),
helper.make_tensor_value_info("b", TensorProto.FLOAT, [4]),
],
[helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 4])],
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)])
model.ir_version = 9
sess = ort.InferenceSession(
model.SerializeToString(),
providers=["CPUExecutionProvider"],
)
s = np.ones(4, dtype=np.float32)
b = np.zeros(4, dtype=np.float32)
for v in (40000.0, 80000.0):
x = np.array([[v, v + 1, v + 2, v + 3]], dtype=np.float32)
y = sess.run(None, {"x": x, "s": s, "b": b})[0]
mean = x.mean(axis=-1, keepdims=True)
var = ((x - mean) ** 2).mean(axis=-1, keepdims=True)
ref = (x - mean) / np.sqrt(var + 1e-5)
print(f"input base: {v}")
print("ORT:", y)
print("ref:", ref)
print("ORT finite:", np.isfinite(y).all())
input base: 40000.0
ORT: [[-1.3416355 -0.4472118 0.4472118 1.3416355]]
ref: [[-1.3416355 -0.4472118 0.4472118 1.3416355]]
ORT finite: True
input base: 80000.0
ORT: [[-1.3416355 -0.4472118 0.4472118 1.3416355]]
ref: [[-1.3416355 -0.4472118 0.4472118 1.3416355]]
ORT finite: True
v=40000.0: ORT -> [[nan nan nan nan]]
v=80000.0: ORT -> [[nan nan nan nan]]
Describe the issue
ONNX Runtime
CPUExecutionProviderreturnsNaNforLayerNormalizationon finitefloat32inputs with large values and small variance.For inputs such as
[40000, 40001, 40002, 40003], the mean and variance are finite, so the normalized output should also be finite. However, ORT returns allNaN.This looks like a numerical stability issue in the variance computation for
LayerNormalization.To reproduce
Urgency
Expected output
Actual output
Platform
Linux
OS Version
Linux-6.17.0-20-generic-x86_64-with-glibc2.39
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
1.25.1
ONNX Runtime API
Python
Architecture
X86
Execution Provider
Default CPU
Execution Provider Library Version
No response