Skip to content

Commit 2858d7e

Browse files
[From ckpt] Fix from_ckpt (#3466)
* Correct from_ckpt * make style
1 parent 88295f9 commit 2858d7e

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

src/diffusers/loaders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1326,7 +1326,7 @@ def from_ckpt(cls, pretrained_model_link_or_path, **kwargs):
13261326
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
13271327
from_safetensors = file_extension == "safetensors"
13281328

1329-
if from_safetensors and use_safetensors is True:
1329+
if from_safetensors and use_safetensors is False:
13301330
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
13311331

13321332
# TODO: For now we only support stable diffusion

src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -140,17 +140,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
140140
new_item = new_item.replace("norm.weight", "group_norm.weight")
141141
new_item = new_item.replace("norm.bias", "group_norm.bias")
142142

143-
new_item = new_item.replace("q.weight", "query.weight")
144-
new_item = new_item.replace("q.bias", "query.bias")
143+
new_item = new_item.replace("q.weight", "to_q.weight")
144+
new_item = new_item.replace("q.bias", "to_q.bias")
145145

146-
new_item = new_item.replace("k.weight", "key.weight")
147-
new_item = new_item.replace("k.bias", "key.bias")
146+
new_item = new_item.replace("k.weight", "to_k.weight")
147+
new_item = new_item.replace("k.bias", "to_k.bias")
148148

149-
new_item = new_item.replace("v.weight", "value.weight")
150-
new_item = new_item.replace("v.bias", "value.bias")
149+
new_item = new_item.replace("v.weight", "to_v.weight")
150+
new_item = new_item.replace("v.bias", "to_v.bias")
151151

152-
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
153-
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
152+
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
153+
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
154154

155155
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
156156

@@ -204,8 +204,12 @@ def assign_to_checkpoint(
204204
new_path = new_path.replace(replacement["old"], replacement["new"])
205205

206206
# proj_attn.weight has to be converted from conv 1D to linear
207-
if "proj_attn.weight" in new_path:
207+
is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
208+
shape = old_checkpoint[path["old"]].shape
209+
if is_attn_weight and len(shape) == 3:
208210
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
211+
elif is_attn_weight and len(shape) == 4:
212+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
209213
else:
210214
checkpoint[new_path] = old_checkpoint[path["old"]]
211215

0 commit comments

Comments
 (0)