@@ -126,9 +126,7 @@ def return_wrapper():
126
126
trace_inputs_method = "get_upper_bound_inputs"
127
127
get_trace_inputs = get_inputs_adapter (
128
128
(
129
- # pyre-fixme[6]: For 1st argument expected `(...) -> Any` but got
130
- # `Union[Module, Tensor]`.
131
- getattr (eager_module , trace_inputs_method )
129
+ getattr (eager_module , trace_inputs_method ) # type: ignore[arg-type]
132
130
if hasattr (eager_module , trace_inputs_method )
133
131
else eager_module .get_random_inputs
134
132
),
@@ -144,18 +142,14 @@ def return_wrapper():
144
142
if hasattr (eager_module , "get_dynamic_shapes" ):
145
143
assert capture_config is not None
146
144
assert capture_config .enable_aot is True
147
- # pyre-fixme[29]: `Union[nn.modules.module.Module,
148
- # torch._tensor.Tensor]` is not a function.
149
- trace_dynamic_shapes = eager_module .get_dynamic_shapes ()
145
+ trace_dynamic_shapes = eager_module .get_dynamic_shapes () # type: ignore[operator]
150
146
method_name_to_dynamic_shapes = {}
151
147
for method in methods :
152
148
method_name_to_dynamic_shapes [method ] = trace_dynamic_shapes
153
149
154
150
memory_planning_pass = MemoryPlanningPass ()
155
151
if hasattr (eager_module , "get_memory_planning_pass" ):
156
- # pyre-fixme[29]: `Union[nn.modules.module.Module,
157
- # torch._tensor.Tensor]` is not a function.
158
- memory_planning_pass = eager_module .get_memory_planning_pass ()
152
+ memory_planning_pass = eager_module .get_memory_planning_pass () # type: ignore[operator]
159
153
160
154
class WrapperModule (nn .Module ):
161
155
def __init__ (self , method ):
@@ -172,7 +166,7 @@ def __init__(self, method):
172
166
assert method_name == "forward"
173
167
ep = _export (
174
168
eager_module ,
175
- method_input ,
169
+ method_input , # type: ignore[arg-type]
176
170
dynamic_shapes = (
177
171
method_name_to_dynamic_shapes [method_name ]
178
172
if method_name_to_dynamic_shapes
@@ -184,7 +178,7 @@ def __init__(self, method):
184
178
else :
185
179
exported_methods [method_name ] = export (
186
180
eager_module ,
187
- method_input ,
181
+ method_input , # type: ignore[arg-type]
188
182
dynamic_shapes = (
189
183
method_name_to_dynamic_shapes [method_name ]
190
184
if method_name_to_dynamic_shapes
@@ -220,9 +214,7 @@ def __init__(self, method):
220
214
221
215
# Get a function that creates random inputs appropriate for testing.
222
216
get_random_inputs_fn = get_inputs_adapter (
223
- # pyre-fixme[6]: For 1st argument expected `(...) -> Any` but got
224
- # `Union[Module, Tensor]`.
225
- eager_module .get_random_inputs ,
217
+ eager_module .get_random_inputs , # type: ignore[arg-type]
226
218
# all exported methods must have the same signature so just pick the first one.
227
219
methods [0 ],
228
220
)
0 commit comments