Skip to content

Commit 4e3ddd5

Browse files
authored
fix: mixture tiling sdxl pipeline - adjust gerating time_ids & embeddings (#11012)
small fix on generating time_ids & embeddings
1 parent 9add071 commit 4e3ddd5

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

examples/community/mixture_tiling_sdxl.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 The HuggingFace Team. All rights reserved.
1+
# Copyright 2025 The DEVAIEXP Team and The HuggingFace Team. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -1070,32 +1070,32 @@ def __call__(
10701070
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
10711071
else:
10721072
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1073-
add_time_ids = self._get_add_time_ids(
1074-
original_size,
1075-
crops_coords_top_left[row][col],
1076-
target_size,
1073+
add_time_ids = self._get_add_time_ids(
1074+
original_size,
1075+
crops_coords_top_left[row][col],
1076+
target_size,
1077+
dtype=prompt_embeds.dtype,
1078+
text_encoder_projection_dim=text_encoder_projection_dim,
1079+
)
1080+
if negative_original_size is not None and negative_target_size is not None:
1081+
negative_add_time_ids = self._get_add_time_ids(
1082+
negative_original_size,
1083+
negative_crops_coords_top_left[row][col],
1084+
negative_target_size,
10771085
dtype=prompt_embeds.dtype,
10781086
text_encoder_projection_dim=text_encoder_projection_dim,
10791087
)
1080-
if negative_original_size is not None and negative_target_size is not None:
1081-
negative_add_time_ids = self._get_add_time_ids(
1082-
negative_original_size,
1083-
negative_crops_coords_top_left[row][col],
1084-
negative_target_size,
1085-
dtype=prompt_embeds.dtype,
1086-
text_encoder_projection_dim=text_encoder_projection_dim,
1087-
)
1088-
else:
1089-
negative_add_time_ids = add_time_ids
1088+
else:
1089+
negative_add_time_ids = add_time_ids
10901090

1091-
if self.do_classifier_free_guidance:
1092-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1093-
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1094-
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1091+
if self.do_classifier_free_guidance:
1092+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1093+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1094+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
10951095

1096-
prompt_embeds = prompt_embeds.to(device)
1097-
add_text_embeds = add_text_embeds.to(device)
1098-
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1096+
prompt_embeds = prompt_embeds.to(device)
1097+
add_text_embeds = add_text_embeds.to(device)
1098+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
10991099
addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids))
11001100
embeddings_and_added_time.append(addition_embed_type_row)
11011101

0 commit comments

Comments
 (0)