Skip to content

Commit 58431f1

Browse files
authored
Set LANCZOS as the default interpolation for image resizing in ControlNet training (#11449)
Set LANCZOS as the default interpolation for image resizing
1 parent 4a9ab65 commit 58431f1

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

examples/controlnet/train_controlnet_flux.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,15 @@ def parse_args(input_args=None):
639639
action="store_true",
640640
help="Enable model cpu offload and save memory.",
641641
)
642+
parser.add_argument(
643+
"--image_interpolation_mode",
644+
type=str,
645+
default="lanczos",
646+
choices=[
647+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
648+
],
649+
help="The image interpolation method to use for resizing images.",
650+
)
642651

643652
if input_args is not None:
644653
args = parser.parse_args(input_args)
@@ -736,9 +745,13 @@ def get_train_dataset(args, accelerator):
736745

737746

738747
def prepare_train_dataset(dataset, accelerator):
748+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
749+
if interpolation is None:
750+
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
751+
739752
image_transforms = transforms.Compose(
740753
[
741-
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
754+
transforms.Resize(args.resolution, interpolation=interpolation),
742755
transforms.CenterCrop(args.resolution),
743756
transforms.ToTensor(),
744757
transforms.Normalize([0.5], [0.5]),
@@ -747,7 +760,7 @@ def prepare_train_dataset(dataset, accelerator):
747760

748761
conditioning_image_transforms = transforms.Compose(
749762
[
750-
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
763+
transforms.Resize(args.resolution, interpolation=interpolation),
751764
transforms.CenterCrop(args.resolution),
752765
transforms.ToTensor(),
753766
transforms.Normalize([0.5], [0.5]),

0 commit comments

Comments
 (0)