Skip to content

Commit ff43dba

Browse files
[Fix] Fix Regional Prompting Pipeline (#6188)
* Update regional_prompting_stable_diffusion.py * reformat * reformat * reformat * reformat * reformat * reformat * reformat * regormat * reformat * reformat * reformat * reformat * Update regional_prompting_stable_diffusion.py --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 5433962 commit ff43dba

File tree

1 file changed

+53
-22
lines changed

1 file changed

+53
-22
lines changed

examples/community/regional_prompting_stable_diffusion.py

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,14 @@ def __init__(
7373
requires_safety_checker: bool = True,
7474
):
7575
super().__init__(
76-
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
76+
vae,
77+
text_encoder,
78+
tokenizer,
79+
unet,
80+
scheduler,
81+
safety_checker,
82+
feature_extractor,
83+
requires_safety_checker,
7784
)
7885
self.register_modules(
7986
vae=vae,
@@ -102,22 +109,22 @@ def __call__(
102109
return_dict: bool = True,
103110
rp_args: Dict[str, str] = None,
104111
):
105-
active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt # noqa: E721
112+
active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt
106113
if negative_prompt is None:
107-
negative_prompt = "" if type(prompt) == str else [""] * len(prompt) # noqa: E721
114+
negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt)
108115

109116
device = self._execution_device
110117
regions = 0
111118

112119
self.power = int(rp_args["power"]) if "power" in rp_args else 1
113120

114-
prompts = prompt if type(prompt) == list else [prompt] # noqa: E721
115-
n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt] # noqa: E721
121+
prompts = prompt if isinstance(prompt, list) else [prompt]
122+
n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt]
116123
self.batch = batch = num_images_per_prompt * len(prompts)
117124
all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
118125
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
119126

120-
cn = len(all_prompts_cn) == len(all_n_prompts_cn)
127+
equal = len(all_prompts_cn) == len(all_n_prompts_cn)
121128

122129
if Compel:
123130
compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder)
@@ -129,15 +136,15 @@ def getcompelembs(prps):
129136
return torch.cat(embl)
130137

131138
conds = getcompelembs(all_prompts_cn)
132-
unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts)
139+
unconds = getcompelembs(all_n_prompts_cn)
133140
embs = getcompelembs(prompts)
134141
n_embs = getcompelembs(n_prompts)
135142
prompt = negative_prompt = None
136143
else:
137144
conds = self.encode_prompt(prompts, device, 1, True)[0]
138145
unconds = (
139146
self.encode_prompt(n_prompts, device, 1, True)[0]
140-
if cn
147+
if equal
141148
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
142149
)
143150
embs = n_embs = None
@@ -206,8 +213,11 @@ def forward(
206213
else:
207214
px, nx = hidden_states.chunk(2)
208215

209-
if cn:
210-
hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)], 0)
216+
if equal:
217+
hidden_states = torch.cat(
218+
[px for i in range(regions)] + [nx for i in range(regions)],
219+
0,
220+
)
211221
encoder_hidden_states = torch.cat([conds] + [unconds])
212222
else:
213223
hidden_states = torch.cat([px for i in range(regions)] + [nx], 0)
@@ -289,9 +299,9 @@ def forward(
289299
if any(x in mode for x in ["COL", "ROW"]):
290300
reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2])
291301
center = reshaped.shape[0] // 2
292-
px = reshaped[0:center] if cn else reshaped[0:-batch]
293-
nx = reshaped[center:] if cn else reshaped[-batch:]
294-
outs = [px, nx] if cn else [px]
302+
px = reshaped[0:center] if equal else reshaped[0:-batch]
303+
nx = reshaped[center:] if equal else reshaped[-batch:]
304+
outs = [px, nx] if equal else [px]
295305
for out in outs:
296306
c = 0
297307
for i, ocell in enumerate(ocells):
@@ -321,15 +331,16 @@ def forward(
321331
:,
322332
]
323333
c += 1
324-
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
334+
px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx)
325335
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
326336
hidden_states = hidden_states.reshape(xshape)
327337

328338
#### Regional Prompting Prompt mode
329339
elif "PRO" in mode:
330-
center = reshaped.shape[0] // 2
331-
px = reshaped[0:center] if cn else reshaped[0:-batch]
332-
nx = reshaped[center:] if cn else reshaped[-batch:]
340+
px, nx = (
341+
torch.chunk(hidden_states) if equal else hidden_states[0:-batch],
342+
hidden_states[-batch:],
343+
)
333344

334345
if (h, w) in self.attnmasks and self.maskready:
335346

@@ -340,8 +351,8 @@ def mask(input):
340351
out[b] = out[b] + out[r * batch + b]
341352
return out
342353

343-
px, nx = (mask(px), mask(nx)) if cn else (mask(px), nx)
344-
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
354+
px, nx = (mask(px), mask(nx)) if equal else (mask(px), nx)
355+
px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx)
345356
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
346357
return hidden_states
347358

@@ -378,7 +389,15 @@ def hook_forwards(root_module: torch.nn.Module):
378389
save_mask = False
379390

380391
if mode == "PROMPT" and save_mask:
381-
saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions)
392+
saveattnmaps(
393+
self,
394+
output,
395+
height,
396+
width,
397+
thresholds,
398+
num_inference_steps // 2,
399+
regions,
400+
)
382401

383402
return output
384403

@@ -437,7 +456,11 @@ def startend(cells, array):
437456
def make_emblist(self, prompts):
438457
with torch.no_grad():
439458
tokens = self.tokenizer(
440-
prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
459+
prompts,
460+
max_length=self.tokenizer.model_max_length,
461+
padding=True,
462+
truncation=True,
463+
return_tensors="pt",
441464
).input_ids.to(self.device)
442465
embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype)
443466
return embs
@@ -563,7 +586,15 @@ def tokendealer(self, all_prompts):
563586

564587

565588
def scaled_dot_product_attention(
566-
self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn=False
589+
self,
590+
query,
591+
key,
592+
value,
593+
attn_mask=None,
594+
dropout_p=0.0,
595+
is_causal=False,
596+
scale=None,
597+
getattn=False,
567598
) -> torch.Tensor:
568599
# Efficient implementation equivalent to the following:
569600
L, S = query.size(-2), key.size(-2)

0 commit comments

Comments
 (0)