Skip to content

Commit abd86d1

Browse files
[AudioLDM] Generalise conversion script (#3328)
Co-authored-by: Patrick von Platen <[email protected]>
1 parent e9aa092 commit abd86d1

File tree

1 file changed

+54
-17
lines changed

1 file changed

+54
-17
lines changed

scripts/convert_original_audioldm_to_diffusers.py

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,8 @@ def load_pipeline_from_original_audioldm_ckpt(
774774
extract_ema: bool = False,
775775
scheduler_type: str = "ddim",
776776
num_in_channels: int = None,
777+
model_channels: int = None,
778+
num_head_channels: int = None,
777779
device: str = None,
778780
from_safetensors: bool = False,
779781
) -> AudioLDMPipeline:
@@ -784,23 +786,36 @@ def load_pipeline_from_original_audioldm_ckpt(
784786
global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
785787
recommended that you override the default values and/or supply an `original_config_file` wherever possible.
786788
787-
:param checkpoint_path: Path to `.ckpt` file. :param original_config_file: Path to `.yaml` config file
788-
corresponding to the original architecture.
789-
If `None`, will be automatically instantiated based on default values.
790-
:param image_size: The image size that the model was trained on. Use 512 for original AudioLDM checkpoints. :param
791-
prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for original
792-
AudioLDM checkpoints.
793-
:param num_in_channels: The number of input channels. If `None` number of input channels will be automatically
794-
inferred.
795-
:param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
796-
"euler-ancestral", "dpm", "ddim"]`.
797-
:param extract_ema: Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract
798-
the EMA weights or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually
799-
yield higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
800-
:param device: The device to use. Pass `None` to determine automatically. :param from_safetensors: If
801-
`checkpoint_path` is in `safetensors` format, load checkpoint with safetensors
802-
instead of PyTorch.
803-
:return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
789+
Args:
790+
checkpoint_path (`str`): Path to `.ckpt` file.
791+
original_config_file (`str`):
792+
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
793+
set to the audioldm-s-full-v2 config.
794+
image_size (`int`, *optional*, defaults to 512):
795+
The image size that the model was trained on.
796+
prediction_type (`str`, *optional*):
797+
The prediction type that the model was trained on. If `None`, will be automatically
798+
inferred by looking for a key in the config. For the default config, the prediction type is `'epsilon'`.
799+
num_in_channels (`int`, *optional*, defaults to None):
800+
The number of UNet input channels. If `None`, it will be automatically inferred from the config.
801+
model_channels (`int`, *optional*, defaults to None):
802+
The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override
803+
to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.
804+
num_head_channels (`int`, *optional*, defaults to None):
805+
The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override
806+
to 32 for the small and medium checkpoints, and 64 for the large.
807+
scheduler_type (`str`, *optional*, defaults to 'pndm'):
808+
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
809+
"ddim"]`.
810+
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
811+
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
812+
`False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
813+
inference. Non-EMA weights are usually better to continue fine-tuning.
814+
device (`str`, *optional*, defaults to `None`):
815+
The device to use. Pass `None` to determine automatically.
816+
from_safetensors (`str`, *optional*, defaults to `False`):
817+
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
818+
return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
804819
"""
805820

806821
if not is_omegaconf_available():
@@ -837,6 +852,12 @@ def load_pipeline_from_original_audioldm_ckpt(
837852
if num_in_channels is not None:
838853
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
839854

855+
if model_channels is not None:
856+
original_config["model"]["params"]["unet_config"]["params"]["model_channels"] = model_channels
857+
858+
if num_head_channels is not None:
859+
original_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = num_head_channels
860+
840861
if (
841862
"parameterization" in original_config["model"]["params"]
842863
and original_config["model"]["params"]["parameterization"] == "v"
@@ -960,6 +981,20 @@ def load_pipeline_from_original_audioldm_ckpt(
960981
type=int,
961982
help="The number of input channels. If `None` number of input channels will be automatically inferred.",
962983
)
984+
parser.add_argument(
985+
"--model_channels",
986+
default=None,
987+
type=int,
988+
help="The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override"
989+
" to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.",
990+
)
991+
parser.add_argument(
992+
"--num_head_channels",
993+
default=None,
994+
type=int,
995+
help="The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override"
996+
" to 32 for the small and medium checkpoints, and 64 for the large.",
997+
)
963998
parser.add_argument(
964999
"--scheduler_type",
9651000
default="ddim",
@@ -1009,6 +1044,8 @@ def load_pipeline_from_original_audioldm_ckpt(
10091044
extract_ema=args.extract_ema,
10101045
scheduler_type=args.scheduler_type,
10111046
num_in_channels=args.num_in_channels,
1047+
model_channels=args.model_channels,
1048+
num_head_channels=args.num_head_channels,
10121049
from_safetensors=args.from_safetensors,
10131050
device=args.device,
10141051
)

0 commit comments

Comments
 (0)