diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py index 19a68f9b108..fa53c1818f5 100644 --- a/backends/cadence/aot/tests/test_quantizer_ops.py +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -33,6 +33,7 @@ CadenceWithSoftmaxQuantizer, qconfig_A16, qconfig_A8W8, + qconfig_A8W8sym, ) from executorch.exir.pass_base import NodeMetadata from parameterized import parameterized @@ -53,7 +54,6 @@ # Quantizers intentionally excluded from annotation testing. # These should be explicitly justified when added. EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = { - CadenceDefaultQuantizer, # TODO: T247438143 Add test coverage CadenceFusedConvReluQuantizer, # TODO: T247438151 Add test coverage CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage @@ -137,6 +137,61 @@ # For add: both inputs are activations [qconfig_A8W8.input_activation, qconfig_A8W8.input_activation], ), + # CadenceDefaultQuantizer test cases + ( + "default_matmul_A8W8", + lambda self: self._build_matmul_graph(), + CadenceDefaultQuantizer(), + torch.ops.aten.matmul.default, + qconfig_A8W8.output_activation, + # For matmul: both inputs are activations + [qconfig_A8W8.input_activation, qconfig_A8W8.input_activation], + ), + ( + "default_linear_A8W8", + lambda self: self._build_linear_graph(), + CadenceDefaultQuantizer(), + torch.ops.aten.linear.default, + qconfig_A8W8.output_activation, + # For linear: [input_activation, weight] + [qconfig_A8W8.input_activation, qconfig_A8W8.weight], + ), + ( + "default_conv1d_A8W8sym", + lambda self: self._build_conv1d_graph(), + CadenceDefaultQuantizer(), + torch.ops.aten.conv1d.default, + qconfig_A8W8sym.output_activation, + # For conv1d: [input_activation, weight] with symmetric weights + [qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight], + ), + ( + "default_conv2d_A8W8sym", + lambda self: self._build_conv2d_graph(), + CadenceDefaultQuantizer(), + torch.ops.aten.conv2d.default, + qconfig_A8W8sym.output_activation, + # For conv2d: [input_activation, weight] with symmetric weights + [qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight], + ), + ( + "default_bmm_A8W8", + lambda self: self._build_bmm_graph(), + CadenceDefaultQuantizer(), + torch.ops.aten.bmm.default, + qconfig_A8W8.output_activation, + # For bmm: both inputs are activations + [qconfig_A8W8.input_activation, qconfig_A8W8.input_activation], + ), + ( + "default_relu_A8W8", + lambda self: self._build_relu_graph(), + CadenceDefaultQuantizer(), + torch.ops.aten.relu.default, + qconfig_A8W8.output_activation, + # For relu: only input_activation + [qconfig_A8W8.input_activation], + ), ] # Derive the set of tested quantizer classes from the test cases. @@ -309,6 +364,50 @@ def _build_add_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: self.assertEqual(len(add_nodes), 1, "Should find exactly one add node") return gm, add_nodes[0] + def _build_bmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: + """Build a simple graph with a bmm (batch matrix multiply) operation.""" + builder = GraphBuilder() + # BMM requires 3D tensors: (batch, n, m) @ (batch, m, p) -> (batch, n, p) + x = builder.placeholder("x", torch.randn(2, 4, 8)) + y = builder.placeholder("y", torch.randn(2, 8, 4)) + bmm = builder.call_operator( + op=torch.ops.aten.bmm.default, + args=(x, y), + meta=NodeMetadata( + {"source_fn_stack": [("bmm", torch.ops.aten.bmm.default)]} + ), + ) + builder.output([bmm]) + gm = builder.get_graph_module() + + bmm_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.bmm.default, + ) + self.assertEqual(len(bmm_nodes), 1, "Should find exactly one bmm node") + return gm, bmm_nodes[0] + + def _build_relu_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: + """Build a simple graph with a relu operation.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 10)) + relu = builder.call_operator( + op=torch.ops.aten.relu.default, + args=(x,), + meta=NodeMetadata( + {"source_fn_stack": [("relu", torch.ops.aten.relu.default)]} + ), + ) + builder.output([relu]) + gm = builder.get_graph_module() + + relu_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.relu.default, + ) + self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node") + return gm, relu_nodes[0] + @parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES) def test_quantizer_annotation( self,