Skip to content

[Fix] Fix Regional Prompting Pipeline #6188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Dec 20, 2023
75 changes: 53 additions & 22 deletions examples/community/regional_prompting_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,14 @@ def __init__(
requires_safety_checker: bool = True,
):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker,
feature_extractor,
requires_safety_checker,
)
self.register_modules(
vae=vae,
Expand Down Expand Up @@ -102,22 +109,22 @@ def __call__(
return_dict: bool = True,
rp_args: Dict[str, str] = None,
):
active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt # noqa: E721
active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt
if negative_prompt is None:
negative_prompt = "" if type(prompt) == str else [""] * len(prompt) # noqa: E721
negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt)

device = self._execution_device
regions = 0

self.power = int(rp_args["power"]) if "power" in rp_args else 1

prompts = prompt if type(prompt) == list else [prompt] # noqa: E721
n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt] # noqa: E721
prompts = prompt if isinstance(prompt, list) else [prompt]
n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt]
self.batch = batch = num_images_per_prompt * len(prompts)
all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)

cn = len(all_prompts_cn) == len(all_n_prompts_cn)
equal = len(all_prompts_cn) == len(all_n_prompts_cn)

if Compel:
compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder)
Expand All @@ -129,15 +136,15 @@ def getcompelembs(prps):
return torch.cat(embl)

conds = getcompelembs(all_prompts_cn)
unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts)
unconds = getcompelembs(all_n_prompts_cn)
embs = getcompelembs(prompts)
n_embs = getcompelembs(n_prompts)
prompt = negative_prompt = None
else:
conds = self.encode_prompt(prompts, device, 1, True)[0]
unconds = (
self.encode_prompt(n_prompts, device, 1, True)[0]
if cn
if equal
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
)
embs = n_embs = None
Expand Down Expand Up @@ -206,8 +213,11 @@ def forward(
else:
px, nx = hidden_states.chunk(2)

if cn:
hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)], 0)
if equal:
hidden_states = torch.cat(
[px for i in range(regions)] + [nx for i in range(regions)],
0,
)
encoder_hidden_states = torch.cat([conds] + [unconds])
else:
hidden_states = torch.cat([px for i in range(regions)] + [nx], 0)
Expand Down Expand Up @@ -289,9 +299,9 @@ def forward(
if any(x in mode for x in ["COL", "ROW"]):
reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2])
center = reshaped.shape[0] // 2
px = reshaped[0:center] if cn else reshaped[0:-batch]
nx = reshaped[center:] if cn else reshaped[-batch:]
outs = [px, nx] if cn else [px]
px = reshaped[0:center] if equal else reshaped[0:-batch]
nx = reshaped[center:] if equal else reshaped[-batch:]
outs = [px, nx] if equal else [px]
for out in outs:
c = 0
for i, ocell in enumerate(ocells):
Expand Down Expand Up @@ -321,15 +331,16 @@ def forward(
:,
]
c += 1
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx)
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
hidden_states = hidden_states.reshape(xshape)

#### Regional Prompting Prompt mode
elif "PRO" in mode:
center = reshaped.shape[0] // 2
px = reshaped[0:center] if cn else reshaped[0:-batch]
nx = reshaped[center:] if cn else reshaped[-batch:]
px, nx = (
torch.chunk(hidden_states) if equal else hidden_states[0:-batch],
hidden_states[-batch:],
)

if (h, w) in self.attnmasks and self.maskready:

Expand All @@ -340,8 +351,8 @@ def mask(input):
out[b] = out[b] + out[r * batch + b]
return out

px, nx = (mask(px), mask(nx)) if cn else (mask(px), nx)
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
px, nx = (mask(px), mask(nx)) if equal else (mask(px), nx)
px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx)
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
return hidden_states

Expand Down Expand Up @@ -378,7 +389,15 @@ def hook_forwards(root_module: torch.nn.Module):
save_mask = False

if mode == "PROMPT" and save_mask:
saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions)
saveattnmaps(
self,
output,
height,
width,
thresholds,
num_inference_steps // 2,
regions,
)

return output

Expand Down Expand Up @@ -437,7 +456,11 @@ def startend(cells, array):
def make_emblist(self, prompts):
with torch.no_grad():
tokens = self.tokenizer(
prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
prompts,
max_length=self.tokenizer.model_max_length,
padding=True,
truncation=True,
return_tensors="pt",
).input_ids.to(self.device)
embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype)
return embs
Expand Down Expand Up @@ -563,7 +586,15 @@ def tokendealer(self, all_prompts):


def scaled_dot_product_attention(
self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn=False
self,
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
getattn=False,
) -> torch.Tensor:
# Efficient implementation equivalent to the following:
L, S = query.size(-2), key.size(-2)
Expand Down