Skip to content

Commit 54b21c0

Browse files
authored
Merge branch 'main' into animatediff-img2video
2 parents 1c645ed + a9288b4 commit 54b21c0

File tree

11 files changed

+1600
-404
lines changed

11 files changed

+1600
-404
lines changed

.github/workflows/pr_test_peft_backend.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ jobs:
5959
6060
- name: Run fast PyTorch LoRA CPU tests with PEFT backend
6161
run: |
62-
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
62+
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
6363
-s -v \
6464
--make-reports=tests_${{ matrix.config.report }} \
6565
tests/lora/test_lora_layers_peft.py

docs/source/en/api/models/autoencoderkl.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ model = AutoencoderKL.from_single_file(url)
3333
## AutoencoderKL
3434

3535
[[autodoc]] AutoencoderKL
36+
- decode
37+
- encode
38+
- all
3639

3740
## AutoencoderKLOutput
3841

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,7 +1279,7 @@ def main(args):
12791279
for name, param in text_encoder_one.named_parameters():
12801280
if "token_embedding" in name:
12811281
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
1282-
param = param.to(dtype=torch.float32)
1282+
param.data = param.to(dtype=torch.float32)
12831283
param.requires_grad = True
12841284
text_lora_parameters_one.append(param)
12851285
else:
@@ -1288,7 +1288,7 @@ def main(args):
12881288
for name, param in text_encoder_two.named_parameters():
12891289
if "token_embedding" in name:
12901290
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
1291-
param = param.to(dtype=torch.float32)
1291+
param.data = param.to(dtype=torch.float32)
12921292
param.requires_grad = True
12931293
text_lora_parameters_two.append(param)
12941294
else:
@@ -1725,19 +1725,19 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
17251725
num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
17261726
elif args.train_text_encoder_ti: # args.train_text_encoder_ti
17271727
num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs)
1728-
1728+
# flag used for textual inversion
1729+
pivoted = False
17291730
for epoch in range(first_epoch, args.num_train_epochs):
17301731
# if performing any kind of optimization of text_encoder params
17311732
if args.train_text_encoder or args.train_text_encoder_ti:
17321733
if epoch == num_train_epochs_text_encoder:
17331734
print("PIVOT HALFWAY", epoch)
17341735
# stopping optimization of text_encoder params
1735-
# re setting the optimizer to optimize only on unet params
1736-
optimizer.param_groups[1]["lr"] = 0.0
1737-
optimizer.param_groups[2]["lr"] = 0.0
1736+
# this flag is used to reset the optimizer to optimize only on unet params
1737+
pivoted = True
17381738

17391739
else:
1740-
# still optimizng the text encoder
1740+
# still optimizing the text encoder
17411741
text_encoder_one.train()
17421742
text_encoder_two.train()
17431743
# set top parameter requires_grad = True for gradient checkpointing works
@@ -1747,6 +1747,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
17471747

17481748
unet.train()
17491749
for step, batch in enumerate(train_dataloader):
1750+
if pivoted:
1751+
# stopping optimization of text_encoder params
1752+
# re setting the optimizer to optimize only on unet params
1753+
optimizer.param_groups[1]["lr"] = 0.0
1754+
optimizer.param_groups[2]["lr"] = 0.0
1755+
17501756
with accelerator.accumulate(unet):
17511757
prompts = batch["prompts"]
17521758
# encode batch prompts when custom prompts are provided for each image -
@@ -1885,8 +1891,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18851891

18861892
# every step, we reset the embeddings to the original embeddings.
18871893
if args.train_text_encoder_ti:
1888-
for idx, text_encoder in enumerate(text_encoders):
1889-
embedding_handler.retract_embeddings()
1894+
embedding_handler.retract_embeddings()
18901895

18911896
# Checks if the accelerator has performed an optimization step behind the scenes
18921897
if accelerator.sync_gradients:

examples/community/README.md

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap
5858
| Null-Text Inversion Pipeline | Implement [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/abs/2211.09794) as a pipeline. | [Null-Text Inversion](https://github.com/google/prompt-to-prompt/) | - | [Junsheng Luan](https://github.com/Junsheng121) |
5959
| Rerender A Video Pipeline | Implementation of [[SIGGRAPH Asia 2023] Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation](https://arxiv.org/abs/2306.07954) | [Rerender A Video Pipeline](#Rerender_A_Video) | - | [Yifan Zhou](https://github.com/SingleZombie) |
6060
| StyleAligned Pipeline | Implementation of [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133) | [StyleAligned Pipeline](#stylealigned-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/15X2E0jFPTajUIjS0FzX50OaHsCbP2lQ0/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
61+
| AnimateDiff Image-To-Video Pipeline | Experimental Image-To-Video support for AnimateDiff (open to improvements) | [AnimateDiff Image To Video Pipeline](#animatediff-image-to-video-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/1TvzCDPHhfFtdcJZe4RLloAwyoLKuttWK/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
62+
6163
| IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) | - | [Fabio Rigano](https://github.com/fabiorigano) |
6264

6365
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
@@ -2990,7 +2992,7 @@ pipe = DiffusionPipeline.from_pretrained(
29902992
custom_pipeline="pipeline_animatediff_controlnet",
29912993
).to(device="cuda", dtype=torch.float16)
29922994
pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
2993-
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
2995+
model_id, subfolder="scheduler", beta_schedule="linear", clip_sample=False, timestep_spacing="linspace", steps_offset=1
29942996
)
29952997
pipe.enable_vae_slicing()
29962998
@@ -3409,7 +3411,32 @@ images = pipe(
34093411
pipe.disable_style_aligned()
34103412
```
34113413

3414+
### AnimateDiff Image-To-Video Pipeline
3415+
3416+
This pipeline adds experimental support for the image-to-video task using AnimateDiff. Refer to [this](https://github.com/huggingface/diffusers/pull/6328) PR for more examples and results.
3417+
3418+
```py
3419+
import torch
3420+
from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler
3421+
from diffusers.utils import export_to_gif, load_image
3422+
3423+
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
3424+
pipe = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda")
3425+
pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace")
3426+
3427+
image = load_image("snail.png")
3428+
output = pipe(
3429+
image=image,
3430+
prompt="A snail moving on the ground",
3431+
strength=0.8,
3432+
latent_interpolation_method="slerp", # can be lerp, slerp, or your own callback
3433+
)
3434+
frames = output.frames[0]
3435+
export_to_gif(frames, "animation.gif")
3436+
```
3437+
34123438
### IP Adapter Face ID
3439+
34133440
IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded.
34143441
You need to install `insightface` and all its requirements to use this model.
34153442
You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`.

0 commit comments

Comments
 (0)