Skip to content

Commit e56da2e

Browse files
committed
Transition to trt 10
1 parent b33e611 commit e56da2e

File tree

3 files changed

+33
-45
lines changed

3 files changed

+33
-45
lines changed

install.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
import sys
33

44
python = sys.executable
5-
5+
TRT_VERSION="10.0.0b6"
66

77
def install():
88
if not launch.is_installed("importlib_metadata"):
99
launch.run_pip("install importlib_metadata", "importlib_metadata", live=True)
1010
from importlib_metadata import version
1111

1212
if launch.is_installed("tensorrt"):
13-
if not version("tensorrt") == "9.2.0.post12.dev5":
13+
if not version("tensorrt") == TRT_VERSION:
1414
launch.run(
1515
["python", "-m", "pip", "uninstall", "-y", "tensorrt"],
1616
"removing old version of tensorrt",
@@ -19,24 +19,10 @@ def install():
1919
if not launch.is_installed("tensorrt"):
2020
print("TensorRT is not installed! Installing...")
2121
launch.run_pip(
22-
"install nvidia-cudnn-cu11==8.9.4.25 --no-cache-dir", "nvidia-cudnn-cu11"
23-
)
24-
launch.run_pip(
25-
"install --pre --extra-index-url https://pypi.nvidia.com tensorrt==9.2.0.post12.dev5 --no-cache-dir",
22+
f"install --extra-index-url https://pypi.nvidia.com tensorrt=={TRT_VERSION} --no-cache-dir",
2623
"tensorrt",
2724
live=True,
2825
)
29-
launch.run(
30-
["python", "-m", "pip", "uninstall", "-y", "nvidia-cudnn-cu11"],
31-
"removing nvidia-cudnn-cu11",
32-
)
33-
34-
if launch.is_installed("nvidia-cudnn-cu11"):
35-
if version("nvidia-cudnn-cu11") == "8.9.4.25":
36-
launch.run(
37-
["python", "-m", "pip", "uninstall", "-y", "nvidia-cudnn-cu11"],
38-
"removing nvidia-cudnn-cu11",
39-
)
4026

4127
# Polygraphy
4228
if not launch.is_installed("polygraphy"):

model_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(self, model_file=MODEL_FILE) -> None:
4545
self.update()
4646

4747
@staticmethod
48-
def get_onnx_path(model_name, is_contolnet: bool = False):
48+
def get_onnx_path(model_name, is_contolnet: bool = True):
4949
if is_contolnet:
5050
model_name = f"{model_name}_cnet"
5151
onnx_filename = f"{model_name}.onnx"

utilities.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,6 @@ def refit_from_dict(self, refit_weights, is_fp16):
200200
trt_datatype = trt.DataType.HALF
201201

202202
# trt.Weight and trt.TensorLocation
203-
refit_weights[trt_weight_name] = refit_weights[trt_weight_name].cpu()
204203
trt_wt_tensor = trt.Weights(
205204
trt_datatype,
206205
refit_weights[trt_weight_name].data_ptr(),
@@ -213,15 +212,16 @@ def refit_from_dict(self, refit_weights, is_fp16):
213212
)
214213

215214
# apply refit
216-
# refitter.set_named_weights(trt_weight_name, trt_wt_tensor, trt_wt_location)
217-
refitter.set_named_weights(trt_weight_name, trt_wt_tensor)
215+
refitter.set_named_weights(trt_weight_name, trt_wt_tensor, trt_wt_location)
218216
refitted_weights.add(trt_weight_name)
219217

220218
assert set(refitted_weights) == set(refit_weights.keys())
221219
if not refitter.refit_cuda_engine():
222220
print("Error: failed to refit new weights.")
223221
exit(0)
224222

223+
print(f"[I] Total refitted weights {len(refitted_weights)}.")
224+
225225
def build(
226226
self,
227227
onnx_path,
@@ -240,14 +240,18 @@ def build(
240240
for _p, i_profile in zip(p, input_profile):
241241
for name, dims in i_profile.items():
242242
assert len(dims) == 3
243-
_p.add(name, min=dims[0], opt=dims[1], max=dims[2])
243+
_p.add(namFe, min=dims[0], opt=dims[1], max=dims[2])
244244

245245
config_kwargs = {}
246246
if not enable_all_tactics:
247247
config_kwargs["tactic_sources"] = []
248248

249249
network = network_from_onnx_path(
250-
onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]
250+
onnx_path,
251+
flags=[
252+
trt.OnnxParserFlag.NATIVE_INSTANCENORM,
253+
trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED,
254+
],
251255
)
252256
if update_output_names:
253257
print(f"Updating network outputs to {update_output_names}")
@@ -257,7 +261,6 @@ def build(
257261
config = builder.create_builder_config()
258262
config.progress_monitor = TQDMProgressMonitor()
259263

260-
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
261264
config.set_flag(trt.BuilderFlag.FP16) if fp16 else None
262265
config.set_flag(trt.BuilderFlag.REFIT) if enable_refit else None
263266

@@ -305,53 +308,52 @@ def load(self):
305308
print(f"Loading TensorRT engine: {self.engine_path}")
306309
self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
307310

308-
def activate(self, reuse_device_memory=None):
311+
def activate(self, reuse_device_memory=False):
309312
if reuse_device_memory:
310313
self.context = self.engine.create_execution_context_without_device_memory()
311-
# self.context.device_memory = reuse_device_memory
312314
else:
313315
self.context = self.engine.create_execution_context()
314316

315317
def allocate_buffers(self, shape_dict=None, device="cuda", additional_shapes=None):
316318
nvtx.range_push("allocate_buffers")
317-
for idx in range(self.engine.num_io_tensors):
318-
binding = self.engine[idx]
319-
if shape_dict and binding in shape_dict:
320-
shape = shape_dict[binding].shape
321-
elif additional_shapes and binding in additional_shapes:
322-
shape = additional_shapes[binding]
319+
for binding in range(self.engine.num_io_tensors):
320+
name = self.engine.get_tensor_name(binding)
321+
322+
if shape_dict and name in shape_dict:
323+
shape = shape_dict[name].shape
324+
elif additional_shapes and name in additional_shapes:
325+
shape = additional_shapes[name]
323326
else:
324-
shape = self.context.get_binding_shape(idx)
325-
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
326-
if self.engine.binding_is_input(binding):
327-
self.context.set_binding_shape(idx, shape)
327+
shape = self.context.get_tensor_shape(name)
328+
329+
dtype = trt.nptype(self.engine.get_tensor_dtype(name))
330+
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
331+
self.context.set_input_shape(name, shape)
328332
tensor = torch.zeros(
329333
tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]
330334
).to(device=device)
331-
self.tensors[binding] = tensor
335+
self.tensors[name] = tensor
332336
nvtx.range_pop()
333337

334338
def infer(self, feed_dict, stream, use_cuda_graph=False):
335-
nvtx.range_push("set_tensors")
339+
336340
for name, buf in feed_dict.items():
337341
self.tensors[name].copy_(buf)
338342

339343
for name, tensor in self.tensors.items():
340344
self.context.set_tensor_address(name, tensor.data_ptr())
341-
nvtx.range_pop()
342-
nvtx.range_push("execute")
345+
343346
noerror = self.context.execute_async_v3(stream)
344347
if not noerror:
345-
raise ValueError("ERROR: inference failed.")
346-
nvtx.range_pop()
348+
raise ValueError(f"ERROR: inference failed.")
349+
347350
return self.tensors
348351

349352
def __str__(self):
350353
out = ""
351354
for opt_profile in range(self.engine.num_optimization_profiles):
352-
for binding_idx in range(self.engine.num_bindings):
353-
name = self.engine.get_binding_name(binding_idx)
355+
for binding in range(self.engine.num_io_tensors):
356+
name = self.engine.get_tensor_name(binding)
354357
shape = self.engine.get_profile_shape(opt_profile, name)
355358
out += f"\t{name} = {shape}\n"
356359
return out
357-

0 commit comments

Comments
 (0)