@@ -65,7 +65,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
65
65
return self .relu (a )
66
66
67
67
68
- example_args = (torch .randn (1 , 3 , 256 , 256 ),)
68
+ example_args : tuple [ torch . Tensor ] = (torch .randn (1 , 3 , 256 , 256 ),)
69
69
aten_dialect : ExportedProgram = export (SimpleConv (), example_args , strict = True )
70
70
print (aten_dialect )
71
71
@@ -100,8 +100,11 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
100
100
return x + y
101
101
102
102
103
- example_args = (torch .randn (3 , 3 ), torch .randn (3 , 3 ))
104
- aten_dialect : ExportedProgram = export (Basic (), example_args , strict = True )
103
+ example_args_2 : tuple [torch .Tensor , torch .Tensor ] = (
104
+ torch .randn (3 , 3 ),
105
+ torch .randn (3 , 3 ),
106
+ )
107
+ aten_dialect = export (Basic (), example_args_2 , strict = True )
105
108
106
109
# Works correctly
107
110
print (aten_dialect .module ()(torch .ones (3 , 3 ), torch .ones (3 , 3 )))
@@ -118,20 +121,11 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
118
121
119
122
from torch .export import Dim
120
123
121
-
122
- class Basic (torch .nn .Module ):
123
- def __init__ (self ):
124
- super ().__init__ ()
125
-
126
- def forward (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
127
- return x + y
128
-
129
-
130
- example_args = (torch .randn (3 , 3 ), torch .randn (3 , 3 ))
124
+ example_args_2 = (torch .randn (3 , 3 ), torch .randn (3 , 3 ))
131
125
dim1_x = Dim ("dim1_x" , min = 1 , max = 10 )
132
126
dynamic_shapes = {"x" : {1 : dim1_x }, "y" : {1 : dim1_x }}
133
- aten_dialect : ExportedProgram = export (
134
- Basic (), example_args , dynamic_shapes = dynamic_shapes , strict = True
127
+ aten_dialect = export (
128
+ Basic (), example_args_2 , dynamic_shapes = dynamic_shapes , strict = True
135
129
)
136
130
print (aten_dialect )
137
131
@@ -207,13 +201,13 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
207
201
)
208
202
209
203
quantizer = XNNPACKQuantizer ().set_global (get_symmetric_quantization_config ())
210
- prepared_graph = prepare_pt2e (pre_autograd_aten_dialect , quantizer )
204
+ prepared_graph = prepare_pt2e (pre_autograd_aten_dialect , quantizer ) # type: ignore[arg-type]
211
205
# calibrate with a sample dataset
212
206
converted_graph = convert_pt2e (prepared_graph )
213
207
print ("Quantized Graph" )
214
208
print (converted_graph )
215
209
216
- aten_dialect : ExportedProgram = export (converted_graph , example_args , strict = True )
210
+ aten_dialect = export (converted_graph , example_args , strict = True )
217
211
print ("ATen Dialect Graph" )
218
212
print (aten_dialect )
219
213
@@ -243,7 +237,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
243
237
from executorch .exir import EdgeProgramManager , to_edge
244
238
245
239
example_args = (torch .randn (1 , 3 , 256 , 256 ),)
246
- aten_dialect : ExportedProgram = export (SimpleConv (), example_args , strict = True )
240
+ aten_dialect = export (SimpleConv (), example_args , strict = True )
247
241
248
242
edge_program : EdgeProgramManager = to_edge (aten_dialect )
249
243
print ("Edge Dialect Graph" )
@@ -272,9 +266,7 @@ def forward(self, x):
272
266
decode_args = (torch .randn (1 , 5 ),)
273
267
aten_decode : ExportedProgram = export (Decode (), decode_args , strict = True )
274
268
275
- edge_program : EdgeProgramManager = to_edge (
276
- {"encode" : aten_encode , "decode" : aten_decode }
277
- )
269
+ edge_program = to_edge ({"encode" : aten_encode , "decode" : aten_decode })
278
270
for method in edge_program .methods :
279
271
print (f"Edge Dialect graph of { method } " )
280
272
print (edge_program .exported_program (method ))
@@ -291,8 +283,8 @@ def forward(self, x):
291
283
# rather than the ``torch.ops.aten`` namespace.
292
284
293
285
example_args = (torch .randn (1 , 3 , 256 , 256 ),)
294
- aten_dialect : ExportedProgram = export (SimpleConv (), example_args , strict = True )
295
- edge_program : EdgeProgramManager = to_edge (aten_dialect )
286
+ aten_dialect = export (SimpleConv (), example_args , strict = True )
287
+ edge_program = to_edge (aten_dialect )
296
288
print ("Edge Dialect Graph" )
297
289
print (edge_program .exported_program ())
298
290
@@ -357,8 +349,8 @@ def forward(self, x):
357
349
358
350
# Export and lower the module to Edge Dialect
359
351
example_args = (torch .ones (1 ),)
360
- aten_dialect : ExportedProgram = export (LowerableModule (), example_args , strict = True )
361
- edge_program : EdgeProgramManager = to_edge (aten_dialect )
352
+ aten_dialect = export (LowerableModule (), example_args , strict = True )
353
+ edge_program = to_edge (aten_dialect )
362
354
to_be_lowered_module = edge_program .exported_program ()
363
355
364
356
from executorch .exir .backend .backend_api import LoweredBackendModule , to_backend
@@ -369,7 +361,7 @@ def forward(self, x):
369
361
)
370
362
371
363
# Lower the module
372
- lowered_module : LoweredBackendModule = to_backend (
364
+ lowered_module : LoweredBackendModule = to_backend ( # type: ignore[call-arg]
373
365
"BackendWithCompilerDemo" , to_be_lowered_module , []
374
366
)
375
367
print (lowered_module )
@@ -423,8 +415,8 @@ def forward(self, x):
423
415
424
416
425
417
example_args = (torch .ones (1 ),)
426
- aten_dialect : ExportedProgram = export (ComposedModule (), example_args , strict = True )
427
- edge_program : EdgeProgramManager = to_edge (aten_dialect )
418
+ aten_dialect = export (ComposedModule (), example_args , strict = True )
419
+ edge_program = to_edge (aten_dialect )
428
420
exported_program = edge_program .exported_program ()
429
421
print ("Edge Dialect graph" )
430
422
print (exported_program )
@@ -460,16 +452,16 @@ def forward(self, a, x, b):
460
452
return z
461
453
462
454
463
- example_args = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
464
- aten_dialect : ExportedProgram = export (Foo (), example_args , strict = True )
465
- edge_program : EdgeProgramManager = to_edge (aten_dialect )
455
+ example_args_3 = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
456
+ aten_dialect = export (Foo (), example_args_3 , strict = True )
457
+ edge_program = to_edge (aten_dialect )
466
458
exported_program = edge_program .exported_program ()
467
459
print ("Edge Dialect graph" )
468
460
print (exported_program )
469
461
470
462
from executorch .exir .backend .test .op_partitioner_demo import AddMulPartitionerDemo
471
463
472
- delegated_program = to_backend (exported_program , AddMulPartitionerDemo ())
464
+ delegated_program = to_backend (exported_program , AddMulPartitionerDemo ()) # type: ignore[call-arg]
473
465
print ("Delegated program" )
474
466
print (delegated_program )
475
467
print (delegated_program .graph_module .lowered_module_0 .original_module )
@@ -484,19 +476,9 @@ def forward(self, a, x, b):
484
476
# call ``to_backend`` on it:
485
477
486
478
487
- class Foo (torch .nn .Module ):
488
- def forward (self , a , x , b ):
489
- y = torch .mm (a , x )
490
- z = y + b
491
- a = z - a
492
- y = torch .mm (a , x )
493
- z = y + b
494
- return z
495
-
496
-
497
- example_args = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
498
- aten_dialect : ExportedProgram = export (Foo (), example_args , strict = True )
499
- edge_program : EdgeProgramManager = to_edge (aten_dialect )
479
+ example_args_3 = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
480
+ aten_dialect = export (Foo (), example_args_3 , strict = True )
481
+ edge_program = to_edge (aten_dialect )
500
482
exported_program = edge_program .exported_program ()
501
483
delegated_program = edge_program .to_backend (AddMulPartitionerDemo ())
502
484
@@ -530,7 +512,6 @@ def forward(self, a, x, b):
530
512
print ("ExecuTorch Dialect" )
531
513
print (executorch_program .exported_program ())
532
514
533
- import executorch .exir as exir
534
515
535
516
######################################################################
536
517
# Notice that in the graph we now see operators like ``torch.ops.aten.sub.out``
@@ -577,13 +558,11 @@ def forward(self, x):
577
558
pre_autograd_aten_dialect = export_for_training (M (), example_args ).module ()
578
559
# Optionally do quantization:
579
560
# pre_autograd_aten_dialect = convert_pt2e(prepare_pt2e(pre_autograd_aten_dialect, CustomBackendQuantizer))
580
- aten_dialect : ExportedProgram = export (
581
- pre_autograd_aten_dialect , example_args , strict = True
582
- )
583
- edge_program : exir .EdgeProgramManager = exir .to_edge (aten_dialect )
561
+ aten_dialect = export (pre_autograd_aten_dialect , example_args , strict = True )
562
+ edge_program = to_edge (aten_dialect )
584
563
# Optionally do delegation:
585
564
# edge_program = edge_program.to_backend(CustomBackendPartitioner)
586
- executorch_program : exir . ExecutorchProgramManager = edge_program .to_executorch (
565
+ executorch_program = edge_program .to_executorch (
587
566
ExecutorchBackendConfig (
588
567
passes = [], # User-defined passes
589
568
)
0 commit comments