@@ -108,7 +108,7 @@ def test_fast_refit_one_engine():
108
108
new_trt_gm = refit_module_weights (
109
109
compiled_module = trt_gm ,
110
110
new_weight_module = exp_program2 ,
111
- inputs = inputs ,
111
+ arg_inputs = inputs ,
112
112
use_weight_map_cache = True ,
113
113
)
114
114
@@ -155,7 +155,7 @@ def test_fast_refit_one_engine_no_map():
155
155
new_trt_gm = refit_module_weights (
156
156
compiled_module = trt_gm ,
157
157
new_weight_module = exp_program2 ,
158
- inputs = inputs ,
158
+ arg_inputs = inputs ,
159
159
use_weight_map_cache = True ,
160
160
)
161
161
@@ -206,7 +206,7 @@ def test_fast_refit_one_engine_wrong_map():
206
206
new_trt_gm = refit_module_weights (
207
207
compiled_module = trt_gm ,
208
208
new_weight_module = exp_program2 ,
209
- inputs = inputs ,
209
+ arg_inputs = inputs ,
210
210
use_weight_map_cache = True ,
211
211
)
212
212
@@ -253,7 +253,7 @@ def test_fast_refit_one_engine_bert():
253
253
new_trt_gm = refit_module_weights (
254
254
compiled_module = trt_gm ,
255
255
new_weight_module = exp_program2 ,
256
- inputs = inputs ,
256
+ arg_inputs = inputs ,
257
257
use_weight_map_cache = True ,
258
258
)
259
259
@@ -303,7 +303,7 @@ def test_fast_refit_one_engine_inline_runtime():
303
303
new_trt_gm = refit_module_weights (
304
304
compiled_module = trt_gm ,
305
305
new_weight_module = exp_program2 ,
306
- inputs = inputs ,
306
+ arg_inputs = inputs ,
307
307
use_weight_map_cache = True ,
308
308
)
309
309
@@ -348,7 +348,7 @@ def test_fast_refit_one_engine_python_runtime():
348
348
new_trt_gm = refit_module_weights (
349
349
compiled_module = trt_gm ,
350
350
new_weight_module = exp_program2 ,
351
- inputs = inputs ,
351
+ arg_inputs = inputs ,
352
352
use_weight_map_cache = True ,
353
353
)
354
354
@@ -415,7 +415,7 @@ def forward(self, x):
415
415
new_trt_gm = refit_module_weights (
416
416
compiled_module = trt_gm ,
417
417
new_weight_module = exp_program2 ,
418
- inputs = inputs ,
418
+ arg_inputs = inputs ,
419
419
use_weight_map_cache = True ,
420
420
)
421
421
@@ -460,7 +460,7 @@ def test_refit_one_engine():
460
460
new_trt_gm = refit_module_weights (
461
461
compiled_module = trt_gm ,
462
462
new_weight_module = exp_program2 ,
463
- inputs = inputs ,
463
+ arg_inputs = inputs ,
464
464
use_weight_map_cache = False ,
465
465
)
466
466
@@ -507,7 +507,7 @@ def test_refit_one_engine_bert():
507
507
new_trt_gm = refit_module_weights (
508
508
compiled_module = trt_gm ,
509
509
new_weight_module = exp_program2 ,
510
- inputs = inputs ,
510
+ arg_inputs = inputs ,
511
511
use_weight_map_cache = False ,
512
512
)
513
513
@@ -557,7 +557,7 @@ def test_refit_one_engine_inline_runtime():
557
557
new_trt_gm = refit_module_weights (
558
558
compiled_module = trt_gm ,
559
559
new_weight_module = exp_program2 ,
560
- inputs = inputs ,
560
+ arg_inputs = inputs ,
561
561
use_weight_map_cache = False ,
562
562
)
563
563
@@ -602,7 +602,7 @@ def test_refit_one_engine_python_runtime():
602
602
new_trt_gm = refit_module_weights (
603
603
compiled_module = trt_gm ,
604
604
new_weight_module = exp_program2 ,
605
- inputs = inputs ,
605
+ arg_inputs = inputs ,
606
606
use_weight_map_cache = False ,
607
607
)
608
608
@@ -669,7 +669,7 @@ def forward(self, x):
669
669
new_trt_gm = refit_module_weights (
670
670
compiled_module = trt_gm ,
671
671
new_weight_module = exp_program2 ,
672
- inputs = inputs ,
672
+ arg_inputs = inputs ,
673
673
use_weight_map_cache = False ,
674
674
)
675
675
0 commit comments