@@ -73,7 +73,14 @@ def __init__(
73
73
requires_safety_checker : bool = True ,
74
74
):
75
75
super ().__init__ (
76
- vae , text_encoder , tokenizer , unet , scheduler , safety_checker , feature_extractor , requires_safety_checker
76
+ vae ,
77
+ text_encoder ,
78
+ tokenizer ,
79
+ unet ,
80
+ scheduler ,
81
+ safety_checker ,
82
+ feature_extractor ,
83
+ requires_safety_checker ,
77
84
)
78
85
self .register_modules (
79
86
vae = vae ,
@@ -102,22 +109,22 @@ def __call__(
102
109
return_dict : bool = True ,
103
110
rp_args : Dict [str , str ] = None ,
104
111
):
105
- active = KBRK in prompt [0 ] if type (prompt ) == list else KBRK in prompt # noqa: E721
112
+ active = KBRK in prompt [0 ] if isinstance (prompt , list ) else KBRK in prompt
106
113
if negative_prompt is None :
107
- negative_prompt = "" if type (prompt ) == str else ["" ] * len (prompt ) # noqa: E721
114
+ negative_prompt = "" if isinstance (prompt , str ) else ["" ] * len (prompt )
108
115
109
116
device = self ._execution_device
110
117
regions = 0
111
118
112
119
self .power = int (rp_args ["power" ]) if "power" in rp_args else 1
113
120
114
- prompts = prompt if type (prompt ) == list else [prompt ] # noqa: E721
115
- n_prompts = negative_prompt if type ( negative_prompt ) == list else [negative_prompt ] # noqa: E721
121
+ prompts = prompt if isinstance (prompt , list ) else [prompt ]
122
+ n_prompts = negative_prompt if isinstance ( prompt , str ) else [negative_prompt ]
116
123
self .batch = batch = num_images_per_prompt * len (prompts )
117
124
all_prompts_cn , all_prompts_p = promptsmaker (prompts , num_images_per_prompt )
118
125
all_n_prompts_cn , _ = promptsmaker (n_prompts , num_images_per_prompt )
119
126
120
- cn = len (all_prompts_cn ) == len (all_n_prompts_cn )
127
+ equal = len (all_prompts_cn ) == len (all_n_prompts_cn )
121
128
122
129
if Compel :
123
130
compel = Compel (tokenizer = self .tokenizer , text_encoder = self .text_encoder )
@@ -129,15 +136,15 @@ def getcompelembs(prps):
129
136
return torch .cat (embl )
130
137
131
138
conds = getcompelembs (all_prompts_cn )
132
- unconds = getcompelembs (all_n_prompts_cn ) if cn else getcompelembs ( n_prompts )
139
+ unconds = getcompelembs (all_n_prompts_cn )
133
140
embs = getcompelembs (prompts )
134
141
n_embs = getcompelembs (n_prompts )
135
142
prompt = negative_prompt = None
136
143
else :
137
144
conds = self .encode_prompt (prompts , device , 1 , True )[0 ]
138
145
unconds = (
139
146
self .encode_prompt (n_prompts , device , 1 , True )[0 ]
140
- if cn
147
+ if equal
141
148
else self .encode_prompt (all_n_prompts_cn , device , 1 , True )[0 ]
142
149
)
143
150
embs = n_embs = None
@@ -206,8 +213,11 @@ def forward(
206
213
else :
207
214
px , nx = hidden_states .chunk (2 )
208
215
209
- if cn :
210
- hidden_states = torch .cat ([px for i in range (regions )] + [nx for i in range (regions )], 0 )
216
+ if equal :
217
+ hidden_states = torch .cat (
218
+ [px for i in range (regions )] + [nx for i in range (regions )],
219
+ 0 ,
220
+ )
211
221
encoder_hidden_states = torch .cat ([conds ] + [unconds ])
212
222
else :
213
223
hidden_states = torch .cat ([px for i in range (regions )] + [nx ], 0 )
@@ -289,9 +299,9 @@ def forward(
289
299
if any (x in mode for x in ["COL" , "ROW" ]):
290
300
reshaped = hidden_states .reshape (hidden_states .size ()[0 ], h , w , hidden_states .size ()[2 ])
291
301
center = reshaped .shape [0 ] // 2
292
- px = reshaped [0 :center ] if cn else reshaped [0 :- batch ]
293
- nx = reshaped [center :] if cn else reshaped [- batch :]
294
- outs = [px , nx ] if cn else [px ]
302
+ px = reshaped [0 :center ] if equal else reshaped [0 :- batch ]
303
+ nx = reshaped [center :] if equal else reshaped [- batch :]
304
+ outs = [px , nx ] if equal else [px ]
295
305
for out in outs :
296
306
c = 0
297
307
for i , ocell in enumerate (ocells ):
@@ -321,15 +331,16 @@ def forward(
321
331
:,
322
332
]
323
333
c += 1
324
- px , nx = (px [0 :batch ], nx [0 :batch ]) if cn else (px [0 :batch ], nx )
334
+ px , nx = (px [0 :batch ], nx [0 :batch ]) if equal else (px [0 :batch ], nx )
325
335
hidden_states = torch .cat ([nx , px ], 0 ) if revers else torch .cat ([px , nx ], 0 )
326
336
hidden_states = hidden_states .reshape (xshape )
327
337
328
338
#### Regional Prompting Prompt mode
329
339
elif "PRO" in mode :
330
- center = reshaped .shape [0 ] // 2
331
- px = reshaped [0 :center ] if cn else reshaped [0 :- batch ]
332
- nx = reshaped [center :] if cn else reshaped [- batch :]
340
+ px , nx = (
341
+ torch .chunk (hidden_states ) if equal else hidden_states [0 :- batch ],
342
+ hidden_states [- batch :],
343
+ )
333
344
334
345
if (h , w ) in self .attnmasks and self .maskready :
335
346
@@ -340,8 +351,8 @@ def mask(input):
340
351
out [b ] = out [b ] + out [r * batch + b ]
341
352
return out
342
353
343
- px , nx = (mask (px ), mask (nx )) if cn else (mask (px ), nx )
344
- px , nx = (px [0 :batch ], nx [0 :batch ]) if cn else (px [0 :batch ], nx )
354
+ px , nx = (mask (px ), mask (nx )) if equal else (mask (px ), nx )
355
+ px , nx = (px [0 :batch ], nx [0 :batch ]) if equal else (px [0 :batch ], nx )
345
356
hidden_states = torch .cat ([nx , px ], 0 ) if revers else torch .cat ([px , nx ], 0 )
346
357
return hidden_states
347
358
@@ -378,7 +389,15 @@ def hook_forwards(root_module: torch.nn.Module):
378
389
save_mask = False
379
390
380
391
if mode == "PROMPT" and save_mask :
381
- saveattnmaps (self , output , height , width , thresholds , num_inference_steps // 2 , regions )
392
+ saveattnmaps (
393
+ self ,
394
+ output ,
395
+ height ,
396
+ width ,
397
+ thresholds ,
398
+ num_inference_steps // 2 ,
399
+ regions ,
400
+ )
382
401
383
402
return output
384
403
@@ -437,7 +456,11 @@ def startend(cells, array):
437
456
def make_emblist (self , prompts ):
438
457
with torch .no_grad ():
439
458
tokens = self .tokenizer (
440
- prompts , max_length = self .tokenizer .model_max_length , padding = True , truncation = True , return_tensors = "pt"
459
+ prompts ,
460
+ max_length = self .tokenizer .model_max_length ,
461
+ padding = True ,
462
+ truncation = True ,
463
+ return_tensors = "pt" ,
441
464
).input_ids .to (self .device )
442
465
embs = self .text_encoder (tokens , output_hidden_states = True ).last_hidden_state .to (self .device , dtype = self .dtype )
443
466
return embs
@@ -563,7 +586,15 @@ def tokendealer(self, all_prompts):
563
586
564
587
565
588
def scaled_dot_product_attention (
566
- self , query , key , value , attn_mask = None , dropout_p = 0.0 , is_causal = False , scale = None , getattn = False
589
+ self ,
590
+ query ,
591
+ key ,
592
+ value ,
593
+ attn_mask = None ,
594
+ dropout_p = 0.0 ,
595
+ is_causal = False ,
596
+ scale = None ,
597
+ getattn = False ,
567
598
) -> torch .Tensor :
568
599
# Efficient implementation equivalent to the following:
569
600
L , S = query .size (- 2 ), key .size (- 2 )
0 commit comments