Skip to content

[OmegaConf] replace it with yaml #6488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions scripts/conversion_ldm_uncond.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import argparse

import OmegaConf
import torch
import yaml

from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel


def convert_ldm_original(checkpoint_path, config_path, output_path):
config = OmegaConf.load(config_path)
config = yaml.safe_load(config_path)
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
keys = list(state_dict.keys())

Expand All @@ -25,8 +25,8 @@ def convert_ldm_original(checkpoint_path, config_path, output_path):
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = state_dict[key]

vqvae_init_args = config.model.params.first_stage_config.params
unet_init_args = config.model.params.unet_config.params
vqvae_init_args = config["model"]["params"]["first_stage_config"]["params"]
unet_init_args = config["model"]["params"]["unet_config"]["params"]

vqvae = VQModel(**vqvae_init_args).eval()
vqvae.load_state_dict(first_stage_dict)
Expand All @@ -35,10 +35,10 @@ def convert_ldm_original(checkpoint_path, config_path, output_path):
unet.load_state_dict(unet_state_dict)

noise_scheduler = DDIMScheduler(
timesteps=config.model.params.timesteps,
timesteps=config["model"]["params"]["timesteps"],
beta_schedule="scaled_linear",
beta_start=config.model.params.linear_start,
beta_end=config.model.params.linear_end,
beta_start=config["model"]["params"]["linear_start"],
beta_end=config["model"]["params"]["linear_end"],
clip_sample=False,
)

Expand Down
52 changes: 23 additions & 29 deletions scripts/convert_gligen_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re

import torch
import yaml
from transformers import (
CLIPProcessor,
CLIPTextModel,
Expand All @@ -28,8 +29,6 @@
textenc_conversion_map,
textenc_pattern,
)
from diffusers.utils import is_omegaconf_available
from diffusers.utils.import_utils import BACKENDS_MAPPING


def convert_open_clip_checkpoint(checkpoint):
Expand Down Expand Up @@ -370,64 +369,64 @@ def convert_gligen_unet_checkpoint(checkpoint, config, path=None, extract_ema=Fa


def create_vae_config(original_config, image_size: int):
vae_params = original_config.autoencoder.params.ddconfig
_ = original_config.autoencoder.params.embed_dim
vae_params = original_config["autoencoder"]["params"]["ddconfig"]
_ = original_config["autoencoder"]["params"]["embed_dim"]

block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)

config = {
"sample_size": image_size,
"in_channels": vae_params.in_channels,
"out_channels": vae_params.out_ch,
"in_channels": vae_params["in_channels"],
"out_channels": vae_params["out_ch"],
"down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels),
"latent_channels": vae_params.z_channels,
"layers_per_block": vae_params.num_res_blocks,
"latent_channels": vae_params["z_channels"],
"layers_per_block": vae_params["num_res_blocks"],
}

return config


def create_unet_config(original_config, image_size: int, attention_type):
unet_params = original_config.model.params
vae_params = original_config.autoencoder.params.ddconfig
unet_params = original_config["model"]["params"]
vae_params = original_config["autoencoder"]["params"]["ddconfig"]

block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]

down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2

up_block_types = []
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2

vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)

head_dim = unet_params.num_heads if "num_heads" in unet_params else None
head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
use_linear_projection = (
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
)
if use_linear_projection:
if head_dim is None:
head_dim = [5, 10, 20, 20]

config = {
"sample_size": image_size // vae_scale_factor,
"in_channels": unet_params.in_channels,
"in_channels": unet_params["in_channels"],
"down_block_types": tuple(down_block_types),
"block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks,
"cross_attention_dim": unet_params.context_dim,
"layers_per_block": unet_params["num_res_blocks"],
"cross_attention_dim": unet_params["context_dim"],
"attention_head_dim": head_dim,
"use_linear_projection": use_linear_projection,
"attention_type": attention_type,
Expand All @@ -445,11 +444,6 @@ def convert_gligen_to_diffusers(
num_in_channels: int = None,
device: str = None,
):
if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])

from omegaconf import OmegaConf

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device)
Expand All @@ -461,14 +455,14 @@ def convert_gligen_to_diffusers(
else:
print("global_step key not found in model")

original_config = OmegaConf.load(original_config_file)
original_config = yaml.safe_load(original_config_file)

if num_in_channels is not None:
original_config["model"]["params"]["in_channels"] = num_in_channels

num_train_timesteps = original_config.diffusion.params.timesteps
beta_start = original_config.diffusion.params.linear_start
beta_end = original_config.diffusion.params.linear_end
num_train_timesteps = original_config["diffusion"]["params"]["timesteps"]
beta_start = original_config["diffusion"]["params"]["linear_start"]
beta_end = original_config["diffusion"]["params"]["linear_end"]

scheduler = DDIMScheduler(
beta_end=beta_end,
Expand Down
Loading