Skip to content

Commit bf92e74

Browse files
authored
fix StableDiffusionTensorRT super args error (#6009)
1 parent b785a15 commit bf92e74

File tree

3 files changed

+33
-6
lines changed

3 files changed

+33
-6
lines changed

examples/community/stable_diffusion_tensorrt_img2img.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
save_engine,
4242
)
4343
from polygraphy.backend.trt import util as trt_util
44-
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
44+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
4545

4646
from diffusers.models import AutoencoderKL, UNet2DConditionModel
4747
from diffusers.pipelines.stable_diffusion import (
@@ -709,6 +709,7 @@ def __init__(
709709
scheduler: DDIMScheduler,
710710
safety_checker: StableDiffusionSafetyChecker,
711711
feature_extractor: CLIPFeatureExtractor,
712+
image_encoder: CLIPVisionModelWithProjection = None,
712713
requires_safety_checker: bool = True,
713714
stages=["clip", "unet", "vae", "vae_encoder"],
714715
image_height: int = 512,
@@ -724,7 +725,15 @@ def __init__(
724725
timing_cache: str = "timing_cache",
725726
):
726727
super().__init__(
727-
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
728+
vae,
729+
text_encoder,
730+
tokenizer,
731+
unet,
732+
scheduler,
733+
safety_checker=safety_checker,
734+
feature_extractor=feature_extractor,
735+
image_encoder=image_encoder,
736+
requires_safety_checker=requires_safety_checker,
728737
)
729738

730739
self.vae.forward = self.vae.decode

examples/community/stable_diffusion_tensorrt_inpaint.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
save_engine,
4242
)
4343
from polygraphy.backend.trt import util as trt_util
44-
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
44+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
4545

4646
from diffusers.models import AutoencoderKL, UNet2DConditionModel
4747
from diffusers.pipelines.stable_diffusion import (
@@ -710,6 +710,7 @@ def __init__(
710710
scheduler: DDIMScheduler,
711711
safety_checker: StableDiffusionSafetyChecker,
712712
feature_extractor: CLIPFeatureExtractor,
713+
image_encoder: CLIPVisionModelWithProjection = None,
713714
requires_safety_checker: bool = True,
714715
stages=["clip", "unet", "vae", "vae_encoder"],
715716
image_height: int = 512,
@@ -725,7 +726,15 @@ def __init__(
725726
timing_cache: str = "timing_cache",
726727
):
727728
super().__init__(
728-
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
729+
vae,
730+
text_encoder,
731+
tokenizer,
732+
unet,
733+
scheduler,
734+
safety_checker=safety_checker,
735+
feature_extractor=feature_extractor,
736+
image_encoder=image_encoder,
737+
requires_safety_checker=requires_safety_checker,
729738
)
730739

731740
self.vae.forward = self.vae.decode

examples/community/stable_diffusion_tensorrt_txt2img.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
save_engine,
4141
)
4242
from polygraphy.backend.trt import util as trt_util
43-
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
43+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
4444

4545
from diffusers.models import AutoencoderKL, UNet2DConditionModel
4646
from diffusers.pipelines.stable_diffusion import (
@@ -624,6 +624,7 @@ def __init__(
624624
scheduler: DDIMScheduler,
625625
safety_checker: StableDiffusionSafetyChecker,
626626
feature_extractor: CLIPFeatureExtractor,
627+
image_encoder: CLIPVisionModelWithProjection = None,
627628
requires_safety_checker: bool = True,
628629
stages=["clip", "unet", "vae"],
629630
image_height: int = 768,
@@ -639,7 +640,15 @@ def __init__(
639640
timing_cache: str = "timing_cache",
640641
):
641642
super().__init__(
642-
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
643+
vae,
644+
text_encoder,
645+
tokenizer,
646+
unet,
647+
scheduler,
648+
safety_checker=safety_checker,
649+
feature_extractor=feature_extractor,
650+
image_encoder=image_encoder,
651+
requires_safety_checker=requires_safety_checker,
643652
)
644653

645654
self.vae.forward = self.vae.decode

0 commit comments

Comments
 (0)