@@ -774,6 +774,8 @@ def load_pipeline_from_original_audioldm_ckpt(
774
774
extract_ema : bool = False ,
775
775
scheduler_type : str = "ddim" ,
776
776
num_in_channels : int = None ,
777
+ model_channels : int = None ,
778
+ num_head_channels : int = None ,
777
779
device : str = None ,
778
780
from_safetensors : bool = False ,
779
781
) -> AudioLDMPipeline :
@@ -784,23 +786,36 @@ def load_pipeline_from_original_audioldm_ckpt(
784
786
global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
785
787
recommended that you override the default values and/or supply an `original_config_file` wherever possible.
786
788
787
- :param checkpoint_path: Path to `.ckpt` file. :param original_config_file: Path to `.yaml` config file
788
- corresponding to the original architecture.
789
- If `None`, will be automatically instantiated based on default values.
790
- :param image_size: The image size that the model was trained on. Use 512 for original AudioLDM checkpoints. :param
791
- prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for original
792
- AudioLDM checkpoints.
793
- :param num_in_channels: The number of input channels. If `None` number of input channels will be automatically
794
- inferred.
795
- :param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
796
- "euler-ancestral", "dpm", "ddim"]`.
797
- :param extract_ema: Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract
798
- the EMA weights or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually
799
- yield higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
800
- :param device: The device to use. Pass `None` to determine automatically. :param from_safetensors: If
801
- `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors
802
- instead of PyTorch.
803
- :return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
789
+ Args:
790
+ checkpoint_path (`str`): Path to `.ckpt` file.
791
+ original_config_file (`str`):
792
+ Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
793
+ set to the audioldm-s-full-v2 config.
794
+ image_size (`int`, *optional*, defaults to 512):
795
+ The image size that the model was trained on.
796
+ prediction_type (`str`, *optional*):
797
+ The prediction type that the model was trained on. If `None`, will be automatically
798
+ inferred by looking for a key in the config. For the default config, the prediction type is `'epsilon'`.
799
+ num_in_channels (`int`, *optional*, defaults to None):
800
+ The number of UNet input channels. If `None`, it will be automatically inferred from the config.
801
+ model_channels (`int`, *optional*, defaults to None):
802
+ The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override
803
+ to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.
804
+ num_head_channels (`int`, *optional*, defaults to None):
805
+ The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override
806
+ to 32 for the small and medium checkpoints, and 64 for the large.
807
+ scheduler_type (`str`, *optional*, defaults to 'pndm'):
808
+ Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
809
+ "ddim"]`.
810
+ extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
811
+ checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
812
+ `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
813
+ inference. Non-EMA weights are usually better to continue fine-tuning.
814
+ device (`str`, *optional*, defaults to `None`):
815
+ The device to use. Pass `None` to determine automatically.
816
+ from_safetensors (`str`, *optional*, defaults to `False`):
817
+ If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
818
+ return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
804
819
"""
805
820
806
821
if not is_omegaconf_available ():
@@ -837,6 +852,12 @@ def load_pipeline_from_original_audioldm_ckpt(
837
852
if num_in_channels is not None :
838
853
original_config ["model" ]["params" ]["unet_config" ]["params" ]["in_channels" ] = num_in_channels
839
854
855
+ if model_channels is not None :
856
+ original_config ["model" ]["params" ]["unet_config" ]["params" ]["model_channels" ] = model_channels
857
+
858
+ if num_head_channels is not None :
859
+ original_config ["model" ]["params" ]["unet_config" ]["params" ]["num_head_channels" ] = num_head_channels
860
+
840
861
if (
841
862
"parameterization" in original_config ["model" ]["params" ]
842
863
and original_config ["model" ]["params" ]["parameterization" ] == "v"
@@ -960,6 +981,20 @@ def load_pipeline_from_original_audioldm_ckpt(
960
981
type = int ,
961
982
help = "The number of input channels. If `None` number of input channels will be automatically inferred." ,
962
983
)
984
+ parser .add_argument (
985
+ "--model_channels" ,
986
+ default = None ,
987
+ type = int ,
988
+ help = "The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override"
989
+ " to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large." ,
990
+ )
991
+ parser .add_argument (
992
+ "--num_head_channels" ,
993
+ default = None ,
994
+ type = int ,
995
+ help = "The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override"
996
+ " to 32 for the small and medium checkpoints, and 64 for the large." ,
997
+ )
963
998
parser .add_argument (
964
999
"--scheduler_type" ,
965
1000
default = "ddim" ,
@@ -1009,6 +1044,8 @@ def load_pipeline_from_original_audioldm_ckpt(
1009
1044
extract_ema = args .extract_ema ,
1010
1045
scheduler_type = args .scheduler_type ,
1011
1046
num_in_channels = args .num_in_channels ,
1047
+ model_channels = args .model_channels ,
1048
+ num_head_channels = args .num_head_channels ,
1012
1049
from_safetensors = args .from_safetensors ,
1013
1050
device = args .device ,
1014
1051
)
0 commit comments