@@ -75,6 +75,11 @@ def load_model(path: Union[str, Path], provider=None, sess_options=None, provide
75
75
logger .info ("No onnxruntime provider specified, using CPUExecutionProvider" )
76
76
provider = "CPUExecutionProvider"
77
77
78
+ if provider_options is None :
79
+ provider_options = []
80
+ elif not isinstance (provider_options , list ):
81
+ provider_options = [provider_options ]
82
+
78
83
return ort .InferenceSession (
79
84
path , providers = [provider ], sess_options = sess_options , provider_options = provider_options
80
85
)
@@ -174,7 +179,10 @@ def _from_pretrained(
174
179
# load model from local directory
175
180
if os .path .isdir (model_id ):
176
181
model = OnnxRuntimeModel .load_model (
177
- Path (model_id , model_file_name ).as_posix (), provider = provider , sess_options = sess_options
182
+ Path (model_id , model_file_name ).as_posix (),
183
+ provider = provider ,
184
+ sess_options = sess_options ,
185
+ provider_options = kwargs .pop ("provider_options" ),
178
186
)
179
187
kwargs ["model_save_dir" ] = Path (model_id )
180
188
# load model from hub
@@ -190,7 +198,12 @@ def _from_pretrained(
190
198
)
191
199
kwargs ["model_save_dir" ] = Path (model_cache_path ).parent
192
200
kwargs ["latest_model_name" ] = Path (model_cache_path ).name
193
- model = OnnxRuntimeModel .load_model (model_cache_path , provider = provider , sess_options = sess_options )
201
+ model = OnnxRuntimeModel .load_model (
202
+ model_cache_path ,
203
+ provider = provider ,
204
+ sess_options = sess_options ,
205
+ provider_options = kwargs .pop ("provider_options" ),
206
+ )
194
207
return cls (model = model , ** kwargs )
195
208
196
209
@classmethod
0 commit comments