Skip to content

Commit b218062

Browse files
committed
Update Pix2PixZero Auto-correlation Loss
1 parent b94880e commit b218062

File tree

1 file changed

+10
-15
lines changed

1 file changed

+10
-15
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -750,23 +750,18 @@ def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep
750750
)
751751

752752
def auto_corr_loss(self, hidden_states, generator=None):
753-
batch_size, channel, height, width = hidden_states.shape
754-
if batch_size > 1:
755-
raise ValueError("Only batch_size 1 is supported for now")
756-
757-
hidden_states = hidden_states.squeeze(0)
758-
# hidden_states must be shape [C,H,W] now
759753
reg_loss = 0.0
760754
for i in range(hidden_states.shape[0]):
761-
noise = hidden_states[i][None, None, :, :]
762-
while True:
763-
roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
764-
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
765-
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
766-
767-
if noise.shape[2] <= 8:
768-
break
769-
noise = F.avg_pool2d(noise, kernel_size=2)
755+
for j in range(hidden_states.shape[1]):
756+
noise = hidden_states[i : i + 1, j : j + 1, :, :]
757+
while True:
758+
roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
759+
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
760+
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
761+
762+
if noise.shape[2] <= 8:
763+
break
764+
noise = F.avg_pool2d(noise, kernel_size=2)
770765
return reg_loss
771766

772767
def kl_divergence(self, hidden_states):

0 commit comments

Comments
 (0)