Skip to content

Commit 723dbdd

Browse files
asomozasayakpaulhlky
authored
[Training] Better image interpolation in training scripts (#11206)
* initial * Update examples/dreambooth/train_dreambooth_lora_sdxl.py Co-authored-by: hlky <[email protected]> * update --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: hlky <[email protected]>
1 parent fbf61f4 commit 723dbdd

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

examples/dreambooth/train_dreambooth_lora_sdxl.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,16 @@ def parse_args(input_args=None):
669669
),
670670
)
671671

672+
parser.add_argument(
673+
"--image_interpolation_mode",
674+
type=str,
675+
default="lanczos",
676+
choices=[
677+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
678+
],
679+
help="The image interpolation method to use for resizing images.",
680+
)
681+
672682
if input_args is not None:
673683
args = parser.parse_args(input_args)
674684
else:
@@ -790,7 +800,12 @@ def __init__(
790800
self.original_sizes = []
791801
self.crop_top_lefts = []
792802
self.pixel_values = []
793-
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
803+
804+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
805+
if interpolation is None:
806+
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
807+
train_resize = transforms.Resize(size, interpolation=interpolation)
808+
794809
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
795810
train_flip = transforms.RandomHorizontalFlip(p=1.0)
796811
train_transforms = transforms.Compose(

0 commit comments

Comments
 (0)