Skip to content

LayerNormalization returns NaN for finite large float32 inputs #28463

@ALinrunrun

Description

@ALinrunrun

Describe the issue

ONNX Runtime CPUExecutionProvider returns NaN for LayerNormalization on finite float32 inputs with large values and small variance.

For inputs such as [40000, 40001, 40002, 40003], the mean and variance are finite, so the normalized output should also be finite. However, ORT returns all NaN.

This looks like a numerical stability issue in the variance computation for LayerNormalization.

To reproduce

import numpy as np
import onnxruntime as ort
from onnx import TensorProto, helper

node = helper.make_node(
    "LayerNormalization",
    ["x", "s", "b"],
    ["y"],
    axis=-1,
    epsilon=1e-5,
)

graph = helper.make_graph(
    [node],
    "g",
    [
        helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 4]),
        helper.make_tensor_value_info("s", TensorProto.FLOAT, [4]),
        helper.make_tensor_value_info("b", TensorProto.FLOAT, [4]),
    ],
    [helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 4])],
)

model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)])
model.ir_version = 9

sess = ort.InferenceSession(
    model.SerializeToString(),
    providers=["CPUExecutionProvider"],
)

s = np.ones(4, dtype=np.float32)
b = np.zeros(4, dtype=np.float32)

for v in (40000.0, 80000.0):
    x = np.array([[v, v + 1, v + 2, v + 3]], dtype=np.float32)
    y = sess.run(None, {"x": x, "s": s, "b": b})[0]

    mean = x.mean(axis=-1, keepdims=True)
    var = ((x - mean) ** 2).mean(axis=-1, keepdims=True)
    ref = (x - mean) / np.sqrt(var + 1e-5)

    print(f"input base: {v}")
    print("ORT:", y)
    print("ref:", ref)
    print("ORT finite:", np.isfinite(y).all())

Urgency

Expected output

input base: 40000.0
ORT: [[-1.3416355 -0.4472118  0.4472118  1.3416355]]
ref: [[-1.3416355 -0.4472118  0.4472118  1.3416355]]
ORT finite: True

input base: 80000.0
ORT: [[-1.3416355 -0.4472118  0.4472118  1.3416355]]
ref: [[-1.3416355 -0.4472118  0.4472118  1.3416355]]
ORT finite: True

Actual output

v=40000.0: ORT -> [[nan nan nan nan]]
v=80000.0: ORT -> [[nan nan nan nan]]

Platform

Linux

OS Version

Linux-6.17.0-20-generic-x86_64-with-glibc2.39

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

1.25.1

ONNX Runtime API

Python

Architecture

X86

Execution Provider

Default CPU

Execution Provider Library Version

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions