@@ -639,6 +639,15 @@ def parse_args(input_args=None):
639
639
action = "store_true" ,
640
640
help = "Enable model cpu offload and save memory." ,
641
641
)
642
+ parser .add_argument (
643
+ "--image_interpolation_mode" ,
644
+ type = str ,
645
+ default = "lanczos" ,
646
+ choices = [
647
+ f .lower () for f in dir (transforms .InterpolationMode ) if not f .startswith ("__" ) and not f .endswith ("__" )
648
+ ],
649
+ help = "The image interpolation method to use for resizing images." ,
650
+ )
642
651
643
652
if input_args is not None :
644
653
args = parser .parse_args (input_args )
@@ -736,9 +745,13 @@ def get_train_dataset(args, accelerator):
736
745
737
746
738
747
def prepare_train_dataset (dataset , accelerator ):
748
+ interpolation = getattr (transforms .InterpolationMode , args .image_interpolation_mode .upper (), None )
749
+ if interpolation is None :
750
+ raise ValueError (f"Unsupported interpolation mode { interpolation = } ." )
751
+
739
752
image_transforms = transforms .Compose (
740
753
[
741
- transforms .Resize (args .resolution , interpolation = transforms . InterpolationMode . BILINEAR ),
754
+ transforms .Resize (args .resolution , interpolation = interpolation ),
742
755
transforms .CenterCrop (args .resolution ),
743
756
transforms .ToTensor (),
744
757
transforms .Normalize ([0.5 ], [0.5 ]),
@@ -747,7 +760,7 @@ def prepare_train_dataset(dataset, accelerator):
747
760
748
761
conditioning_image_transforms = transforms .Compose (
749
762
[
750
- transforms .Resize (args .resolution , interpolation = transforms . InterpolationMode . BILINEAR ),
763
+ transforms .Resize (args .resolution , interpolation = interpolation ),
751
764
transforms .CenterCrop (args .resolution ),
752
765
transforms .ToTensor (),
753
766
transforms .Normalize ([0.5 ], [0.5 ]),
0 commit comments