Skip to content

Commit cb4b3f0

Browse files
authored
[OmegaConf] replace it with yaml (#6488)
* remove omegaconf from convert_from_ckpt. * remove from single_file. * change to string based ubscription. * style * okay * fix: vae_param * no . indexing. * style * style * turn getattrs into explicit if/else * style * propagate changes to ldm_uncond. * propagate to gligen * propagate to if. * fix: quotes. * propagate to audioldm. * propagate to audioldm2 * propagate to musicldm. * propagate to vq_diffusion * propagate to zero123. * remove omegaconf from diffusers codebase.
1 parent 3d574b3 commit cb4b3f0

15 files changed

+358
-423
lines changed

scripts/conversion_ldm_uncond.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import argparse
22

3-
import OmegaConf
43
import torch
4+
import yaml
55

66
from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel
77

88

99
def convert_ldm_original(checkpoint_path, config_path, output_path):
10-
config = OmegaConf.load(config_path)
10+
config = yaml.safe_load(config_path)
1111
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
1212
keys = list(state_dict.keys())
1313

@@ -25,8 +25,8 @@ def convert_ldm_original(checkpoint_path, config_path, output_path):
2525
if key.startswith(unet_key):
2626
unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
2727

28-
vqvae_init_args = config.model.params.first_stage_config.params
29-
unet_init_args = config.model.params.unet_config.params
28+
vqvae_init_args = config["model"]["params"]["first_stage_config"]["params"]
29+
unet_init_args = config["model"]["params"]["unet_config"]["params"]
3030

3131
vqvae = VQModel(**vqvae_init_args).eval()
3232
vqvae.load_state_dict(first_stage_dict)
@@ -35,10 +35,10 @@ def convert_ldm_original(checkpoint_path, config_path, output_path):
3535
unet.load_state_dict(unet_state_dict)
3636

3737
noise_scheduler = DDIMScheduler(
38-
timesteps=config.model.params.timesteps,
38+
timesteps=config["model"]["params"]["timesteps"],
3939
beta_schedule="scaled_linear",
40-
beta_start=config.model.params.linear_start,
41-
beta_end=config.model.params.linear_end,
40+
beta_start=config["model"]["params"]["linear_start"],
41+
beta_end=config["model"]["params"]["linear_end"],
4242
clip_sample=False,
4343
)
4444

scripts/convert_gligen_to_diffusers.py

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import re
33

44
import torch
5+
import yaml
56
from transformers import (
67
CLIPProcessor,
78
CLIPTextModel,
@@ -28,8 +29,6 @@
2829
textenc_conversion_map,
2930
textenc_pattern,
3031
)
31-
from diffusers.utils import is_omegaconf_available
32-
from diffusers.utils.import_utils import BACKENDS_MAPPING
3332

3433

3534
def convert_open_clip_checkpoint(checkpoint):
@@ -370,64 +369,64 @@ def convert_gligen_unet_checkpoint(checkpoint, config, path=None, extract_ema=Fa
370369

371370

372371
def create_vae_config(original_config, image_size: int):
373-
vae_params = original_config.autoencoder.params.ddconfig
374-
_ = original_config.autoencoder.params.embed_dim
372+
vae_params = original_config["autoencoder"]["params"]["ddconfig"]
373+
_ = original_config["autoencoder"]["params"]["embed_dim"]
375374

376-
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
375+
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
377376
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
378377
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
379378

380379
config = {
381380
"sample_size": image_size,
382-
"in_channels": vae_params.in_channels,
383-
"out_channels": vae_params.out_ch,
381+
"in_channels": vae_params["in_channels"],
382+
"out_channels": vae_params["out_ch"],
384383
"down_block_types": tuple(down_block_types),
385384
"up_block_types": tuple(up_block_types),
386385
"block_out_channels": tuple(block_out_channels),
387-
"latent_channels": vae_params.z_channels,
388-
"layers_per_block": vae_params.num_res_blocks,
386+
"latent_channels": vae_params["z_channels"],
387+
"layers_per_block": vae_params["num_res_blocks"],
389388
}
390389

391390
return config
392391

393392

394393
def create_unet_config(original_config, image_size: int, attention_type):
395-
unet_params = original_config.model.params
396-
vae_params = original_config.autoencoder.params.ddconfig
394+
unet_params = original_config["model"]["params"]
395+
vae_params = original_config["autoencoder"]["params"]["ddconfig"]
397396

398-
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
397+
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
399398

400399
down_block_types = []
401400
resolution = 1
402401
for i in range(len(block_out_channels)):
403-
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
402+
block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
404403
down_block_types.append(block_type)
405404
if i != len(block_out_channels) - 1:
406405
resolution *= 2
407406

408407
up_block_types = []
409408
for i in range(len(block_out_channels)):
410-
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
409+
block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
411410
up_block_types.append(block_type)
412411
resolution //= 2
413412

414-
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
413+
vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
415414

416-
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
415+
head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
417416
use_linear_projection = (
418-
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
417+
unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
419418
)
420419
if use_linear_projection:
421420
if head_dim is None:
422421
head_dim = [5, 10, 20, 20]
423422

424423
config = {
425424
"sample_size": image_size // vae_scale_factor,
426-
"in_channels": unet_params.in_channels,
425+
"in_channels": unet_params["in_channels"],
427426
"down_block_types": tuple(down_block_types),
428427
"block_out_channels": tuple(block_out_channels),
429-
"layers_per_block": unet_params.num_res_blocks,
430-
"cross_attention_dim": unet_params.context_dim,
428+
"layers_per_block": unet_params["num_res_blocks"],
429+
"cross_attention_dim": unet_params["context_dim"],
431430
"attention_head_dim": head_dim,
432431
"use_linear_projection": use_linear_projection,
433432
"attention_type": attention_type,
@@ -445,11 +444,6 @@ def convert_gligen_to_diffusers(
445444
num_in_channels: int = None,
446445
device: str = None,
447446
):
448-
if not is_omegaconf_available():
449-
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
450-
451-
from omegaconf import OmegaConf
452-
453447
if device is None:
454448
device = "cuda" if torch.cuda.is_available() else "cpu"
455449
checkpoint = torch.load(checkpoint_path, map_location=device)
@@ -461,14 +455,14 @@ def convert_gligen_to_diffusers(
461455
else:
462456
print("global_step key not found in model")
463457

464-
original_config = OmegaConf.load(original_config_file)
458+
original_config = yaml.safe_load(original_config_file)
465459

466460
if num_in_channels is not None:
467461
original_config["model"]["params"]["in_channels"] = num_in_channels
468462

469-
num_train_timesteps = original_config.diffusion.params.timesteps
470-
beta_start = original_config.diffusion.params.linear_start
471-
beta_end = original_config.diffusion.params.linear_end
463+
num_train_timesteps = original_config["diffusion"]["params"]["timesteps"]
464+
beta_start = original_config["diffusion"]["params"]["linear_start"]
465+
beta_end = original_config["diffusion"]["params"]["linear_end"]
472466

473467
scheduler = DDIMScheduler(
474468
beta_end=beta_end,

0 commit comments

Comments
 (0)