Skip to content

Commit 3d574b3

Browse files
authored
Fix a bug of flip in SDXL training script (#6547)
* Update train_text_to_image_sdxl.py * Update train_text_to_image_lora_sdxl.py
1 parent 0990377 commit 3d574b3

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

examples/text_to_image/train_text_to_image_lora_sdxl.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -836,17 +836,16 @@ def preprocess_train(examples):
836836
for image in images:
837837
original_sizes.append((image.height, image.width))
838838
image = train_resize(image)
839+
if args.random_flip and random.random() < 0.5:
840+
# flip
841+
image = train_flip(image)
839842
if args.center_crop:
840843
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
841844
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
842845
image = train_crop(image)
843846
else:
844847
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
845848
image = crop(image, y1, x1, h, w)
846-
if args.random_flip and random.random() < 0.5:
847-
# flip
848-
x1 = image.width - x1
849-
image = train_flip(image)
850849
crop_top_left = (y1, x1)
851850
crop_top_lefts.append(crop_top_left)
852851
image = train_transforms(image)

examples/text_to_image/train_text_to_image_sdxl.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -839,17 +839,16 @@ def preprocess_train(examples):
839839
for image in images:
840840
original_sizes.append((image.height, image.width))
841841
image = train_resize(image)
842+
if args.random_flip and random.random() < 0.5:
843+
# flip
844+
image = train_flip(image)
842845
if args.center_crop:
843846
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
844847
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
845848
image = train_crop(image)
846849
else:
847850
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
848851
image = crop(image, y1, x1, h, w)
849-
if args.random_flip and random.random() < 0.5:
850-
# flip
851-
x1 = image.width - x1
852-
image = train_flip(image)
853852
crop_top_left = (y1, x1)
854853
crop_top_lefts.append(crop_top_left)
855854
image = train_transforms(image)

0 commit comments

Comments
 (0)