Skip to content

Commit cdea032

Browse files
superlabs-devJimmy
authored and
Jimmy
committed
fix tiled vae blend extent range (huggingface#3384)
fix tiled vae bleand extent range
1 parent 2ce5583 commit cdea032

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/diffusers/models/autoencoder_kl.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,14 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode
196196
return DecoderOutput(sample=decoded)
197197

198198
def blend_v(self, a, b, blend_extent):
199-
for y in range(min(a.shape[2], b.shape[2], blend_extent)):
199+
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
200+
for y in range(blend_extent):
200201
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
201202
return b
202203

203204
def blend_h(self, a, b, blend_extent):
204-
for x in range(min(a.shape[3], b.shape[3], blend_extent)):
205+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
206+
for x in range(blend_extent):
205207
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
206208
return b
207209

0 commit comments

Comments
 (0)