@@ -140,17 +140,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
140
140
new_item = new_item .replace ("norm.weight" , "group_norm.weight" )
141
141
new_item = new_item .replace ("norm.bias" , "group_norm.bias" )
142
142
143
- new_item = new_item .replace ("q.weight" , "query .weight" )
144
- new_item = new_item .replace ("q.bias" , "query .bias" )
143
+ new_item = new_item .replace ("q.weight" , "to_q .weight" )
144
+ new_item = new_item .replace ("q.bias" , "to_q .bias" )
145
145
146
- new_item = new_item .replace ("k.weight" , "key .weight" )
147
- new_item = new_item .replace ("k.bias" , "key .bias" )
146
+ new_item = new_item .replace ("k.weight" , "to_k .weight" )
147
+ new_item = new_item .replace ("k.bias" , "to_k .bias" )
148
148
149
- new_item = new_item .replace ("v.weight" , "value .weight" )
150
- new_item = new_item .replace ("v.bias" , "value .bias" )
149
+ new_item = new_item .replace ("v.weight" , "to_v .weight" )
150
+ new_item = new_item .replace ("v.bias" , "to_v .bias" )
151
151
152
- new_item = new_item .replace ("proj_out.weight" , "proj_attn .weight" )
153
- new_item = new_item .replace ("proj_out.bias" , "proj_attn .bias" )
152
+ new_item = new_item .replace ("proj_out.weight" , "to_out.0 .weight" )
153
+ new_item = new_item .replace ("proj_out.bias" , "to_out.0 .bias" )
154
154
155
155
new_item = shave_segments (new_item , n_shave_prefix_segments = n_shave_prefix_segments )
156
156
@@ -204,8 +204,12 @@ def assign_to_checkpoint(
204
204
new_path = new_path .replace (replacement ["old" ], replacement ["new" ])
205
205
206
206
# proj_attn.weight has to be converted from conv 1D to linear
207
- if "proj_attn.weight" in new_path :
207
+ is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path )
208
+ shape = old_checkpoint [path ["old" ]].shape
209
+ if is_attn_weight and len (shape ) == 3 :
208
210
checkpoint [new_path ] = old_checkpoint [path ["old" ]][:, :, 0 ]
211
+ elif is_attn_weight and len (shape ) == 4 :
212
+ checkpoint [new_path ] = old_checkpoint [path ["old" ]][:, :, 0 , 0 ]
209
213
else :
210
214
checkpoint [new_path ] = old_checkpoint [path ["old" ]]
211
215
0 commit comments