Skip to content

Remove unused parameters before serializing #9659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
GregoryComer opened this issue Mar 26, 2025 · 0 comments
Open

Remove unused parameters before serializing #9659

GregoryComer opened this issue Mar 26, 2025 · 0 comments
Labels
module: runtime Issues related to the core runtime and code under runtime/ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@GregoryComer
Copy link
Member

GregoryComer commented Mar 26, 2025

🐛 Describe the bug

Now that export switched to non-strict by default, unused parameters are left in the graph by default. This means that unquantized weights get serialized along with quantized weights, causing PTE size to bloat by 5x or more. We should strip out unused parameters somewhere in to_edge or to_executorch.

As a repro (requiring the latest PyTorch):

import torch

from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
    DuplicateDynamicQuantChainPass,
)
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    get_symmetric_quantization_config,
    XNNPACKQuantizer,
)
from torch.export import export, export_for_training, Dim
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import to_edge_transform_and_lower

class SimpleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(16, 1024)
        self.relu1 = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(1024, 16)
        self.relu2 = torch.nn.ReLU()
    
    def forward(self, x):
        x = self.linear1(x)
        x = self.relu1(x)
        x = self.linear2(x)
        x = self.relu2(x)
        return x

model = SimpleModel()
inputs = (torch.randn(1, 16),)

pre_autograd_aten_dialect = torch.export.export_for_training(
    model,
    inputs,
).module()

quantizer = XNNPACKQuantizer()
#qparams = get_symmetric_quantization_config(is_dynamic=True, is_per_channel=True)
qparams = get_symmetric_quantization_config(is_per_channel=False)
quantizer.set_global(qparams)

prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer)
prepared_graph.to("cpu")

converted_graph = convert_pt2e(prepared_graph)
DuplicateDynamicQuantChainPass()(converted_graph)

ep = export(converted_graph, inputs, strict=False)
lowered = to_edge_transform_and_lower(
    ep,
    partitioner=[XnnpackPartitioner()]
)

When printing the lowered program, note the extra unused f32 weights. You can also observe the PTE size is much larger than expected. Specifically, p_linear1_weight and p_linear2_weight are the original (unquantized) f32 weights and are unused. There is a u8 copy of the weights which is consumed by the delegate as expected.

print(lowered.exported_program())
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_linear1_weight: "f32[1024, 16]", p_linear2_weight: "f32[16, 1024]", x: "f32[1, 16]"):
            # No stacktrace found for following nodes
            lowered_module_0 = self.lowered_module_0
            executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, x);  lowered_module_0 = x = None
            getitem: "f32[1, 16]" = executorch_call_delegate[0];  executorch_call_delegate = None
            return (getitem,)
            
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear1_weight'), target='linear1.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear2_weight'), target='linear2.weight', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}

Versions

All

cc @larryliu0820 @JacobSzwejbka

@GregoryComer GregoryComer added module: runtime Issues related to the core runtime and code under runtime/ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Mar 26, 2025
@github-project-automation github-project-automation bot moved this to To triage in ExecuTorch Core Mar 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: runtime Issues related to the core runtime and code under runtime/ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: To triage
Development

No branches or pull requests

1 participant