Skip to content

[AudioLDM] Generalise conversion script #3328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 54 additions & 17 deletions scripts/convert_original_audioldm_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,8 @@ def load_pipeline_from_original_audioldm_ckpt(
extract_ema: bool = False,
scheduler_type: str = "ddim",
num_in_channels: int = None,
model_channels: int = None,
num_head_channels: int = None,
device: str = None,
from_safetensors: bool = False,
) -> AudioLDMPipeline:
Expand All @@ -784,23 +786,36 @@ def load_pipeline_from_original_audioldm_ckpt(
global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
recommended that you override the default values and/or supply an `original_config_file` wherever possible.

:param checkpoint_path: Path to `.ckpt` file. :param original_config_file: Path to `.yaml` config file
corresponding to the original architecture.
If `None`, will be automatically instantiated based on default values.
:param image_size: The image size that the model was trained on. Use 512 for original AudioLDM checkpoints. :param
prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for original
AudioLDM checkpoints.
:param num_in_channels: The number of input channels. If `None` number of input channels will be automatically
inferred.
:param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
"euler-ancestral", "dpm", "ddim"]`.
:param extract_ema: Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract
the EMA weights or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually
yield higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
:param device: The device to use. Pass `None` to determine automatically. :param from_safetensors: If
`checkpoint_path` is in `safetensors` format, load checkpoint with safetensors
instead of PyTorch.
:return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
Args:
checkpoint_path (`str`): Path to `.ckpt` file.
original_config_file (`str`):
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
set to the audioldm-s-full-v2 config.
image_size (`int`, *optional*, defaults to 512):
The image size that the model was trained on.
prediction_type (`str`, *optional*):
The prediction type that the model was trained on. If `None`, will be automatically
inferred by looking for a key in the config. For the default config, the prediction type is `'epsilon'`.
num_in_channels (`int`, *optional*, defaults to None):
The number of UNet input channels. If `None`, it will be automatically inferred from the config.
model_channels (`int`, *optional*, defaults to None):
The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override
to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.
num_head_channels (`int`, *optional*, defaults to None):
The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override
to 32 for the small and medium checkpoints, and 64 for the large.
scheduler_type (`str`, *optional*, defaults to 'pndm'):
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
"ddim"]`.
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
`False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
inference. Non-EMA weights are usually better to continue fine-tuning.
device (`str`, *optional*, defaults to `None`):
The device to use. Pass `None` to determine automatically.
from_safetensors (`str`, *optional*, defaults to `False`):
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
"""

if not is_omegaconf_available():
Expand Down Expand Up @@ -837,6 +852,12 @@ def load_pipeline_from_original_audioldm_ckpt(
if num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels

if model_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["model_channels"] = model_channels

if num_head_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = num_head_channels

if (
"parameterization" in original_config["model"]["params"]
and original_config["model"]["params"]["parameterization"] == "v"
Expand Down Expand Up @@ -960,6 +981,20 @@ def load_pipeline_from_original_audioldm_ckpt(
type=int,
help="The number of input channels. If `None` number of input channels will be automatically inferred.",
)
parser.add_argument(
"--model_channels",
default=None,
type=int,
help="The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override"
" to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.",
)
parser.add_argument(
"--num_head_channels",
default=None,
type=int,
help="The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override"
" to 32 for the small and medium checkpoints, and 64 for the large.",
)
parser.add_argument(
"--scheduler_type",
default="ddim",
Expand Down Expand Up @@ -1009,6 +1044,8 @@ def load_pipeline_from_original_audioldm_ckpt(
extract_ema=args.extract_ema,
scheduler_type=args.scheduler_type,
num_in_channels=args.num_in_channels,
model_channels=args.model_channels,
num_head_channels=args.num_head_channels,
from_safetensors=args.from_safetensors,
device=args.device,
)
Expand Down