Skip to content

Commit 6620eda

Browse files
authored
Standardise outputs for video pipelines (#6626)
* update * update * update * update * update * update * update * clean up * clean up
1 parent 1f0705a commit 6620eda

File tree

7 files changed

+91
-77
lines changed

7 files changed

+91
-77
lines changed

src/diffusers/pipelines/animatediff/pipeline_animatediff.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,7 @@
6767
"""
6868

6969

70-
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
71-
# Based on:
72-
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
73-
70+
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
7471
batch_size, channels, num_frames, height, width = video.shape
7572
outputs = []
7673
for batch_idx in range(batch_size):
@@ -79,6 +76,15 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
7976

8077
outputs.append(batch_output)
8178

79+
if output_type == "np":
80+
outputs = np.stack(outputs)
81+
82+
elif output_type == "pt":
83+
outputs = torch.stack(outputs)
84+
85+
elif not output_type == "pil":
86+
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
87+
8288
return outputs
8389

8490

@@ -805,11 +811,7 @@ def _retrieve_video_frames(self, latents, output_type, return_dict):
805811
return AnimateDiffPipelineOutput(frames=latents)
806812

807813
video_tensor = self.decode_latents(latents)
808-
809-
if output_type == "pt":
810-
video = video_tensor
811-
else:
812-
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
814+
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
813815

814816
if not return_dict:
815817
return (video,)

src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,8 @@ def _append_dims(x, target_dims):
4040
return x[(...,) + (None,) * dims_to_append]
4141

4242

43-
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
44-
# Based on:
45-
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
46-
43+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
44+
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
4745
batch_size, channels, num_frames, height, width = video.shape
4846
outputs = []
4947
for batch_idx in range(batch_size):
@@ -53,7 +51,13 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
5351
outputs.append(batch_output)
5452

5553
if output_type == "np":
56-
return np.stack(outputs)
54+
outputs = np.stack(outputs)
55+
56+
elif output_type == "pt":
57+
outputs = torch.stack(outputs)
58+
59+
elif not output_type == "pil":
60+
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
5761

5862
return outputs
5963

src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
from transformers import CLIPTextModel, CLIPTokenizer
2121

22+
from ...image_processor import VaeImageProcessor
2223
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
2324
from ...models import AutoencoderKL, UNet3DConditionModel
2425
from ...models.lora import adjust_lora_scale_text_encoder
@@ -58,22 +59,26 @@
5859
"""
5960

6061

61-
def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
62-
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
63-
# reshape to ncfhw
64-
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
65-
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
66-
# unnormalize back to [0,1]
67-
video = video.mul_(std).add_(mean)
68-
video.clamp_(0, 1)
69-
# prepare the final outputs
70-
i, c, f, h, w = video.shape
71-
images = video.permute(2, 3, 0, 4, 1).reshape(
72-
f, h, i * w, c
73-
) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
74-
images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
75-
images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
76-
return images
62+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
63+
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
64+
batch_size, channels, num_frames, height, width = video.shape
65+
outputs = []
66+
for batch_idx in range(batch_size):
67+
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
68+
batch_output = processor.postprocess(batch_vid, output_type)
69+
70+
outputs.append(batch_output)
71+
72+
if output_type == "np":
73+
outputs = np.stack(outputs)
74+
75+
elif output_type == "pt":
76+
outputs = torch.stack(outputs)
77+
78+
elif not output_type == "pil":
79+
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
80+
81+
return outputs
7782

7883

7984
class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
@@ -122,6 +127,7 @@ def __init__(
122127
scheduler=scheduler,
123128
)
124129
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
130+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
125131

126132
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
127133
def enable_vae_slicing(self):
@@ -717,11 +723,7 @@ def __call__(
717723
return TextToVideoSDPipelineOutput(frames=latents)
718724

719725
video_tensor = self.decode_latents(latents)
720-
721-
if output_type == "pt":
722-
video = video_tensor
723-
else:
724-
video = tensor2vid(video_tensor)
726+
video = tensor2vid(video_tensor, self.image_processor, output_type)
725727

726728
# Offload all models
727729
self.maybe_free_model_hooks()

src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121
from transformers import CLIPTextModel, CLIPTokenizer
2222

23+
from ...image_processor import VaeImageProcessor
2324
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
2425
from ...models import AutoencoderKL, UNet3DConditionModel
2526
from ...models.lora import adjust_lora_scale_text_encoder
@@ -93,22 +94,26 @@ def retrieve_latents(
9394
raise AttributeError("Could not access latents of provided encoder_output")
9495

9596

96-
def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
97-
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
98-
# reshape to ncfhw
99-
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
100-
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
101-
# unnormalize back to [0,1]
102-
video = video.mul_(std).add_(mean)
103-
video.clamp_(0, 1)
104-
# prepare the final outputs
105-
i, c, f, h, w = video.shape
106-
images = video.permute(2, 3, 0, 4, 1).reshape(
107-
f, h, i * w, c
108-
) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
109-
images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
110-
images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
111-
return images
97+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
98+
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
99+
batch_size, channels, num_frames, height, width = video.shape
100+
outputs = []
101+
for batch_idx in range(batch_size):
102+
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
103+
batch_output = processor.postprocess(batch_vid, output_type)
104+
105+
outputs.append(batch_output)
106+
107+
if output_type == "np":
108+
outputs = np.stack(outputs)
109+
110+
elif output_type == "pt":
111+
outputs = torch.stack(outputs)
112+
113+
elif not output_type == "pil":
114+
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
115+
116+
return outputs
112117

113118

114119
def preprocess_video(video):
@@ -198,6 +203,7 @@ def __init__(
198203
scheduler=scheduler,
199204
)
200205
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
206+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
201207

202208
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
203209
def enable_vae_slicing(self):
@@ -812,12 +818,11 @@ def __call__(
812818
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
813819
self.unet.to("cpu")
814820

815-
video_tensor = self.decode_latents(latents)
821+
if output_type == "latent":
822+
return TextToVideoSDPipelineOutput(frames=latents)
816823

817-
if output_type == "pt":
818-
video = video_tensor
819-
else:
820-
video = tensor2vid(video_tensor)
824+
video_tensor = self.decode_latents(latents)
825+
video = tensor2vid(video_tensor, self.image_processor, output_type)
821826

822827
# Offload all models
823828
self.maybe_free_model_hooks()

tests/pipelines/animatediff/test_animatediff.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def test_free_init(self):
262262
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
263263
max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max()
264264
self.assertGreater(
265-
sum_enabled, 1e2, "Enabling of FreeInit should lead to results different from the default pipeline results"
265+
sum_enabled, 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results"
266266
)
267267
self.assertLess(
268268
max_diff_disabled,

tests/pipelines/text_to_video_synthesis/test_text_to_video.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from diffusers.utils.testing_utils import (
3030
enable_full_determinism,
3131
load_numpy,
32+
numpy_cosine_similarity_distance,
3233
require_torch_gpu,
3334
skip_mps,
3435
slow,
@@ -141,10 +142,11 @@ def test_text_to_video_default_case(self):
141142
inputs = self.get_dummy_inputs(device)
142143
inputs["output_type"] = "np"
143144
frames = sd_pipe(**inputs).frames
144-
image_slice = frames[0][-3:, -3:, -1]
145145

146-
assert frames[0].shape == (32, 32, 3)
147-
expected_slice = np.array([192.0, 44.0, 157.0, 140.0, 108.0, 104.0, 123.0, 144.0, 129.0])
146+
image_slice = frames[0][0][-3:, -3:, -1]
147+
148+
assert frames[0][0].shape == (32, 32, 3)
149+
expected_slice = np.array([0.7537, 0.1752, 0.6157, 0.5508, 0.4240, 0.4110, 0.4838, 0.5648, 0.5094])
148150

149151
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
150152

@@ -183,7 +185,7 @@ def test_progress_bar(self):
183185
class TextToVideoSDPipelineSlowTests(unittest.TestCase):
184186
def test_two_step_model(self):
185187
expected_video = load_numpy(
186-
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/video_2step.npy"
188+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/video_2step.npy"
187189
)
188190

189191
pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b")
@@ -192,10 +194,8 @@ def test_two_step_model(self):
192194
prompt = "Spiderman is surfing"
193195
generator = torch.Generator(device="cpu").manual_seed(0)
194196

195-
video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="pt").frames
196-
video = video_frames.cpu().numpy()
197-
198-
assert np.abs(expected_video - video).mean() < 5e-2
197+
video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").frames
198+
assert numpy_cosine_similarity_distance(expected_video.flatten(), video_frames.flatten()) < 1e-4
199199

200200
def test_two_step_model_with_freeu(self):
201201
expected_video = []
@@ -207,10 +207,9 @@ def test_two_step_model_with_freeu(self):
207207
generator = torch.Generator(device="cpu").manual_seed(0)
208208

209209
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
210-
video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="pt").frames
211-
video = video_frames.cpu().numpy()
212-
video = video[0, 0, -3:, -3:, -1].flatten()
210+
video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").frames
211+
video = video_frames[0, 0, -3:, -3:, -1].flatten()
213212

214-
expected_video = [-0.3102, -0.2477, -0.1772, -0.648, -0.6176, -0.5484, -0.0217, -0.056, -0.0177]
213+
expected_video = [0.3643, 0.3455, 0.3831, 0.3923, 0.2978, 0.3247, 0.3278, 0.3201, 0.3475]
215214

216215
assert np.abs(expected_video - video).mean() < 5e-2

tests/pipelines/text_to_video_synthesis/test_video_to_video.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,10 @@ def test_text_to_video_default_case(self):
157157
inputs = self.get_dummy_inputs(device)
158158
inputs["output_type"] = "np"
159159
frames = sd_pipe(**inputs).frames
160-
image_slice = frames[0][-3:, -3:, -1]
160+
image_slice = frames[0][0][-3:, -3:, -1]
161161

162-
assert frames[0].shape == (32, 32, 3)
163-
expected_slice = np.array([162.0, 136.0, 132.0, 140.0, 139.0, 137.0, 169.0, 134.0, 132.0])
162+
assert frames[0][0].shape == (32, 32, 3)
163+
expected_slice = np.array([0.6391, 0.5350, 0.5202, 0.5521, 0.5453, 0.5393, 0.6652, 0.5270, 0.5185])
164164

165165
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
166166

@@ -214,9 +214,11 @@ def test_two_step_model(self):
214214

215215
prompt = "Spiderman is surfing"
216216

217-
video_frames = pipe(prompt, video=video, generator=generator, num_inference_steps=3, output_type="pt").frames
218-
219-
expected_array = np.array([-0.9770508, -0.8027344, -0.62646484, -0.8334961, -0.7573242])
220-
output_array = video_frames.cpu().numpy()[0, 0, 0, 0, -5:]
217+
generator = torch.Generator(device="cpu").manual_seed(0)
218+
video_frames = pipe(prompt, video=video, generator=generator, num_inference_steps=3, output_type="np").frames
221219

222-
assert numpy_cosine_similarity_distance(expected_array, output_array) < 1e-2
220+
expected_array = np.array(
221+
[0.17114258, 0.13720703, 0.08886719, 0.14819336, 0.1730957, 0.24584961, 0.22021484, 0.35180664, 0.2607422]
222+
)
223+
output_array = video_frames[0, 0, :3, :3, 0].flatten()
224+
assert numpy_cosine_similarity_distance(expected_array, output_array) < 1e-3

0 commit comments

Comments
 (0)