Skip to content

Commit 6f3142b

Browse files
committed
Fixed a bug of regular engine compilation
1 parent b51b0f4 commit 6f3142b

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

py/torch_tensorrt/dynamo/conversion/_conversion.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -132,18 +132,20 @@ def convert_module(
132132

133133
runtime = trt.Runtime(TRT_LOGGER)
134134
refit_test_engine = runtime.deserialize_cuda_engine(interpreter_result.engine)
135-
weight_name_map: Any = interpreter_result.weight_name_map
136-
try:
137-
_refit_single_trt_engine_with_gm(
138-
new_gm=module,
139-
old_engine=refit_test_engine,
140-
input_list=inputs,
141-
settings=settings,
142-
weight_name_map=weight_name_map,
143-
)
144-
except AssertionError:
145-
logger.warning("Fast refit test failed. Removing the weight map caching.")
146-
weight_name_map = None
135+
weight_name_map: Any = None
136+
# Do the test refit with cached map if make_refitable is enabled
137+
if settings.make_refitable:
138+
weight_name_map = interpreter_result.weight_name_map
139+
try:
140+
_refit_single_trt_engine_with_gm(
141+
new_gm=module,
142+
old_engine=refit_test_engine,
143+
input_list=inputs,
144+
settings=settings,
145+
weight_name_map=interpreter_result.weight_name_map,
146+
)
147+
except AssertionError:
148+
logger.warning("Fast refit test failed. Removing the weight map caching.")
147149

148150
rt_cls = PythonTorchTensorRTModule
149151

0 commit comments

Comments
 (0)