17
17
import math
18
18
19
19
20
+ def get_views (video_length , window_size = 16 , stride = 4 ):
21
+ num_blocks_time = (video_length - window_size ) // stride + 1
22
+ views = []
23
+ for i in range (num_blocks_time ):
24
+ t_start = int (i * stride )
25
+ t_end = t_start + window_size
26
+ views .append ((t_start ,t_end ))
27
+ return views
28
+
29
+
30
+ def generate_weight_sequence (n ):
31
+ if n % 2 == 0 :
32
+ max_weight = n // 2
33
+ weight_sequence = list (range (1 , max_weight + 1 , 1 )) + list (range (max_weight , 0 , - 1 ))
34
+ else :
35
+ max_weight = (n + 1 ) // 2
36
+ weight_sequence = list (range (1 , max_weight , 1 )) + [max_weight ] + list (range (max_weight - 1 , 0 , - 1 ))
37
+ return weight_sequence
38
+
39
+
20
40
def zero_module (module ):
21
41
# Zero out the parameters of a module and return it.
22
42
for p in module .parameters ():
@@ -46,6 +66,16 @@ def get_motion_module(
46
66
else :
47
67
raise ValueError
48
68
69
+ def get_window_motion_module (
70
+ in_channels ,
71
+ motion_module_type : str ,
72
+ motion_module_kwargs : dict
73
+ ):
74
+ if motion_module_type == "Vanilla" :
75
+ return VanillaTemporalModule (in_channels = in_channels , local_window = True , ** motion_module_kwargs ,)
76
+ else :
77
+ raise ValueError
78
+
49
79
50
80
class VanillaTemporalModule (nn .Module ):
51
81
def __init__ (
@@ -59,6 +89,7 @@ def __init__(
59
89
temporal_position_encoding_max_len = 24 ,
60
90
temporal_attention_dim_div = 1 ,
61
91
zero_initialize = True ,
92
+ ** kwargs ,
62
93
):
63
94
super ().__init__ ()
64
95
@@ -71,6 +102,7 @@ def __init__(
71
102
cross_frame_attention_mode = cross_frame_attention_mode ,
72
103
temporal_position_encoding = temporal_position_encoding ,
73
104
temporal_position_encoding_max_len = temporal_position_encoding_max_len ,
105
+ ** kwargs ,
74
106
)
75
107
76
108
if zero_initialize :
@@ -103,6 +135,7 @@ def __init__(
103
135
cross_frame_attention_mode = None ,
104
136
temporal_position_encoding = False ,
105
137
temporal_position_encoding_max_len = 24 ,
138
+ ** kwargs ,
106
139
):
107
140
super ().__init__ ()
108
141
@@ -127,6 +160,7 @@ def __init__(
127
160
cross_frame_attention_mode = cross_frame_attention_mode ,
128
161
temporal_position_encoding = temporal_position_encoding ,
129
162
temporal_position_encoding_max_len = temporal_position_encoding_max_len ,
163
+ ** kwargs ,
130
164
)
131
165
for d in range (num_layers )
132
166
]
@@ -176,6 +210,7 @@ def __init__(
176
210
cross_frame_attention_mode = None ,
177
211
temporal_position_encoding = False ,
178
212
temporal_position_encoding_max_len = 24 ,
213
+ local_window = False ,
179
214
):
180
215
super ().__init__ ()
181
216
@@ -208,15 +243,52 @@ def __init__(
208
243
self .ff = FeedForward (dim , dropout = dropout , activation_fn = activation_fn )
209
244
self .ff_norm = nn .LayerNorm (dim )
210
245
246
+ self .local_window = local_window
247
+
211
248
212
249
def forward (self , hidden_states , encoder_hidden_states = None , attention_mask = None , video_length = None ):
213
- for attention_block , norm in zip (self .attention_blocks , self .norms ):
214
- norm_hidden_states = norm (hidden_states )
215
- hidden_states = attention_block (
216
- norm_hidden_states ,
217
- encoder_hidden_states = encoder_hidden_states if attention_block .is_cross_attention else None ,
218
- video_length = video_length ,
219
- ) + hidden_states
250
+
251
+ if not self .local_window :
252
+ for attention_block , norm in zip (self .attention_blocks , self .norms ):
253
+ norm_hidden_states = norm (hidden_states )
254
+ hidden_states = attention_block (
255
+ norm_hidden_states ,
256
+ encoder_hidden_states = encoder_hidden_states if attention_block .is_cross_attention else None ,
257
+ video_length = video_length ,
258
+ ) + hidden_states
259
+ else :
260
+ views = get_views (video_length )
261
+ hidden_states = rearrange (hidden_states , "(b f) d c -> b f d c" , f = video_length )
262
+ count = torch .zeros_like (hidden_states )
263
+ value = torch .zeros_like (hidden_states )
264
+ for t_start , t_end in views :
265
+ weight_sequence = generate_weight_sequence (t_end - t_start )
266
+ weight_tensor = torch .ones_like (count [:, t_start :t_end ])
267
+ weight_tensor = weight_tensor * torch .Tensor (weight_sequence ).to (hidden_states .device ).unsqueeze (0 ).unsqueeze (- 1 ).unsqueeze (- 1 )
268
+
269
+ sub_hidden_states = rearrange (hidden_states [:, t_start :t_end ], "b f d c -> (b f) d c" )
270
+ for attention_block , norm in zip (self .attention_blocks , self .norms ):
271
+ norm_hidden_states = norm (sub_hidden_states )
272
+ sub_hidden_states = attention_block (
273
+ norm_hidden_states ,
274
+ encoder_hidden_states = encoder_hidden_states if attention_block .is_cross_attention else None ,
275
+ video_length = t_end - t_start ,
276
+ ) + sub_hidden_states
277
+ sub_hidden_states = rearrange (sub_hidden_states , "(b f) d c -> b f d c" , f = t_end - t_start )
278
+
279
+ value [:,t_start :t_end ] += sub_hidden_states * weight_tensor
280
+ count [:,t_start :t_end ] += weight_tensor
281
+
282
+ hidden_states = torch .where (count > 0 , value / count , value )
283
+ hidden_states = rearrange (hidden_states , "b f d c -> (b f) d c" )
284
+
285
+ # for attention_block, norm in zip(self.attention_blocks, self.norms):
286
+ # norm_hidden_states = norm(hidden_states)
287
+ # hidden_states = attention_block(
288
+ # norm_hidden_states,
289
+ # encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
290
+ # video_length=video_length,
291
+ # ) + hidden_states
220
292
221
293
hidden_states = self .ff (self .ff_norm (hidden_states )) + hidden_states
222
294
0 commit comments