Skip to content

Commit 9abf5ed

Browse files
AnonymousAnonymous
Anonymous
authored and
Anonymous
committed
add freenoise
1 parent 0a8d8ef commit 9abf5ed

File tree

5 files changed

+100
-11
lines changed

5 files changed

+100
-11
lines changed

animatediff/models/motion_module.py

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,26 @@
1717
import math
1818

1919

20+
def get_views(video_length, window_size=16, stride=4):
21+
num_blocks_time = (video_length - window_size) // stride + 1
22+
views = []
23+
for i in range(num_blocks_time):
24+
t_start = int(i * stride)
25+
t_end = t_start + window_size
26+
views.append((t_start,t_end))
27+
return views
28+
29+
30+
def generate_weight_sequence(n):
31+
if n % 2 == 0:
32+
max_weight = n // 2
33+
weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1))
34+
else:
35+
max_weight = (n + 1) // 2
36+
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
37+
return weight_sequence
38+
39+
2040
def zero_module(module):
2141
# Zero out the parameters of a module and return it.
2242
for p in module.parameters():
@@ -46,6 +66,16 @@ def get_motion_module(
4666
else:
4767
raise ValueError
4868

69+
def get_window_motion_module(
70+
in_channels,
71+
motion_module_type: str,
72+
motion_module_kwargs: dict
73+
):
74+
if motion_module_type == "Vanilla":
75+
return VanillaTemporalModule(in_channels=in_channels, local_window=True, **motion_module_kwargs,)
76+
else:
77+
raise ValueError
78+
4979

5080
class VanillaTemporalModule(nn.Module):
5181
def __init__(
@@ -59,6 +89,7 @@ def __init__(
5989
temporal_position_encoding_max_len = 24,
6090
temporal_attention_dim_div = 1,
6191
zero_initialize = True,
92+
**kwargs,
6293
):
6394
super().__init__()
6495

@@ -71,6 +102,7 @@ def __init__(
71102
cross_frame_attention_mode=cross_frame_attention_mode,
72103
temporal_position_encoding=temporal_position_encoding,
73104
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
105+
**kwargs,
74106
)
75107

76108
if zero_initialize:
@@ -103,6 +135,7 @@ def __init__(
103135
cross_frame_attention_mode = None,
104136
temporal_position_encoding = False,
105137
temporal_position_encoding_max_len = 24,
138+
**kwargs,
106139
):
107140
super().__init__()
108141

@@ -127,6 +160,7 @@ def __init__(
127160
cross_frame_attention_mode=cross_frame_attention_mode,
128161
temporal_position_encoding=temporal_position_encoding,
129162
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
163+
**kwargs,
130164
)
131165
for d in range(num_layers)
132166
]
@@ -176,6 +210,7 @@ def __init__(
176210
cross_frame_attention_mode = None,
177211
temporal_position_encoding = False,
178212
temporal_position_encoding_max_len = 24,
213+
local_window = False,
179214
):
180215
super().__init__()
181216

@@ -208,15 +243,52 @@ def __init__(
208243
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
209244
self.ff_norm = nn.LayerNorm(dim)
210245

246+
self.local_window = local_window
247+
211248

212249
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
213-
for attention_block, norm in zip(self.attention_blocks, self.norms):
214-
norm_hidden_states = norm(hidden_states)
215-
hidden_states = attention_block(
216-
norm_hidden_states,
217-
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
218-
video_length=video_length,
219-
) + hidden_states
250+
251+
if not self.local_window:
252+
for attention_block, norm in zip(self.attention_blocks, self.norms):
253+
norm_hidden_states = norm(hidden_states)
254+
hidden_states = attention_block(
255+
norm_hidden_states,
256+
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
257+
video_length=video_length,
258+
) + hidden_states
259+
else:
260+
views = get_views(video_length)
261+
hidden_states = rearrange(hidden_states, "(b f) d c -> b f d c", f=video_length)
262+
count = torch.zeros_like(hidden_states)
263+
value = torch.zeros_like(hidden_states)
264+
for t_start, t_end in views:
265+
weight_sequence = generate_weight_sequence(t_end - t_start)
266+
weight_tensor = torch.ones_like(count[:, t_start:t_end])
267+
weight_tensor = weight_tensor * torch.Tensor(weight_sequence).to(hidden_states.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
268+
269+
sub_hidden_states = rearrange(hidden_states[:, t_start:t_end], "b f d c -> (b f) d c")
270+
for attention_block, norm in zip(self.attention_blocks, self.norms):
271+
norm_hidden_states = norm(sub_hidden_states)
272+
sub_hidden_states = attention_block(
273+
norm_hidden_states,
274+
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
275+
video_length=t_end-t_start,
276+
) + sub_hidden_states
277+
sub_hidden_states = rearrange(sub_hidden_states, "(b f) d c -> b f d c", f=t_end-t_start)
278+
279+
value[:,t_start:t_end] += sub_hidden_states * weight_tensor
280+
count[:,t_start:t_end] += weight_tensor
281+
282+
hidden_states = torch.where(count>0, value/count, value)
283+
hidden_states = rearrange(hidden_states, "b f d c -> (b f) d c")
284+
285+
# for attention_block, norm in zip(self.attention_blocks, self.norms):
286+
# norm_hidden_states = norm(hidden_states)
287+
# hidden_states = attention_block(
288+
# norm_hidden_states,
289+
# encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
290+
# video_length=video_length,
291+
# ) + hidden_states
220292

221293
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
222294

animatediff/models/unet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(
8888
motion_module_kwargs = {},
8989
unet_use_cross_frame_attention = None,
9090
unet_use_temporal_attention = None,
91+
**kwargs,
9192
):
9293
super().__init__()
9394

animatediff/models/unet_blocks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from .attention import Transformer3DModel
77
from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
8-
from .motion_module import get_motion_module
8+
# from .motion_module import get_motion_module
9+
from .motion_module import get_window_motion_module as get_motion_module
910

1011
import pdb
1112

animatediff/pipelines/pipeline_animation.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from einops import rearrange
2929

3030
from ..models.unet import UNet3DConditionModel
31+
import random
3132

3233

3334
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -283,8 +284,18 @@ def check_inputs(self, prompt, height, width, callback_steps):
283284
f" {type(callback_steps)}."
284285
)
285286

286-
def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
287+
def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None, use_freenoise=False):
287288
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
289+
290+
if use_freenoise:
291+
window_size = 16
292+
window_stride = 4
293+
latents = torch.randn(shape)
294+
for frame_index in range(window_size, video_length, window_stride):
295+
list_index = list(range(frame_index-window_size, frame_index+window_stride-window_size))
296+
random.shuffle(list_index)
297+
latents[:, :, frame_index:frame_index+window_stride] = latents[:, :, list_index]
298+
288299
if isinstance(generator, list) and len(generator) != batch_size:
289300
raise ValueError(
290301
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -330,6 +341,7 @@ def __call__(
330341
return_dict: bool = True,
331342
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
332343
callback_steps: Optional[int] = 1,
344+
use_freenoise: Optional[bool] = False,
333345
**kwargs,
334346
):
335347
# Default height and width to unet
@@ -377,6 +389,7 @@ def __call__(
377389
device,
378390
generator,
379391
latents,
392+
use_freenoise = use_freenoise,
380393
)
381394
latents_dtype = latents.dtype
382395

@@ -392,7 +405,7 @@ def __call__(
392405
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
393406

394407
# predict the noise residual
395-
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
408+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, use_freenoise=use_freenoise).sample.to(dtype=latents_dtype)
396409
# noise_pred = []
397410
# import pdb
398411
# pdb.set_trace()

scripts/animate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def main(args):
9494
width = args.W,
9595
height = args.H,
9696
video_length = args.L,
97+
use_freenoise = args.use_freenoise,
9798
).videos
9899
samples.append(sample)
99100

@@ -115,9 +116,10 @@ def main(args):
115116
parser.add_argument("--inference_config", type=str, default="configs/inference/inference-v1.yaml")
116117
parser.add_argument("--config", type=str, required=True)
117118

118-
parser.add_argument("--L", type=int, default=16 )
119+
parser.add_argument("--L", type=int, default=64 )
119120
parser.add_argument("--W", type=int, default=512)
120121
parser.add_argument("--H", type=int, default=512)
122+
parser.add_argument("--use_freenoise", type=bool, default=True)
121123

122124
args = parser.parse_args()
123125
main(args)

0 commit comments

Comments
 (0)