@@ -200,7 +200,6 @@ def refit_from_dict(self, refit_weights, is_fp16):
200
200
trt_datatype = trt .DataType .HALF
201
201
202
202
# trt.Weight and trt.TensorLocation
203
- refit_weights [trt_weight_name ] = refit_weights [trt_weight_name ].cpu ()
204
203
trt_wt_tensor = trt .Weights (
205
204
trt_datatype ,
206
205
refit_weights [trt_weight_name ].data_ptr (),
@@ -213,15 +212,16 @@ def refit_from_dict(self, refit_weights, is_fp16):
213
212
)
214
213
215
214
# apply refit
216
- # refitter.set_named_weights(trt_weight_name, trt_wt_tensor, trt_wt_location)
217
- refitter .set_named_weights (trt_weight_name , trt_wt_tensor )
215
+ refitter .set_named_weights (trt_weight_name , trt_wt_tensor , trt_wt_location )
218
216
refitted_weights .add (trt_weight_name )
219
217
220
218
assert set (refitted_weights ) == set (refit_weights .keys ())
221
219
if not refitter .refit_cuda_engine ():
222
220
print ("Error: failed to refit new weights." )
223
221
exit (0 )
224
222
223
+ print (f"[I] Total refitted weights { len (refitted_weights )} ." )
224
+
225
225
def build (
226
226
self ,
227
227
onnx_path ,
@@ -240,14 +240,18 @@ def build(
240
240
for _p , i_profile in zip (p , input_profile ):
241
241
for name , dims in i_profile .items ():
242
242
assert len (dims ) == 3
243
- _p .add (name , min = dims [0 ], opt = dims [1 ], max = dims [2 ])
243
+ _p .add (namFe , min = dims [0 ], opt = dims [1 ], max = dims [2 ])
244
244
245
245
config_kwargs = {}
246
246
if not enable_all_tactics :
247
247
config_kwargs ["tactic_sources" ] = []
248
248
249
249
network = network_from_onnx_path (
250
- onnx_path , flags = [trt .OnnxParserFlag .NATIVE_INSTANCENORM ]
250
+ onnx_path ,
251
+ flags = [
252
+ trt .OnnxParserFlag .NATIVE_INSTANCENORM ,
253
+ trt .NetworkDefinitionCreationFlag .STRONGLY_TYPED ,
254
+ ],
251
255
)
252
256
if update_output_names :
253
257
print (f"Updating network outputs to { update_output_names } " )
@@ -257,7 +261,6 @@ def build(
257
261
config = builder .create_builder_config ()
258
262
config .progress_monitor = TQDMProgressMonitor ()
259
263
260
- config .set_flag (trt .BuilderFlag .STRICT_TYPES )
261
264
config .set_flag (trt .BuilderFlag .FP16 ) if fp16 else None
262
265
config .set_flag (trt .BuilderFlag .REFIT ) if enable_refit else None
263
266
@@ -305,53 +308,52 @@ def load(self):
305
308
print (f"Loading TensorRT engine: { self .engine_path } " )
306
309
self .engine = engine_from_bytes (bytes_from_path (self .engine_path ))
307
310
308
- def activate (self , reuse_device_memory = None ):
311
+ def activate (self , reuse_device_memory = False ):
309
312
if reuse_device_memory :
310
313
self .context = self .engine .create_execution_context_without_device_memory ()
311
- # self.context.device_memory = reuse_device_memory
312
314
else :
313
315
self .context = self .engine .create_execution_context ()
314
316
315
317
def allocate_buffers (self , shape_dict = None , device = "cuda" , additional_shapes = None ):
316
318
nvtx .range_push ("allocate_buffers" )
317
- for idx in range (self .engine .num_io_tensors ):
318
- binding = self .engine [idx ]
319
- if shape_dict and binding in shape_dict :
320
- shape = shape_dict [binding ].shape
321
- elif additional_shapes and binding in additional_shapes :
322
- shape = additional_shapes [binding ]
319
+ for binding in range (self .engine .num_io_tensors ):
320
+ name = self .engine .get_tensor_name (binding )
321
+
322
+ if shape_dict and name in shape_dict :
323
+ shape = shape_dict [name ].shape
324
+ elif additional_shapes and name in additional_shapes :
325
+ shape = additional_shapes [name ]
323
326
else :
324
- shape = self .context .get_binding_shape (idx )
325
- dtype = trt .nptype (self .engine .get_binding_dtype (binding ))
326
- if self .engine .binding_is_input (binding ):
327
- self .context .set_binding_shape (idx , shape )
327
+ shape = self .context .get_tensor_shape (name )
328
+
329
+ dtype = trt .nptype (self .engine .get_tensor_dtype (name ))
330
+ if self .engine .get_tensor_mode (name ) == trt .TensorIOMode .INPUT :
331
+ self .context .set_input_shape (name , shape )
328
332
tensor = torch .zeros (
329
333
tuple (shape ), dtype = numpy_to_torch_dtype_dict [dtype ]
330
334
).to (device = device )
331
- self .tensors [binding ] = tensor
335
+ self .tensors [name ] = tensor
332
336
nvtx .range_pop ()
333
337
334
338
def infer (self , feed_dict , stream , use_cuda_graph = False ):
335
- nvtx . range_push ( "set_tensors" )
339
+
336
340
for name , buf in feed_dict .items ():
337
341
self .tensors [name ].copy_ (buf )
338
342
339
343
for name , tensor in self .tensors .items ():
340
344
self .context .set_tensor_address (name , tensor .data_ptr ())
341
- nvtx .range_pop ()
342
- nvtx .range_push ("execute" )
345
+
343
346
noerror = self .context .execute_async_v3 (stream )
344
347
if not noerror :
345
- raise ValueError ("ERROR: inference failed." )
346
- nvtx . range_pop ()
348
+ raise ValueError (f "ERROR: inference failed." )
349
+
347
350
return self .tensors
348
351
349
352
def __str__ (self ):
350
353
out = ""
351
354
for opt_profile in range (self .engine .num_optimization_profiles ):
352
- for binding_idx in range (self .engine .num_bindings ):
353
- name = self .engine .get_binding_name ( binding_idx )
355
+ for binding in range (self .engine .num_io_tensors ):
356
+ name = self .engine .get_tensor_name ( binding )
354
357
shape = self .engine .get_profile_shape (opt_profile , name )
355
358
out += f"\t { name } = { shape } \n "
356
359
return out
357
-
0 commit comments