@@ -242,7 +242,7 @@ def _infer_mode(self, prompt, prompt_embeds, image, prompt_latents, vae_latents,
242
242
def set_text_mode (self ):
243
243
self .mode = "text"
244
244
245
- def set_img_mode (self ):
245
+ def set_image_mode (self ):
246
246
self .mode = "img"
247
247
248
248
def set_text_to_image_mode (self ):
@@ -276,7 +276,8 @@ def _infer_batch_size(self, mode, prompt, prompt_embeds, image, num_samples):
276
276
batch_size = num_samples
277
277
return batch_size
278
278
279
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
279
+ # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
280
+ # self.tokenizer => self.clip_tokenizer
280
281
def _encode_prompt (
281
282
self ,
282
283
prompt ,
@@ -319,25 +320,25 @@ def _encode_prompt(
319
320
batch_size = prompt_embeds .shape [0 ]
320
321
321
322
if prompt_embeds is None :
322
- text_inputs = self .tokenizer (
323
+ text_inputs = self .clip_tokenizer (
323
324
prompt ,
324
325
padding = "max_length" ,
325
- max_length = self .tokenizer .model_max_length ,
326
+ max_length = self .clip_tokenizer .model_max_length ,
326
327
truncation = True ,
327
328
return_tensors = "pt" ,
328
329
)
329
330
text_input_ids = text_inputs .input_ids
330
- untruncated_ids = self .tokenizer (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
331
+ untruncated_ids = self .clip_tokenizer (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
331
332
332
333
if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (
333
334
text_input_ids , untruncated_ids
334
335
):
335
- removed_text = self .tokenizer .batch_decode (
336
- untruncated_ids [:, self .tokenizer .model_max_length - 1 : - 1 ]
336
+ removed_text = self .clip_tokenizer .batch_decode (
337
+ untruncated_ids [:, self .clip_tokenizer .model_max_length - 1 : - 1 ]
337
338
)
338
339
logger .warning (
339
340
"The following part of your input was truncated because CLIP can only handle sequences up to"
340
- f" { self .tokenizer .model_max_length } tokens: { removed_text } "
341
+ f" { self .clip_tokenizer .model_max_length } tokens: { removed_text } "
341
342
)
342
343
343
344
if hasattr (self .text_encoder .config , "use_attention_mask" ) and self .text_encoder .config .use_attention_mask :
@@ -380,7 +381,7 @@ def _encode_prompt(
380
381
uncond_tokens = negative_prompt
381
382
382
383
max_length = prompt_embeds .shape [1 ]
383
- uncond_input = self .tokenizer (
384
+ uncond_input = self .clip_tokenizer (
384
385
uncond_tokens ,
385
386
padding = "max_length" ,
386
387
max_length = max_length ,
@@ -480,24 +481,21 @@ def encode_image_clip_latents(
480
481
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is { type (image )} "
481
482
)
482
483
483
- image = image .to (device = device , dtype = dtype )
484
+ preprocessed_image = self .image_processor .preprocess (
485
+ image ,
486
+ do_center_crop = True ,
487
+ crop_size = resolution ,
488
+ return_tensors = "pt" ,
489
+ )
490
+ preprocessed_image = preprocessed_image .to (device = device , dtype = dtype )
484
491
485
492
if isinstance (generator , list ):
486
493
image_latents = [
487
- self .image_encoder (
488
- ** self .image_processor .preprocess (
489
- image [i : i + 1 ], do_center_crop = True , crop_size = resolution , return_tensors = "pt"
490
- )
491
- )
492
- for i in range (batch_size )
494
+ self .image_encoder (** preprocessed_image [i : i + 1 ]).pooler_output for i in range (batch_size )
493
495
]
494
496
image_latents = torch .cat (image_latents , dim = 0 )
495
497
else :
496
- # TODO: figure out self.image_processor.preprocess kwargs
497
- inputs = self .image_processor .preprocess (
498
- image , do_center_crop = True , crop_size = resolution , return_tensors = "pt"
499
- )
500
- image_latents = self .image_encoder (** inputs )
498
+ image_latents = self .image_encoder (** preprocessed_image ).pooler_output
501
499
502
500
if batch_size > image_latents .shape [0 ] and batch_size % image_latents .shape [0 ] == 0 :
503
501
# expand image_latents for batch_size
@@ -659,7 +657,7 @@ def get_noise_pred(
659
657
prompt_embeds ,
660
658
img_vae ,
661
659
img_clip ,
662
- timesteps ,
660
+ max_timestep ,
663
661
guidance_scale ,
664
662
generator ,
665
663
device ,
@@ -689,17 +687,15 @@ def get_noise_pred(
689
687
img_vae_T = randn_tensor (img_vae .shape , generator = generator , device = device , dtype = img_vae .dtype )
690
688
img_clip_T = randn_tensor (img_clip .shape , generator = generator , device = device , dtype = img_clip .dtype )
691
689
text_T = randn_tensor (prompt_embeds .shape , generator = generator , device = device , dtype = prompt_embeds .dtype )
692
- t_img_uncond = torch .ones_like (t ) * timesteps [0 ]
693
- t_text_uncond = torch .ones_like (t ) * timesteps [0 ]
694
690
695
691
# print(f"t_img_uncond: {t_img_uncond}")
696
692
# print(f"t_img_uncond shape: {t_img_uncond.shape}")
697
693
698
694
# print("Running unconditional U-Net call 1 for CFG...")
699
- _ , _ , text_out_uncond = self .unet (img_vae_T , img_clip_T , text_latents , t_img = t_img_uncond , t_text = t )
695
+ _ , _ , text_out_uncond = self .unet (img_vae_T , img_clip_T , text_latents , t_img = max_timestep , t_text = t )
700
696
# print("Running unconditional U-Net call 2 for CFG...")
701
697
img_vae_out_uncond , img_clip_out_uncond , _ = self .unet (
702
- img_vae_latents , img_clip_latents , text_T , t_img = t , t_text = t_text_uncond
698
+ img_vae_latents , img_clip_latents , text_T , t_img = t , t_text = max_timestep
703
699
)
704
700
705
701
x_out_uncond = self ._combine_joint (img_vae_out_uncond , img_clip_out_uncond , text_out_uncond )
@@ -708,10 +704,9 @@ def get_noise_pred(
708
704
elif mode == "text2img" :
709
705
# Text-conditioned image generation
710
706
img_vae_latents , img_clip_latents = self ._split (latents , height , width )
711
- t_text = torch .zeros (t .size (0 ), dtype = torch .int , device = device )
712
707
713
708
img_vae_out , img_clip_out , text_out = self .unet (
714
- img_vae_latents , img_clip_latents , prompt_embeds , t_img = t , t_text = t_text
709
+ img_vae_latents , img_clip_latents , prompt_embeds , t_img = t , t_text = 0
715
710
)
716
711
717
712
img_out = self ._combine (img_vae_out , img_clip_out )
@@ -721,48 +716,41 @@ def get_noise_pred(
721
716
722
717
# Classifier-free guidance
723
718
text_T = randn_tensor (prompt_embeds .shape , generator = generator , device = device , dtype = prompt_embeds .dtype )
724
- t_text_uncond = torch .ones_like (t ) * timesteps
725
719
726
720
img_vae_out_uncond , img_clip_out_uncond , text_out_uncond = self .unet (
727
- img_vae_latents , img_clip_latents , text_T , t_img = timesteps , t_text = t_text_uncond
721
+ img_vae_latents , img_clip_latents , text_T , t_img = t , t_text = max_timestep
728
722
)
729
723
730
724
img_out_uncond = self ._combine (img_vae_out_uncond , img_clip_out_uncond )
731
725
732
726
return guidance_scale * img_out + (1.0 - guidance_scale ) * img_out_uncond
733
727
elif mode == "img2text" :
734
728
# Image-conditioned text generation
735
- t_img = torch .zeros (t .size (0 ), dtype = torch .int , device = device )
736
-
737
- img_vae_out , img_clip_out , text_out = self .unet (img_vae , img_clip , latents , t_img = t_img , t_text = t )
729
+ img_vae_out , img_clip_out , text_out = self .unet (img_vae , img_clip , latents , t_img = 0 , t_text = t )
738
730
739
731
if guidance_scale <= 1.0 :
740
732
return text_out
741
733
742
734
# Classifier-free guidance
743
735
img_vae_T = randn_tensor (img_vae .shape , generator = generator , device = device , dtype = img_vae .dtype )
744
736
img_clip_T = randn_tensor (img_clip .shape , generator = generator , device = device , dtype = img_clip .dtype )
745
- t_img_uncond = torch .ones_like (t ) * timesteps
746
737
747
738
img_vae_out_uncond , img_clip_out_uncond , text_out_uncond = self .unet (
748
- img_vae_T , img_clip_T , latents , t_img = t_img_uncond , t_text = timesteps
739
+ img_vae_T , img_clip_T , latents , t_img = max_timestep , t_text = t
749
740
)
750
741
751
742
return guidance_scale * text_out + (1.0 - guidance_scale ) * text_out_uncond
752
743
elif mode == "text" :
753
744
# Unconditional ("marginal") text generation (no CFG)
754
- t_img = torch .ones_like (t ) * timesteps
755
-
756
- img_vae_out , img_clip_out , text_out = self .unet (img_vae , img_clip , latents , t_img = t_img , t_text = t )
745
+ img_vae_out , img_clip_out , text_out = self .unet (img_vae , img_clip , latents , t_img = max_timestep , t_text = t )
757
746
758
747
return text_out
759
748
elif mode == "img" :
760
749
# Unconditional ("marginal") image generation (no CFG)
761
750
img_vae_latents , img_clip_latents = self ._split (latents , height , width )
762
- t_text = torch .ones_like (t ) * timesteps
763
751
764
752
img_vae_out , img_clip_out , text_out = self .unet (
765
- img_vae_latents , img_clip_latents , prompt_embeds , t_img = t , t_text = t_text
753
+ img_vae_latents , img_clip_latents , prompt_embeds , t_img = t , t_text = max_timestep
766
754
)
767
755
768
756
img_out = self ._combine (img_vae_out , img_clip_out )
@@ -980,7 +968,7 @@ def __call__(
980
968
assert image is not None
981
969
# Encode image using VAE
982
970
image_vae = preprocess (image )
983
- height , width = image .shape [- 2 :]
971
+ height , width = image_vae .shape [- 2 :]
984
972
image_vae_latents = self .encode_image_vae_latents (
985
973
image_vae ,
986
974
batch_size ,
@@ -1001,6 +989,8 @@ def __call__(
1001
989
device ,
1002
990
generator ,
1003
991
)
992
+ # (batch_size, clip_hidden_size) => (batch_size, 1, clip_hidden_size)
993
+ image_clip_latents = image_clip_latents .unsqueeze (1 )
1004
994
else :
1005
995
# 4.2. Prepare image latent variables, if input not available
1006
996
# Prepare image VAE latents
@@ -1030,6 +1020,7 @@ def __call__(
1030
1020
# 5. Set timesteps
1031
1021
self .scheduler .set_timesteps (num_inference_steps , device = device )
1032
1022
timesteps = self .scheduler .timesteps
1023
+ max_timestep = timesteps [0 ]
1033
1024
# print(f"Timesteps: {timesteps}")
1034
1025
# print(f"Timesteps shape: {timesteps.shape}")
1035
1026
@@ -1062,7 +1053,7 @@ def __call__(
1062
1053
prompt_embeds ,
1063
1054
image_vae_latents ,
1064
1055
image_clip_latents ,
1065
- timesteps ,
1056
+ max_timestep ,
1066
1057
guidance_scale ,
1067
1058
generator ,
1068
1059
device ,
0 commit comments