@@ -74,17 +74,39 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
74
74
"last_scale_shift_table" : "scale_shift_table" ,
75
75
}
76
76
77
+ VAE_095_RENAME_DICT = {
78
+ # decoder
79
+ "up_blocks.0" : "mid_block" ,
80
+ "up_blocks.1" : "up_blocks.0.upsamplers.0" ,
81
+ "up_blocks.2" : "up_blocks.0" ,
82
+ "up_blocks.3" : "up_blocks.1.upsamplers.0" ,
83
+ "up_blocks.4" : "up_blocks.1" ,
84
+ "up_blocks.5" : "up_blocks.2.upsamplers.0" ,
85
+ "up_blocks.6" : "up_blocks.2" ,
86
+ "up_blocks.7" : "up_blocks.3.upsamplers.0" ,
87
+ "up_blocks.8" : "up_blocks.3" ,
88
+ # encoder
89
+ "down_blocks.0" : "down_blocks.0" ,
90
+ "down_blocks.1" : "down_blocks.0.downsamplers.0" ,
91
+ "down_blocks.2" : "down_blocks.1" ,
92
+ "down_blocks.3" : "down_blocks.1.downsamplers.0" ,
93
+ "down_blocks.4" : "down_blocks.2" ,
94
+ "down_blocks.5" : "down_blocks.2.downsamplers.0" ,
95
+ "down_blocks.6" : "down_blocks.3" ,
96
+ "down_blocks.7" : "down_blocks.3.downsamplers.0" ,
97
+ "down_blocks.8" : "mid_block" ,
98
+ # common
99
+ "last_time_embedder" : "time_embedder" ,
100
+ "last_scale_shift_table" : "scale_shift_table" ,
101
+ }
102
+
77
103
VAE_SPECIAL_KEYS_REMAP = {
78
104
"per_channel_statistics.channel" : remove_keys_ ,
79
105
"per_channel_statistics.mean-of-means" : remove_keys_ ,
80
106
"per_channel_statistics.mean-of-stds" : remove_keys_ ,
81
107
"model.diffusion_model" : remove_keys_ ,
82
108
}
83
109
84
- VAE_091_SPECIAL_KEYS_REMAP = {
85
- "timestep_scale_multiplier" : remove_keys_ ,
86
- }
87
-
88
110
89
111
def get_state_dict (saved_dict : Dict [str , Any ]) -> Dict [str , Any ]:
90
112
state_dict = saved_dict
@@ -104,12 +126,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
104
126
def convert_transformer (
105
127
ckpt_path : str ,
106
128
dtype : torch .dtype ,
129
+ version : str = "0.9.0" ,
107
130
):
108
131
PREFIX_KEY = "model.diffusion_model."
109
132
110
133
original_state_dict = get_state_dict (load_file (ckpt_path ))
134
+ config = {}
135
+ if version == "0.9.5" :
136
+ config ["_use_causal_rope_fix" ] = True
111
137
with init_empty_weights ():
112
- transformer = LTXVideoTransformer3DModel ()
138
+ transformer = LTXVideoTransformer3DModel (** config )
113
139
114
140
for key in list (original_state_dict .keys ()):
115
141
new_key = key [:]
@@ -161,12 +187,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
161
187
"out_channels" : 3 ,
162
188
"latent_channels" : 128 ,
163
189
"block_out_channels" : (128 , 256 , 512 , 512 ),
190
+ "down_block_types" : (
191
+ "LTXVideoDownBlock3D" ,
192
+ "LTXVideoDownBlock3D" ,
193
+ "LTXVideoDownBlock3D" ,
194
+ "LTXVideoDownBlock3D" ,
195
+ ),
164
196
"decoder_block_out_channels" : (128 , 256 , 512 , 512 ),
165
197
"layers_per_block" : (4 , 3 , 3 , 3 , 4 ),
166
198
"decoder_layers_per_block" : (4 , 3 , 3 , 3 , 4 ),
167
199
"spatio_temporal_scaling" : (True , True , True , False ),
168
200
"decoder_spatio_temporal_scaling" : (True , True , True , False ),
169
201
"decoder_inject_noise" : (False , False , False , False , False ),
202
+ "downsample_type" : ("conv" , "conv" , "conv" , "conv" ),
170
203
"upsample_residual" : (False , False , False , False ),
171
204
"upsample_factor" : (1 , 1 , 1 , 1 ),
172
205
"patch_size" : 4 ,
@@ -183,12 +216,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
183
216
"out_channels" : 3 ,
184
217
"latent_channels" : 128 ,
185
218
"block_out_channels" : (128 , 256 , 512 , 512 ),
219
+ "down_block_types" : (
220
+ "LTXVideoDownBlock3D" ,
221
+ "LTXVideoDownBlock3D" ,
222
+ "LTXVideoDownBlock3D" ,
223
+ "LTXVideoDownBlock3D" ,
224
+ ),
186
225
"decoder_block_out_channels" : (256 , 512 , 1024 ),
187
226
"layers_per_block" : (4 , 3 , 3 , 3 , 4 ),
188
227
"decoder_layers_per_block" : (5 , 6 , 7 , 8 ),
189
228
"spatio_temporal_scaling" : (True , True , True , False ),
190
229
"decoder_spatio_temporal_scaling" : (True , True , True ),
191
230
"decoder_inject_noise" : (True , True , True , False ),
231
+ "downsample_type" : ("conv" , "conv" , "conv" , "conv" ),
192
232
"upsample_residual" : (True , True , True ),
193
233
"upsample_factor" : (2 , 2 , 2 ),
194
234
"timestep_conditioning" : True ,
@@ -200,7 +240,38 @@ def get_vae_config(version: str) -> Dict[str, Any]:
200
240
"decoder_causal" : False ,
201
241
}
202
242
VAE_KEYS_RENAME_DICT .update (VAE_091_RENAME_DICT )
203
- VAE_SPECIAL_KEYS_REMAP .update (VAE_091_SPECIAL_KEYS_REMAP )
243
+ elif version == "0.9.5" :
244
+ config = {
245
+ "in_channels" : 3 ,
246
+ "out_channels" : 3 ,
247
+ "latent_channels" : 128 ,
248
+ "block_out_channels" : (128 , 256 , 512 , 1024 , 2048 ),
249
+ "down_block_types" : (
250
+ "LTXVideo095DownBlock3D" ,
251
+ "LTXVideo095DownBlock3D" ,
252
+ "LTXVideo095DownBlock3D" ,
253
+ "LTXVideo095DownBlock3D" ,
254
+ ),
255
+ "decoder_block_out_channels" : (256 , 512 , 1024 ),
256
+ "layers_per_block" : (4 , 6 , 6 , 2 , 2 ),
257
+ "decoder_layers_per_block" : (5 , 5 , 5 , 5 ),
258
+ "spatio_temporal_scaling" : (True , True , True , True ),
259
+ "decoder_spatio_temporal_scaling" : (True , True , True ),
260
+ "decoder_inject_noise" : (False , False , False , False ),
261
+ "downsample_type" : ("spatial" , "temporal" , "spatiotemporal" , "spatiotemporal" ),
262
+ "upsample_residual" : (True , True , True ),
263
+ "upsample_factor" : (2 , 2 , 2 ),
264
+ "timestep_conditioning" : True ,
265
+ "patch_size" : 4 ,
266
+ "patch_size_t" : 1 ,
267
+ "resnet_norm_eps" : 1e-6 ,
268
+ "scaling_factor" : 1.0 ,
269
+ "encoder_causal" : True ,
270
+ "decoder_causal" : False ,
271
+ "spatial_compression_ratio" : 32 ,
272
+ "temporal_compression_ratio" : 8 ,
273
+ }
274
+ VAE_KEYS_RENAME_DICT .update (VAE_095_RENAME_DICT )
204
275
return config
205
276
206
277
@@ -223,7 +294,7 @@ def get_args():
223
294
parser .add_argument ("--output_path" , type = str , required = True , help = "Path where converted model should be saved" )
224
295
parser .add_argument ("--dtype" , default = "fp32" , help = "Torch dtype to save the model in." )
225
296
parser .add_argument (
226
- "--version" , type = str , default = "0.9.0" , choices = ["0.9.0" , "0.9.1" ], help = "Version of the LTX model"
297
+ "--version" , type = str , default = "0.9.0" , choices = ["0.9.0" , "0.9.1" , "0.9.5" ], help = "Version of the LTX model"
227
298
)
228
299
return parser .parse_args ()
229
300
@@ -277,14 +348,17 @@ def get_args():
277
348
for param in text_encoder .parameters ():
278
349
param .data = param .data .contiguous ()
279
350
280
- scheduler = FlowMatchEulerDiscreteScheduler (
281
- use_dynamic_shifting = True ,
282
- base_shift = 0.95 ,
283
- max_shift = 2.05 ,
284
- base_image_seq_len = 1024 ,
285
- max_image_seq_len = 4096 ,
286
- shift_terminal = 0.1 ,
287
- )
351
+ if args .version == "0.9.5" :
352
+ scheduler = FlowMatchEulerDiscreteScheduler (use_dynamic_shifting = False )
353
+ else :
354
+ scheduler = FlowMatchEulerDiscreteScheduler (
355
+ use_dynamic_shifting = True ,
356
+ base_shift = 0.95 ,
357
+ max_shift = 2.05 ,
358
+ base_image_seq_len = 1024 ,
359
+ max_image_seq_len = 4096 ,
360
+ shift_terminal = 0.1 ,
361
+ )
288
362
289
363
pipe = LTXPipeline (
290
364
scheduler = scheduler ,
0 commit comments