|
1 |
| -# Copyright 2025 The HuggingFace Team. All rights reserved. |
| 1 | +# Copyright 2025 The DEVAIEXP Team and The HuggingFace Team. All rights reserved. |
2 | 2 | #
|
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License");
|
4 | 4 | # you may not use this file except in compliance with the License.
|
@@ -1070,32 +1070,32 @@ def __call__(
|
1070 | 1070 | text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
1071 | 1071 | else:
|
1072 | 1072 | text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
1073 |
| - add_time_ids = self._get_add_time_ids( |
1074 |
| - original_size, |
1075 |
| - crops_coords_top_left[row][col], |
1076 |
| - target_size, |
| 1073 | + add_time_ids = self._get_add_time_ids( |
| 1074 | + original_size, |
| 1075 | + crops_coords_top_left[row][col], |
| 1076 | + target_size, |
| 1077 | + dtype=prompt_embeds.dtype, |
| 1078 | + text_encoder_projection_dim=text_encoder_projection_dim, |
| 1079 | + ) |
| 1080 | + if negative_original_size is not None and negative_target_size is not None: |
| 1081 | + negative_add_time_ids = self._get_add_time_ids( |
| 1082 | + negative_original_size, |
| 1083 | + negative_crops_coords_top_left[row][col], |
| 1084 | + negative_target_size, |
1077 | 1085 | dtype=prompt_embeds.dtype,
|
1078 | 1086 | text_encoder_projection_dim=text_encoder_projection_dim,
|
1079 | 1087 | )
|
1080 |
| - if negative_original_size is not None and negative_target_size is not None: |
1081 |
| - negative_add_time_ids = self._get_add_time_ids( |
1082 |
| - negative_original_size, |
1083 |
| - negative_crops_coords_top_left[row][col], |
1084 |
| - negative_target_size, |
1085 |
| - dtype=prompt_embeds.dtype, |
1086 |
| - text_encoder_projection_dim=text_encoder_projection_dim, |
1087 |
| - ) |
1088 |
| - else: |
1089 |
| - negative_add_time_ids = add_time_ids |
| 1088 | + else: |
| 1089 | + negative_add_time_ids = add_time_ids |
1090 | 1090 |
|
1091 |
| - if self.do_classifier_free_guidance: |
1092 |
| - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) |
1093 |
| - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) |
1094 |
| - add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) |
| 1091 | + if self.do_classifier_free_guidance: |
| 1092 | + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) |
| 1093 | + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) |
| 1094 | + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) |
1095 | 1095 |
|
1096 |
| - prompt_embeds = prompt_embeds.to(device) |
1097 |
| - add_text_embeds = add_text_embeds.to(device) |
1098 |
| - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) |
| 1096 | + prompt_embeds = prompt_embeds.to(device) |
| 1097 | + add_text_embeds = add_text_embeds.to(device) |
| 1098 | + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) |
1099 | 1099 | addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids))
|
1100 | 1100 | embeddings_and_added_time.append(addition_embed_type_row)
|
1101 | 1101 |
|
|
0 commit comments