Skip to content

Commit 2e83cbb

Browse files
a-r-r-o-wyiyixuxuhlky
authored
LTX 0.9.5 (#10968)
* update --------- Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: hlky <[email protected]>
1 parent 33d10af commit 2e83cbb

File tree

13 files changed

+1865
-54
lines changed

13 files changed

+1865
-54
lines changed

docs/source/en/api/pipelines/ltx_video.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24)
196196
- all
197197
- __call__
198198

199+
## LTXConditionPipeline
200+
201+
[[autodoc]] LTXConditionPipeline
202+
- all
203+
- __call__
204+
199205
## LTXPipelineOutput
200206

201207
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput

scripts/convert_ltx_to_diffusers.py

Lines changed: 89 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,39 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
7474
"last_scale_shift_table": "scale_shift_table",
7575
}
7676

77+
VAE_095_RENAME_DICT = {
78+
# decoder
79+
"up_blocks.0": "mid_block",
80+
"up_blocks.1": "up_blocks.0.upsamplers.0",
81+
"up_blocks.2": "up_blocks.0",
82+
"up_blocks.3": "up_blocks.1.upsamplers.0",
83+
"up_blocks.4": "up_blocks.1",
84+
"up_blocks.5": "up_blocks.2.upsamplers.0",
85+
"up_blocks.6": "up_blocks.2",
86+
"up_blocks.7": "up_blocks.3.upsamplers.0",
87+
"up_blocks.8": "up_blocks.3",
88+
# encoder
89+
"down_blocks.0": "down_blocks.0",
90+
"down_blocks.1": "down_blocks.0.downsamplers.0",
91+
"down_blocks.2": "down_blocks.1",
92+
"down_blocks.3": "down_blocks.1.downsamplers.0",
93+
"down_blocks.4": "down_blocks.2",
94+
"down_blocks.5": "down_blocks.2.downsamplers.0",
95+
"down_blocks.6": "down_blocks.3",
96+
"down_blocks.7": "down_blocks.3.downsamplers.0",
97+
"down_blocks.8": "mid_block",
98+
# common
99+
"last_time_embedder": "time_embedder",
100+
"last_scale_shift_table": "scale_shift_table",
101+
}
102+
77103
VAE_SPECIAL_KEYS_REMAP = {
78104
"per_channel_statistics.channel": remove_keys_,
79105
"per_channel_statistics.mean-of-means": remove_keys_,
80106
"per_channel_statistics.mean-of-stds": remove_keys_,
81107
"model.diffusion_model": remove_keys_,
82108
}
83109

84-
VAE_091_SPECIAL_KEYS_REMAP = {
85-
"timestep_scale_multiplier": remove_keys_,
86-
}
87-
88110

89111
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
90112
state_dict = saved_dict
@@ -104,12 +126,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
104126
def convert_transformer(
105127
ckpt_path: str,
106128
dtype: torch.dtype,
129+
version: str = "0.9.0",
107130
):
108131
PREFIX_KEY = "model.diffusion_model."
109132

110133
original_state_dict = get_state_dict(load_file(ckpt_path))
134+
config = {}
135+
if version == "0.9.5":
136+
config["_use_causal_rope_fix"] = True
111137
with init_empty_weights():
112-
transformer = LTXVideoTransformer3DModel()
138+
transformer = LTXVideoTransformer3DModel(**config)
113139

114140
for key in list(original_state_dict.keys()):
115141
new_key = key[:]
@@ -161,12 +187,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
161187
"out_channels": 3,
162188
"latent_channels": 128,
163189
"block_out_channels": (128, 256, 512, 512),
190+
"down_block_types": (
191+
"LTXVideoDownBlock3D",
192+
"LTXVideoDownBlock3D",
193+
"LTXVideoDownBlock3D",
194+
"LTXVideoDownBlock3D",
195+
),
164196
"decoder_block_out_channels": (128, 256, 512, 512),
165197
"layers_per_block": (4, 3, 3, 3, 4),
166198
"decoder_layers_per_block": (4, 3, 3, 3, 4),
167199
"spatio_temporal_scaling": (True, True, True, False),
168200
"decoder_spatio_temporal_scaling": (True, True, True, False),
169201
"decoder_inject_noise": (False, False, False, False, False),
202+
"downsample_type": ("conv", "conv", "conv", "conv"),
170203
"upsample_residual": (False, False, False, False),
171204
"upsample_factor": (1, 1, 1, 1),
172205
"patch_size": 4,
@@ -183,12 +216,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
183216
"out_channels": 3,
184217
"latent_channels": 128,
185218
"block_out_channels": (128, 256, 512, 512),
219+
"down_block_types": (
220+
"LTXVideoDownBlock3D",
221+
"LTXVideoDownBlock3D",
222+
"LTXVideoDownBlock3D",
223+
"LTXVideoDownBlock3D",
224+
),
186225
"decoder_block_out_channels": (256, 512, 1024),
187226
"layers_per_block": (4, 3, 3, 3, 4),
188227
"decoder_layers_per_block": (5, 6, 7, 8),
189228
"spatio_temporal_scaling": (True, True, True, False),
190229
"decoder_spatio_temporal_scaling": (True, True, True),
191230
"decoder_inject_noise": (True, True, True, False),
231+
"downsample_type": ("conv", "conv", "conv", "conv"),
192232
"upsample_residual": (True, True, True),
193233
"upsample_factor": (2, 2, 2),
194234
"timestep_conditioning": True,
@@ -200,7 +240,38 @@ def get_vae_config(version: str) -> Dict[str, Any]:
200240
"decoder_causal": False,
201241
}
202242
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
203-
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
243+
elif version == "0.9.5":
244+
config = {
245+
"in_channels": 3,
246+
"out_channels": 3,
247+
"latent_channels": 128,
248+
"block_out_channels": (128, 256, 512, 1024, 2048),
249+
"down_block_types": (
250+
"LTXVideo095DownBlock3D",
251+
"LTXVideo095DownBlock3D",
252+
"LTXVideo095DownBlock3D",
253+
"LTXVideo095DownBlock3D",
254+
),
255+
"decoder_block_out_channels": (256, 512, 1024),
256+
"layers_per_block": (4, 6, 6, 2, 2),
257+
"decoder_layers_per_block": (5, 5, 5, 5),
258+
"spatio_temporal_scaling": (True, True, True, True),
259+
"decoder_spatio_temporal_scaling": (True, True, True),
260+
"decoder_inject_noise": (False, False, False, False),
261+
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
262+
"upsample_residual": (True, True, True),
263+
"upsample_factor": (2, 2, 2),
264+
"timestep_conditioning": True,
265+
"patch_size": 4,
266+
"patch_size_t": 1,
267+
"resnet_norm_eps": 1e-6,
268+
"scaling_factor": 1.0,
269+
"encoder_causal": True,
270+
"decoder_causal": False,
271+
"spatial_compression_ratio": 32,
272+
"temporal_compression_ratio": 8,
273+
}
274+
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
204275
return config
205276

206277

@@ -223,7 +294,7 @@ def get_args():
223294
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
224295
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
225296
parser.add_argument(
226-
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
297+
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
227298
)
228299
return parser.parse_args()
229300

@@ -277,14 +348,17 @@ def get_args():
277348
for param in text_encoder.parameters():
278349
param.data = param.data.contiguous()
279350

280-
scheduler = FlowMatchEulerDiscreteScheduler(
281-
use_dynamic_shifting=True,
282-
base_shift=0.95,
283-
max_shift=2.05,
284-
base_image_seq_len=1024,
285-
max_image_seq_len=4096,
286-
shift_terminal=0.1,
287-
)
351+
if args.version == "0.9.5":
352+
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
353+
else:
354+
scheduler = FlowMatchEulerDiscreteScheduler(
355+
use_dynamic_shifting=True,
356+
base_shift=0.95,
357+
max_shift=2.05,
358+
base_image_seq_len=1024,
359+
max_image_seq_len=4096,
360+
shift_terminal=0.1,
361+
)
288362

289363
pipe = LTXPipeline(
290364
scheduler=scheduler,

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@
402402
"LDMTextToImagePipeline",
403403
"LEditsPPPipelineStableDiffusion",
404404
"LEditsPPPipelineStableDiffusionXL",
405+
"LTXConditionPipeline",
405406
"LTXImageToVideoPipeline",
406407
"LTXPipeline",
407408
"Lumina2Pipeline",
@@ -947,6 +948,7 @@
947948
LDMTextToImagePipeline,
948949
LEditsPPPipelineStableDiffusion,
949950
LEditsPPPipelineStableDiffusionXL,
951+
LTXConditionPipeline,
950952
LTXImageToVideoPipeline,
951953
LTXPipeline,
952954
Lumina2Pipeline,

0 commit comments

Comments
 (0)