Skip to content

[WIP] test prepare_latents for ltx0.95 #10976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Mar 14, 2025
Merged
6 changes: 6 additions & 0 deletions docs/source/en/api/pipelines/ltx_video.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24)
- all
- __call__

## LTXConditionPipeline

[[autodoc]] LTXConditionPipeline
- all
- __call__

## LTXPipelineOutput

[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
23 changes: 15 additions & 8 deletions scripts/convert_ltx_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
"model.diffusion_model": remove_keys_,
"decoder.timestep_scale_multiplier": remove_keys_,
}


Expand Down Expand Up @@ -268,6 +269,9 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"scaling_factor": 1.0,
"encoder_causal": True,
"decoder_causal": False,
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
"timestep_scale_multiplier": 1000.0,
}
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
return config
Expand Down Expand Up @@ -346,14 +350,17 @@ def get_args():
for param in text_encoder.parameters():
param.data = param.data.contiguous()

scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
if args.version == "0.9.5":
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
else:
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)

pipe = LTXPipeline(
scheduler=scheduler,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
"LTXConditionPipeline",
"LTXImageToVideoPipeline",
"LTXPipeline",
"Lumina2Text2ImgPipeline",
Expand Down Expand Up @@ -857,6 +858,7 @@
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
LTXConditionPipeline,
LTXImageToVideoPipeline,
LTXPipeline,
Lumina2Text2ImgPipeline,
Expand Down
22 changes: 17 additions & 5 deletions src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,12 +921,14 @@ def __init__(
timestep_conditioning: bool = False,
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1),
timestep_scale_multiplier: float = 1.0,
) -> None:
super().__init__()

self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.out_channels = out_channels * patch_size**2
self.timestep_scale_multiplier = timestep_scale_multiplier

block_out_channels = tuple(reversed(block_out_channels))
spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
Expand Down Expand Up @@ -981,9 +983,7 @@ def __init__(
# timestep embedding
self.time_embedder = None
self.scale_shift_table = None
self.timestep_scale_multiplier = None
if timestep_conditioning:
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)

Expand All @@ -992,7 +992,7 @@ def __init__(
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)

if self.timestep_scale_multiplier is not None:
if temb is not None:
temb = temb * self.timestep_scale_multiplier

if torch.is_grad_enabled() and self.gradient_checkpointing:
Expand Down Expand Up @@ -1105,6 +1105,9 @@ def __init__(
scaling_factor: float = 1.0,
encoder_causal: bool = True,
decoder_causal: bool = False,
spatial_compression_ratio: int = None,
temporal_compression_ratio: int = None,
timestep_scale_multiplier: float = 1.0,
) -> None:
super().__init__()

Expand Down Expand Up @@ -1135,15 +1138,24 @@ def __init__(
inject_noise=decoder_inject_noise,
upsample_residual=upsample_residual,
upsample_factor=upsample_factor,
timestep_scale_multiplier=timestep_scale_multiplier,
)

latents_mean = torch.zeros((latent_channels,), requires_grad=False)
latents_std = torch.ones((latent_channels,), requires_grad=False)
self.register_buffer("latents_mean", latents_mean, persistent=True)
self.register_buffer("latents_std", latents_std, persistent=True)

self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling)
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling)
self.spatial_compression_ratio = (
patch_size * 2 ** sum(spatio_temporal_scaling)
if spatial_compression_ratio is None
else spatial_compression_ratio
)
self.temporal_compression_ratio = (
patch_size_t * 2 ** sum(spatio_temporal_scaling)
if temporal_compression_ratio is None
else temporal_compression_ratio
)

# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
Expand Down
99 changes: 66 additions & 33 deletions src/diffusers/models/transformers/transformer_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,47 +115,77 @@ def __init__(
self.theta = theta
self._causal_rope_fix = _causal_rope_fix

def forward(
def _prepare_video_coords(
self,
hidden_states: torch.Tensor,
batch_size: int,
num_frames: int,
height: int,
width: int,
frame_rate: Optional[int] = None,
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.size(0)

rope_interpolation_scale: Tuple[torch.Tensor, float, float],
frame_rate: float,
device: torch.device,
) -> torch.Tensor:
# Always compute rope in fp32
grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device)
grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device)
grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device)
grid_h = torch.arange(height, dtype=torch.float32, device=device)
grid_w = torch.arange(width, dtype=torch.float32, device=device)
grid_f = torch.arange(num_frames, dtype=torch.float32, device=device)
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
grid = torch.stack(grid, dim=0)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)

if rope_interpolation_scale is not None:
if isinstance(rope_interpolation_scale, tuple):
# This will be deprecated in v0.34.0
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
if isinstance(rope_interpolation_scale, tuple):
# This will be deprecated in v0.34.0
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
else:
if not self._causal_rope_fix:
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0:1] * self.patch_size_t / self.base_num_frames
else:
if not self._causal_rope_fix:
grid[:, 0:1] = (
grid[:, 0:1] * rope_interpolation_scale[0:1] * self.patch_size_t / self.base_num_frames
)
else:
grid[:, 0:1] = (
((grid[:, 0:1] - 1) * rope_interpolation_scale[0:1] + 1 / frame_rate).clamp(min=0)
* self.patch_size_t
/ self.base_num_frames
)
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1:2] * self.patch_size / self.base_height
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2:3] * self.patch_size / self.base_width
grid[:, 0:1] = (
((grid[:, 0:1] - 1) * rope_interpolation_scale[0:1] + 1 / frame_rate).clamp(min=0)
* self.patch_size_t
/ self.base_num_frames
)
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1:2] * self.patch_size / self.base_height
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2:3] * self.patch_size / self.base_width

grid = grid.flatten(2, 4).transpose(1, 2)

return grid

def forward(
self,
hidden_states: torch.Tensor,
num_frames: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
frame_rate: Optional[int] = None,
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
video_coords: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.size(0)

if video_coords is None:
grid = self._prepare_video_coords(
batch_size,
num_frames,
height,
width,
rope_interpolation_scale=rope_interpolation_scale,
frame_rate=frame_rate,
device=hidden_states.device,
)
else:
grid = torch.stack(
[
video_coords[:, 0] / self.base_num_frames,
video_coords[:, 1] / self.base_height,
video_coords[:, 2] / self.base_width,
],
dim=-1,
)

start = 1.0
end = self.theta
freqs = self.theta ** torch.linspace(
Expand Down Expand Up @@ -387,11 +417,12 @@ def forward(
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor,
num_frames: int,
height: int,
width: int,
frame_rate: int,
num_frames: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
frame_rate: Optional[int] = None,
rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None,
video_coords: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> torch.Tensor:
Expand All @@ -414,7 +445,9 @@ def forward(
msg = "Passing a tuple for `rope_interpolation_scale` is deprecated and will be removed in v0.34.0."
deprecate("rope_interpolation_scale", "0.34.0", msg)

image_rotary_emb = self.rope(hidden_states, num_frames, height, width, frame_rate, rope_interpolation_scale)
image_rotary_emb = self.rope(
hidden_states, num_frames, height, width, frame_rate, rope_interpolation_scale, video_coords
)

# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@
]
)
_import_structure["latte"] = ["LattePipeline"]
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"]
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"]
_import_structure["lumina"] = ["LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Text2ImgPipeline"]
_import_structure["marigold"].extend(
Expand Down Expand Up @@ -610,7 +610,7 @@
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
)
from .ltx import LTXImageToVideoPipeline, LTXPipeline
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline
from .lumina import LuminaText2ImgPipeline
from .lumina2 import Lumina2Text2ImgPipeline
from .marigold import (
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/ltx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Expand All @@ -34,6 +35,7 @@
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_ltx import LTXPipeline
from .pipeline_ltx_condition import LTXConditionPipeline
from .pipeline_ltx_image2video import LTXImageToVideoPipeline

else:
Expand Down
Loading