@@ -750,23 +750,18 @@ def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep
750
750
)
751
751
752
752
def auto_corr_loss (self , hidden_states , generator = None ):
753
- batch_size , channel , height , width = hidden_states .shape
754
- if batch_size > 1 :
755
- raise ValueError ("Only batch_size 1 is supported for now" )
756
-
757
- hidden_states = hidden_states .squeeze (0 )
758
- # hidden_states must be shape [C,H,W] now
759
753
reg_loss = 0.0
760
754
for i in range (hidden_states .shape [0 ]):
761
- noise = hidden_states [i ][None , None , :, :]
762
- while True :
763
- roll_amount = torch .randint (noise .shape [2 ] // 2 , (1 ,), generator = generator ).item ()
764
- reg_loss += (noise * torch .roll (noise , shifts = roll_amount , dims = 2 )).mean () ** 2
765
- reg_loss += (noise * torch .roll (noise , shifts = roll_amount , dims = 3 )).mean () ** 2
766
-
767
- if noise .shape [2 ] <= 8 :
768
- break
769
- noise = F .avg_pool2d (noise , kernel_size = 2 )
755
+ for j in range (hidden_states .shape [1 ]):
756
+ noise = hidden_states [i : i + 1 , j : j + 1 , :, :]
757
+ while True :
758
+ roll_amount = torch .randint (noise .shape [2 ] // 2 , (1 ,), generator = generator ).item ()
759
+ reg_loss += (noise * torch .roll (noise , shifts = roll_amount , dims = 2 )).mean () ** 2
760
+ reg_loss += (noise * torch .roll (noise , shifts = roll_amount , dims = 3 )).mean () ** 2
761
+
762
+ if noise .shape [2 ] <= 8 :
763
+ break
764
+ noise = F .avg_pool2d (noise , kernel_size = 2 )
770
765
return reg_loss
771
766
772
767
def kl_divergence (self , hidden_states ):
0 commit comments