Skip to content

Commit b054fbc

Browse files
committed
Fixed bugs after rebase
1 parent 6f3142b commit b054fbc

File tree

3 files changed

+17
-17
lines changed

3 files changed

+17
-17
lines changed

py/torch_tensorrt/dynamo/conversion/_conversion.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ def convert_module(
131131
from torch_tensorrt.logging import TRT_LOGGER
132132

133133
runtime = trt.Runtime(TRT_LOGGER)
134-
refit_test_engine = runtime.deserialize_cuda_engine(interpreter_result.engine)
134+
refit_test_engine = runtime.deserialize_cuda_engine(
135+
interpreter_result.serialized_engine
136+
)
135137
weight_name_map: Any = None
136138
# Do the test refit with cached map if make_refitable is enabled
137139
if settings.make_refitable:
@@ -169,5 +171,5 @@ def convert_module(
169171
output_binding_names=list(interpreter_result.output_names),
170172
name=name,
171173
settings=settings,
172-
weight_name_map = weight_name_map
174+
weight_name_map=weight_name_map,
173175
)

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from contextlib import nullcontext
55
from typing import Any, Dict, List, Optional, Sequence, Tuple
66

7+
import tensorrt as trt
78
import torch
89
import torch_tensorrt
910
from torch.nn import Module
@@ -18,8 +19,6 @@
1819
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
1920
from torch_tensorrt.logging import TRT_LOGGER
2021

21-
import tensorrt as trt
22-
2322
logger = logging.getLogger(__name__)
2423

2524

@@ -104,7 +103,6 @@ def __init__(
104103
self.settings = settings
105104
self.engine = None
106105
self.weight_name_map = weight_name_map
107-
self._initialize()
108106

109107
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
110108
self.setup_engine()

tests/py/dynamo/models/test_model_refit.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_fast_refit_one_engine():
108108
new_trt_gm = refit_module_weights(
109109
compiled_module=trt_gm,
110110
new_weight_module=exp_program2,
111-
inputs=inputs,
111+
arg_inputs=inputs,
112112
use_weight_map_cache=True,
113113
)
114114

@@ -155,7 +155,7 @@ def test_fast_refit_one_engine_no_map():
155155
new_trt_gm = refit_module_weights(
156156
compiled_module=trt_gm,
157157
new_weight_module=exp_program2,
158-
inputs=inputs,
158+
arg_inputs=inputs,
159159
use_weight_map_cache=True,
160160
)
161161

@@ -206,7 +206,7 @@ def test_fast_refit_one_engine_wrong_map():
206206
new_trt_gm = refit_module_weights(
207207
compiled_module=trt_gm,
208208
new_weight_module=exp_program2,
209-
inputs=inputs,
209+
arg_inputs=inputs,
210210
use_weight_map_cache=True,
211211
)
212212

@@ -253,7 +253,7 @@ def test_fast_refit_one_engine_bert():
253253
new_trt_gm = refit_module_weights(
254254
compiled_module=trt_gm,
255255
new_weight_module=exp_program2,
256-
inputs=inputs,
256+
arg_inputs=inputs,
257257
use_weight_map_cache=True,
258258
)
259259

@@ -303,7 +303,7 @@ def test_fast_refit_one_engine_inline_runtime():
303303
new_trt_gm = refit_module_weights(
304304
compiled_module=trt_gm,
305305
new_weight_module=exp_program2,
306-
inputs=inputs,
306+
arg_inputs=inputs,
307307
use_weight_map_cache=True,
308308
)
309309

@@ -348,7 +348,7 @@ def test_fast_refit_one_engine_python_runtime():
348348
new_trt_gm = refit_module_weights(
349349
compiled_module=trt_gm,
350350
new_weight_module=exp_program2,
351-
inputs=inputs,
351+
arg_inputs=inputs,
352352
use_weight_map_cache=True,
353353
)
354354

@@ -415,7 +415,7 @@ def forward(self, x):
415415
new_trt_gm = refit_module_weights(
416416
compiled_module=trt_gm,
417417
new_weight_module=exp_program2,
418-
inputs=inputs,
418+
arg_inputs=inputs,
419419
use_weight_map_cache=True,
420420
)
421421

@@ -460,7 +460,7 @@ def test_refit_one_engine():
460460
new_trt_gm = refit_module_weights(
461461
compiled_module=trt_gm,
462462
new_weight_module=exp_program2,
463-
inputs=inputs,
463+
arg_inputs=inputs,
464464
use_weight_map_cache=False,
465465
)
466466

@@ -507,7 +507,7 @@ def test_refit_one_engine_bert():
507507
new_trt_gm = refit_module_weights(
508508
compiled_module=trt_gm,
509509
new_weight_module=exp_program2,
510-
inputs=inputs,
510+
arg_inputs=inputs,
511511
use_weight_map_cache=False,
512512
)
513513

@@ -557,7 +557,7 @@ def test_refit_one_engine_inline_runtime():
557557
new_trt_gm = refit_module_weights(
558558
compiled_module=trt_gm,
559559
new_weight_module=exp_program2,
560-
inputs=inputs,
560+
arg_inputs=inputs,
561561
use_weight_map_cache=False,
562562
)
563563

@@ -602,7 +602,7 @@ def test_refit_one_engine_python_runtime():
602602
new_trt_gm = refit_module_weights(
603603
compiled_module=trt_gm,
604604
new_weight_module=exp_program2,
605-
inputs=inputs,
605+
arg_inputs=inputs,
606606
use_weight_map_cache=False,
607607
)
608608

@@ -669,7 +669,7 @@ def forward(self, x):
669669
new_trt_gm = refit_module_weights(
670670
compiled_module=trt_gm,
671671
new_weight_module=exp_program2,
672-
inputs=inputs,
672+
arg_inputs=inputs,
673673
use_weight_map_cache=False,
674674
)
675675

0 commit comments

Comments
 (0)