Skip to content

Commit 31e8c58

Browse files
committed
[wip] kakaobrain unCLIP convert script
1 parent 9428ea4 commit 31e8c58

File tree

1 file changed

+213
-0
lines changed

1 file changed

+213
-0
lines changed
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
import argparse
2+
import tempfile
3+
4+
import torch
5+
6+
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
7+
from diffusers import UnCLIPPipeline, UNet2DConditionModel
8+
9+
10+
# decoder model
11+
12+
13+
def decoder_model_from_original_config():
14+
# We are hardcoding the model configuration for now. If we need to generalize to more model configurations, we can
15+
# update then.
16+
model = UNet2DConditionModel(
17+
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
18+
layers_per_block=3,
19+
resnet_time_scale_shift="scale_shift",
20+
block_out_channels=(320, 640, 960, 1280),
21+
downsample_resnet=True,
22+
up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
23+
upsample_resnet=True,
24+
up_block_layers_per_block=3,
25+
in_channels=3,
26+
out_channels=6,
27+
)
28+
29+
return model
30+
31+
32+
# done decoder model
33+
34+
# decoder checkpoint
35+
36+
DECODER_ORIGINAL_PREFIX = "model"
37+
38+
39+
def decoder_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
40+
diffusers_checkpoint = {}
41+
42+
# TODO
43+
# Linear
44+
checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.0.weight"]
45+
checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.0.bias"]
46+
# Norm
47+
checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.1.weight"]
48+
checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.1.bias"]
49+
50+
# TODO
51+
# Linear
52+
checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_feat_proj.weight"]
53+
checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_feat_proj.bias"]
54+
55+
# TODO
56+
# Linear
57+
checkpoint[f"{DECODER_ORIGINAL_PREFIX}.clip_tok_proj.weight"]
58+
checkpoint[f"{DECODER_ORIGINAL_PREFIX}.clip_tok_proj.bias"]
59+
60+
# TODO There's also a `clip_emb_mult` that's a scalar and not a model parameter
61+
62+
# input_blocks.0 -> conv_in
63+
64+
diffusers_checkpoint.update(
65+
{
66+
"conv_in.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.input_blocks.0.0.weight"],
67+
"conv_in.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.input_blocks.0.0.bias"],
68+
}
69+
)
70+
71+
# DownBlock2D
72+
# input_blocks.[1, 2, 3, 4] -> down_blocks.0
73+
74+
# CrossAttnDownBlock2D
75+
# input_blocks.[5, 6, 7, 8] -> down_blocks.1
76+
77+
# CrossAttnDownBlock2D
78+
# 9, 10, 11, 12 -> 2
79+
80+
# TODO bug here
81+
# 13, 14, 15 -> 3
82+
83+
resnets_per_down_block = len(model.down_blocks[0].resnets)
84+
resnets_per_down_block += 1
85+
86+
for down_block_idx in range(len(model.down_blocks)):
87+
original_resnet_idx = 1 + resnets_per_down_block * down_block_idx
88+
89+
diffusers_checkpoint.update(
90+
decoder_downblock_to_diffusers_checkpoint(
91+
model, checkpoint, diffusers_down_block_idx=down_block_idx, original_resnet_idx=original_resnet_idx
92+
)
93+
)
94+
95+
# middle_block -> mid_block
96+
97+
# output_blocks -> up_blocks
98+
99+
return diffusers_checkpoint
100+
101+
102+
# TODO add transformers
103+
def decoder_downblock_to_diffusers_checkpoint(model, checkpoint, *, diffusers_down_block_idx, original_resnet_idx):
104+
diffusers_checkpoint = {}
105+
106+
diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.resnets"
107+
resnet_prefix = f"{DECODER_ORIGINAL_PREFIX}.input_blocks"
108+
109+
num_resnets = len(model.down_blocks[diffusers_down_block_idx].resnets)
110+
111+
# The last downsample block is also a resnet
112+
num_resnets = num_resnets + 1
113+
114+
for resnet_idx_inc in range(num_resnets):
115+
full_resnet_prefix = f"{resnet_prefix}.{original_resnet_idx + resnet_idx_inc}.0"
116+
117+
if resnet_idx_inc == num_resnets - 1:
118+
# this is a downsample block
119+
full_diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.downsamplers.0"
120+
else:
121+
# this is a regular resnet block
122+
full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}"
123+
124+
diffusers_checkpoint.update(
125+
decoder_resnet_to_diffusers_checkpoint(
126+
checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix
127+
)
128+
)
129+
130+
return diffusers_checkpoint
131+
132+
133+
def decoder_resnet_to_diffusers_checkpoint(checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
134+
diffusers_checkpoint = {
135+
f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.in_layers.0.weight"],
136+
f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.in_layers.0.bias"],
137+
f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.in_layers.2.weight"],
138+
f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.in_layers.2.bias"],
139+
f"{diffusers_resnet_prefix}.time_emb_proj.weight": checkpoint[f"{resnet_prefix}.emb_layers.1.weight"],
140+
f"{diffusers_resnet_prefix}.time_emb_proj.bias": checkpoint[f"{resnet_prefix}.emb_layers.1.bias"],
141+
f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.out_layers.0.weight"],
142+
f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.out_layers.0.bias"],
143+
f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.out_layers.3.weight"],
144+
f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.out_layers.3.bias"],
145+
}
146+
147+
return diffusers_checkpoint
148+
149+
150+
# done decoder checkpoint
151+
152+
153+
if __name__ == "__main__":
154+
parser = argparse.ArgumentParser()
155+
156+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
157+
158+
parser.add_argument(
159+
"--decoder_checkpoint_path",
160+
default=None,
161+
type=str,
162+
required=True,
163+
help="Path to the decoder checkpoint to convert.",
164+
)
165+
166+
parser.add_argument(
167+
"--checkpoint_load_device",
168+
default="cpu",
169+
type=str,
170+
required=False,
171+
help="The device passed to `map_location` when loading checkpoints.",
172+
)
173+
174+
args = parser.parse_args()
175+
176+
print(f"loading checkpoints to {args.checkpoint_load_device}")
177+
178+
checkpoint_map_location = torch.device(args.checkpoint_load_device)
179+
180+
# decoder_model
181+
182+
print("loading decoder")
183+
184+
decoder_checkpoint = torch.load(args.decoder_checkpoint_path, map_location=checkpoint_map_location)
185+
decoder_checkpoint = decoder_checkpoint["state_dict"]
186+
187+
with init_empty_weights():
188+
decoder_model = decoder_model_from_original_config()
189+
190+
decoder_diffusers_checkpoint = decoder_original_checkpoint_to_diffusers_checkpoint(
191+
decoder_model, decoder_checkpoint
192+
)
193+
194+
with tempfile.NamedTemporaryFile() as decoder_diffusers_checkpoint_file:
195+
torch.save(decoder_diffusers_checkpoint, decoder_diffusers_checkpoint_file.name)
196+
del decoder_diffusers_checkpoint
197+
del decoder_checkpoint
198+
# TODO use load_checkpoint_and_dispatch
199+
# load_checkpoint_and_dispatch(decoder_model, decoder_diffusers_checkpoint_file.name, device_map="auto")
200+
decoder_model.load_state_dict(
201+
torch.load(decoder_diffusers_checkpoint_file.name, map_location=checkpoint_map_location), strict=False
202+
)
203+
204+
print("done loading decoder")
205+
206+
# done decoder_model
207+
208+
print(f"saving Kakao Brain unCLIP to {args.dump_path}")
209+
210+
pipe = UnCLIPPipeline(decoder=decoder_model)
211+
pipe.save_pretrained(args.dump_path)
212+
213+
print("done writing Kakao Brain unCLIP")

0 commit comments

Comments
 (0)