Skip to content

Commit 4188f30

Browse files
btlorchsayakpaul
andcommitted
Convert RGB to BGR for the SDXL watermark encoder (#7013)
* Convert channel order to BGR for the watermark encoder. Convert the watermarked BGR images back to RGB. Fixes #6292 * Revert channel order before stacking images to overcome limitations that negative strides are currently not supported --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent de41461 commit 4188f30

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/diffusers/pipelines/stable_diffusion_xl/watermark.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,15 @@ def apply_watermark(self, images: torch.FloatTensor):
2828

2929
images = (255 * (images / 2 + 0.5)).cpu().permute(0, 2, 3, 1).float().numpy()
3030

31-
images = [self.encoder.encode(image, "dwtDct") for image in images]
31+
# Convert RGB to BGR, which is the channel order expected by the watermark encoder.
32+
images = images[:, :, :, ::-1]
3233

33-
images = torch.from_numpy(np.array(images)).permute(0, 3, 1, 2)
34+
# Add watermark and convert BGR back to RGB
35+
images = [self.encoder.encode(image, "dwtDct")[:, :, ::-1] for image in images]
36+
37+
images = np.array(images)
38+
39+
images = torch.from_numpy(images).permute(0, 3, 1, 2)
3440

3541
images = torch.clamp(2 * (images / 255 - 0.5), min=-1.0, max=1.0)
3642
return images

0 commit comments

Comments
 (0)