Skip to content

Arm backend: Add function to return quant params for lowered graph #12390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions exir/backend/io_quant_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Sequence

import torch.fx as fx
from executorch.exir import EdgeProgramManager
from executorch.exir.passes.quantize_io_pass import QuantizeInputs, QuantizeOutputs


def extract_io_quant_params(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps move this to the quantize_io_pass.py?

edge_prog: EdgeProgramManager,
*,
input_idxs: Sequence[int] = (0,),
output_idxs: Sequence[int] = (0,),
) -> Dict[str, Dict[str, Dict[str, Any]]]:
"""
Returns quantization parameters such as scale/zero_point:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we get these after quantize_io_pass and then the config methods it adds?

{
"inputs": {
<placeholder_name>: {"scale": float, "zero_point": int}
},
"outputs": {
<node_name>: {"scale": float, "zero_point": int}
}
}

Note that this function will strip out the IO quantize/dequantize ops as
it records their parameters, so if you need to preserve the original graph
you need to make a copy with copy.deepcopy before.

Note that `to_edge_transform_and_lower` should be called before.
"""
# Use IO passes
passes = []
for idx in input_idxs:
passes.append(QuantizeInputs(edge_prog, [idx]))
for idx in output_idxs:
passes.append(QuantizeOutputs(edge_prog, [idx]))

# Apply them
edge_prog = edge_prog.transform(passes)

cfg = getattr(edge_prog, "_config_methods", {}) or {}

# We need GraphModule to find node names
gm = edge_prog.exported_program().graph_module

input_names = _gather_io_names(gm, side="input")
output_names = _gather_io_names(gm, side="output")

# Build the result dict
result = {"inputs": {}, "outputs": {}}
for key, val in cfg.items():
if key.startswith("input"):
prefix, section, names = "input", "inputs", input_names
elif key.startswith("output"):
prefix, section, names = "output", "outputs", output_names
else:
continue

idx_str, param = key[len(prefix) :].split("_", 1)
idx = int(idx_str)
name = names[idx]
# We need to map 'zp' to 'zero_point'
out_param = "zero_point" if param in ("zp", "zero_point") else param
result[section].setdefault(name, {})[out_param] = val

return result


def _gather_io_names(gm: fx.GraphModule, side: str):
"""
For 'input', returns placeholder names in graph order.
For 'output', returns names of output nodes.
"""
if side == "input":
return [n.name for n in gm.graph.nodes if n.op == "placeholder"]

if side == "output":

def _flatten(args):
out = []

def rec(x):
if isinstance(x, (tuple, list)):
for y in x:
rec(y)
elif isinstance(x, fx.Node):
out.append(x)

rec(args)
return out

output_node = next(n for n in gm.graph.nodes if n.op == "output")
return [n.name for n in _flatten(output_node.args)]

raise ValueError(f"Unknown side: {side}")
93 changes: 93 additions & 0 deletions exir/backend/test/test_io_quant_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
import unittest

import torch
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from executorch.exir import to_edge_transform_and_lower
from executorch.exir.backend.io_quant_params import extract_io_quant_params

from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e


class SimpleAdd(torch.nn.Module):
def forward(self, x, y):
return x + y


class TestExtractIOQuantParamsPT2E(unittest.TestCase):
def setUp(self):
self.example_inputs = (
torch.ones(1, 5),
torch.full(
(
1,
5,
),
2.0,
),
)
self.mod = SimpleAdd().eval()

# Setup XNNPACK quantizer for example
self.quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config()
self.quantizer.set_global(operator_config)

exported = torch.export.export_for_training(
self.mod,
copy.deepcopy(self.example_inputs),
strict=True,
)
prepared = prepare_pt2e(exported.module(), self.quantizer)

# Call observers to calibrate
_ = prepared(*self.example_inputs)

converted = convert_pt2e(prepared)

# Export again with quant parameters
final_export = torch.export.export_for_training(
converted,
self.example_inputs,
strict=True,
)

# Lower to EdgeProgramManager
self.edge_prog = to_edge_transform_and_lower(final_export)

def test_roundtrip_extracts_io_params(self):
# Get dict with quant parameters
q = extract_io_quant_params(
self.edge_prog,
input_idxs=(0, 1),
output_idxs=(0,),
)

# Validate structure
self.assertIn("inputs", q)
self.assertIn("outputs", q)
self.assertEqual(len(q["inputs"]), 2)
self.assertEqual(len(q["outputs"]), 1)

# Each entry must have a float 'scale' and int 'zero_point'
for name, params in q["inputs"].items():
self.assertIsInstance(name, str)
self.assertIsInstance(params["scale"], float)
self.assertIsInstance(params["zero_point"], int)

out_name, out_params = next(iter(q["outputs"].items()))
self.assertIsInstance(out_name, str)
self.assertIsInstance(out_params["scale"], float)
self.assertIsInstance(out_params["zero_point"], int)


if __name__ == "__main__":
unittest.main()
Loading