Skip to content

Commit c55a50a

Browse files
committed
update docs
1 parent 661a0b3 commit c55a50a

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

docs/source/en/api/pipelines/animatediff.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,64 @@ export_to_gif(frames, "animatelcm-motion-lora.gif")
822822
</tr>
823823
</table>
824824

825+
## Using FreeNoise
826+
827+
[FreeNoise: Tuning-Free Longer Video Diffusion via Noise Rescheduling](https://arxiv.org/abs/2310.15169) by Haonan Qiu, Menghan Xia, Yong Zhang, Yingqing He, Xintao Wang, Ying Shan, Ziwei Liu.
828+
829+
FreeNoise is a sampling mechanism that allows the generation of longer videos with short-video generation models by employing noise-rescheduling, temporal attention over sliding windows, and weighted averaging of latent frames. It also can be used with multiple prompts to allow for interpolated video generations. More details are available in the paper.
830+
831+
```python
832+
import torch
833+
from diffusers import AutoencoderKL, AnimateDiffPipeline, LCMScheduler, MotionAdapter
834+
from diffusers.utils import export_to_video, load_image
835+
836+
# Load pipeline
837+
dtype = torch.float16
838+
motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM", torch_dtype=dtype)
839+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype)
840+
841+
pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=motion_adapter, vae=vae, torch_dtype=dtype)
842+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
843+
844+
pipe.load_lora_weights(
845+
"wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm_lora"
846+
)
847+
pipe.set_adapters(["lcm_lora"], [0.8])
848+
849+
# Enable FreeNoise for long prompt generation
850+
pipe.enable_free_noise(context_length=16, context_stride=4)
851+
pipe.to("cuda")
852+
853+
# Optionally, enable memory efficient inference
854+
pipe.enable_free_noise_split_inference()
855+
pipe.unet.enable_forward_chunking(16)
856+
857+
# Can be a single prompt, or a dictionary with frame timesteps
858+
prompt = {
859+
0: "A caterpillar on a leaf, high quality, photorealistic",
860+
40: "A caterpillar transforming into a cocoon, on a leaf, near flowers, photorealistic",
861+
80: "A cocoon on a leaf, flowers in the backgrond, photorealistic",
862+
120: "A cocoon maturing and a butterfly being born, flowers and leaves visible in the background, photorealistic",
863+
160: "A beautiful butterfly, vibrant colors, sitting on a leaf, flowers in the background, photorealistic",
864+
200: "A beautiful butterfly, flying away in a forest, photorealistic",
865+
240: "A cyberpunk butterfly, neon lights, glowing",
866+
}
867+
negative_prompt = "bad quality, worst quality, jpeg artifacts"
868+
869+
# Run inference
870+
output = pipe(
871+
prompt=prompt,
872+
negative_prompt=negative_prompt,
873+
num_frames=256,
874+
guidance_scale=2.5,
875+
num_inference_steps=10,
876+
generator=torch.Generator("cpu").manual_seed(0),
877+
)
878+
879+
# Save video
880+
frames = output.frames[0]
881+
export_to_video(frames, "output.mp4", fps=16)
882+
```
825883

826884
## Using `from_single_file` with the MotionAdapter
827885

src/diffusers/pipelines/free_noise_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,21 @@ def _enable_split_inference_samplers_(
549549
samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"])
550550

551551
def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None:
552+
r"""
553+
Enable FreeNoise memory optimizations by utilizing
554+
[`~diffusers.pipelines.free_noise_utils.SplitInferenceModule`] across different intermediate modeling blocks.
555+
556+
Args:
557+
spatial_split_size (`int`, defaults to `256`):
558+
The split size across spatial dimensions for internal blocks. This is used in facilitating split
559+
inference across the effective batch dimension (`[B x H x W, F, C]`) of intermediate tensors in motion
560+
modeling blocks.
561+
temporal_split_size (`int`, defaults to `16`):
562+
The split size across temporal dimensions for internal blocks. This is used in facilitating split
563+
inference across the effective batch dimension (`[B x F, H x W, C]`) of intermediate tensors in spatial
564+
attention, resnets, downsampling and upsampling blocks.
565+
"""
566+
# TODO(aryan): Discuss on what's the best way to provide more control to users
552567
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
553568
for block in blocks:
554569
if getattr(block, "motion_modules", None) is not None:

0 commit comments

Comments
 (0)