@@ -470,6 +470,15 @@ def parse_args(input_args=None):
470
470
"--enable_xformers_memory_efficient_attention" , action = "store_true" , help = "Whether or not to use xformers."
471
471
)
472
472
parser .add_argument ("--noise_offset" , type = float , default = 0 , help = "The scale of noise offset." )
473
+ parser .add_argument (
474
+ "--image_interpolation_mode" ,
475
+ type = str ,
476
+ default = "lanczos" ,
477
+ choices = [
478
+ f .lower () for f in dir (transforms .InterpolationMode ) if not f .startswith ("__" ) and not f .endswith ("__" )
479
+ ],
480
+ help = "The image interpolation method to use for resizing images." ,
481
+ )
473
482
474
483
if input_args is not None :
475
484
args = parser .parse_args (input_args )
@@ -861,7 +870,10 @@ def load_model_hook(models, input_dir):
861
870
)
862
871
863
872
# Preprocessing the datasets.
864
- train_resize = transforms .Resize (args .resolution , interpolation = transforms .InterpolationMode .BILINEAR )
873
+ interpolation = getattr (transforms .InterpolationMode , args .image_interpolation_mode .upper (), None )
874
+ if interpolation is None :
875
+ raise ValueError (f"Unsupported interpolation mode { interpolation = } ." )
876
+ train_resize = transforms .Resize (args .resolution , interpolation = interpolation )
865
877
train_crop = transforms .CenterCrop (args .resolution ) if args .center_crop else transforms .RandomCrop (args .resolution )
866
878
train_flip = transforms .RandomHorizontalFlip (p = 1.0 )
867
879
train_transforms = transforms .Compose ([transforms .ToTensor (), transforms .Normalize ([0.5 ], [0.5 ])])
0 commit comments