Skip to content

Commit 4a9ab65

Browse files
urpetkov-amdUros Petkovicsayakpaulyiyixuxugithub-actions[bot]
authored
Fixing missing provider options argument (#11397)
* Fixing missing provider options argument * Adding if else for provider options * Apply suggestions from code review Co-authored-by: YiYi Xu <[email protected]> * Apply style fixes * Update src/diffusers/pipelines/onnx_utils.py Co-authored-by: YiYi Xu <[email protected]> * Update src/diffusers/pipelines/onnx_utils.py Co-authored-by: YiYi Xu <[email protected]> --------- Co-authored-by: Uros Petkovic <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 0ac1d5b commit 4a9ab65

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

src/diffusers/pipelines/onnx_utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ def load_model(path: Union[str, Path], provider=None, sess_options=None, provide
7575
logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
7676
provider = "CPUExecutionProvider"
7777

78+
if provider_options is None:
79+
provider_options = []
80+
elif not isinstance(provider_options, list):
81+
provider_options = [provider_options]
82+
7883
return ort.InferenceSession(
7984
path, providers=[provider], sess_options=sess_options, provider_options=provider_options
8085
)
@@ -174,7 +179,10 @@ def _from_pretrained(
174179
# load model from local directory
175180
if os.path.isdir(model_id):
176181
model = OnnxRuntimeModel.load_model(
177-
Path(model_id, model_file_name).as_posix(), provider=provider, sess_options=sess_options
182+
Path(model_id, model_file_name).as_posix(),
183+
provider=provider,
184+
sess_options=sess_options,
185+
provider_options=kwargs.pop("provider_options"),
178186
)
179187
kwargs["model_save_dir"] = Path(model_id)
180188
# load model from hub
@@ -190,7 +198,12 @@ def _from_pretrained(
190198
)
191199
kwargs["model_save_dir"] = Path(model_cache_path).parent
192200
kwargs["latest_model_name"] = Path(model_cache_path).name
193-
model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options)
201+
model = OnnxRuntimeModel.load_model(
202+
model_cache_path,
203+
provider=provider,
204+
sess_options=sess_options,
205+
provider_options=kwargs.pop("provider_options"),
206+
)
194207
return cls(model=model, **kwargs)
195208

196209
@classmethod

0 commit comments

Comments
 (0)