File tree 1 file changed +14
-12
lines changed
py/torch_tensorrt/dynamo/conversion
1 file changed +14
-12
lines changed Original file line number Diff line number Diff line change @@ -132,18 +132,20 @@ def convert_module(
132
132
133
133
runtime = trt .Runtime (TRT_LOGGER )
134
134
refit_test_engine = runtime .deserialize_cuda_engine (interpreter_result .engine )
135
- weight_name_map : Any = interpreter_result .weight_name_map
136
- try :
137
- _refit_single_trt_engine_with_gm (
138
- new_gm = module ,
139
- old_engine = refit_test_engine ,
140
- input_list = inputs ,
141
- settings = settings ,
142
- weight_name_map = weight_name_map ,
143
- )
144
- except AssertionError :
145
- logger .warning ("Fast refit test failed. Removing the weight map caching." )
146
- weight_name_map = None
135
+ weight_name_map : Any = None
136
+ # Do the test refit with cached map if make_refitable is enabled
137
+ if settings .make_refitable :
138
+ weight_name_map = interpreter_result .weight_name_map
139
+ try :
140
+ _refit_single_trt_engine_with_gm (
141
+ new_gm = module ,
142
+ old_engine = refit_test_engine ,
143
+ input_list = inputs ,
144
+ settings = settings ,
145
+ weight_name_map = interpreter_result .weight_name_map ,
146
+ )
147
+ except AssertionError :
148
+ logger .warning ("Fast refit test failed. Removing the weight map caching." )
147
149
148
150
rt_cls = PythonTorchTensorRTModule
149
151
You can’t perform that action at this time.
0 commit comments