Skip to content

Commit 73b59f5

Browse files
authored
[refactor] enhance readability of flux related pipelines (#9711)
* flux pipline: readability enhancement.
1 parent 52d4449 commit 73b59f5

10 files changed

+110
-96
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2198,8 +2198,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
21982198

21992199
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
22002200
model_input.shape[0],
2201-
model_input.shape[2],
2202-
model_input.shape[3],
2201+
model_input.shape[2] // 2,
2202+
model_input.shape[3] // 2,
22032203
accelerator.device,
22042204
weight_dtype,
22052205
)
@@ -2253,8 +2253,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
22532253
)[0]
22542254
model_pred = FluxPipeline._unpack_latents(
22552255
model_pred,
2256-
height=int(model_input.shape[2] * vae_scale_factor / 2),
2257-
width=int(model_input.shape[3] * vae_scale_factor / 2),
2256+
height=model_input.shape[2] * vae_scale_factor,
2257+
width=model_input.shape[3] * vae_scale_factor,
22582258
vae_scale_factor=vae_scale_factor,
22592259
)
22602260

examples/controlnet/train_controlnet_flux.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,8 +1256,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
12561256

12571257
latent_image_ids = FluxControlNetPipeline._prepare_latent_image_ids(
12581258
batch_size=pixel_latents_tmp.shape[0],
1259-
height=pixel_latents_tmp.shape[2],
1260-
width=pixel_latents_tmp.shape[3],
1259+
height=pixel_latents_tmp.shape[2] // 2,
1260+
width=pixel_latents_tmp.shape[3] // 2,
12611261
device=pixel_values.device,
12621262
dtype=pixel_values.dtype,
12631263
)

examples/dreambooth/train_dreambooth_flux.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,12 +1540,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15401540
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
15411541
model_input = model_input.to(dtype=weight_dtype)
15421542

1543-
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
1543+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
15441544

15451545
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
15461546
model_input.shape[0],
1547-
model_input.shape[2],
1548-
model_input.shape[3],
1547+
model_input.shape[2] // 2,
1548+
model_input.shape[3] // 2,
15491549
accelerator.device,
15501550
weight_dtype,
15511551
)
@@ -1601,8 +1601,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16011601
# upscaling height & width as discussed in https://github.com/huggingface/diffusers/pull/9257#discussion_r1731108042
16021602
model_pred = FluxPipeline._unpack_latents(
16031603
model_pred,
1604-
height=int(model_input.shape[2] * vae_scale_factor / 2),
1605-
width=int(model_input.shape[3] * vae_scale_factor / 2),
1604+
height=model_input.shape[2] * vae_scale_factor,
1605+
width=model_input.shape[3] * vae_scale_factor,
16061606
vae_scale_factor=vae_scale_factor,
16071607
)
16081608

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1645,12 +1645,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16451645
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
16461646
model_input = model_input.to(dtype=weight_dtype)
16471647

1648-
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
1648+
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
16491649

16501650
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
16511651
model_input.shape[0],
1652-
model_input.shape[2],
1653-
model_input.shape[3],
1652+
model_input.shape[2] // 2,
1653+
model_input.shape[3] // 2,
16541654
accelerator.device,
16551655
weight_dtype,
16561656
)
@@ -1704,8 +1704,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17041704
)[0]
17051705
model_pred = FluxPipeline._unpack_latents(
17061706
model_pred,
1707-
height=int(model_input.shape[2] * vae_scale_factor / 2),
1708-
width=int(model_input.shape[3] * vae_scale_factor / 2),
1707+
height=model_input.shape[2] * vae_scale_factor,
1708+
width=model_input.shape[3] * vae_scale_factor,
17091709
vae_scale_factor=vae_scale_factor,
17101710
)
17111711

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -195,13 +195,13 @@ def __init__(
195195
scheduler=scheduler,
196196
)
197197
self.vae_scale_factor = (
198-
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
198+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
199199
)
200200
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
201201
self.tokenizer_max_length = (
202202
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
203203
)
204-
self.default_sample_size = 64
204+
self.default_sample_size = 128
205205

206206
def _get_t5_prompt_embeds(
207207
self,
@@ -386,8 +386,10 @@ def check_inputs(
386386
callback_on_step_end_tensor_inputs=None,
387387
max_sequence_length=None,
388388
):
389-
if height % 8 != 0 or width % 8 != 0:
390-
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
389+
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
390+
raise ValueError(
391+
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
392+
)
391393

392394
if callback_on_step_end_tensor_inputs is not None and not all(
393395
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -425,9 +427,9 @@ def check_inputs(
425427

426428
@staticmethod
427429
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
428-
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
429-
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
430-
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
430+
latent_image_ids = torch.zeros(height, width, 3)
431+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
432+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
431433

432434
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
433435

@@ -452,10 +454,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
452454
height = height // vae_scale_factor
453455
width = width // vae_scale_factor
454456

455-
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
457+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
456458
latents = latents.permute(0, 3, 1, 4, 2, 5)
457459

458-
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
460+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
459461

460462
return latents
461463

@@ -499,8 +501,8 @@ def prepare_latents(
499501
generator,
500502
latents=None,
501503
):
502-
height = 2 * (int(height) // self.vae_scale_factor)
503-
width = 2 * (int(width) // self.vae_scale_factor)
504+
height = int(height) // self.vae_scale_factor
505+
width = int(width) // self.vae_scale_factor
504506

505507
shape = (batch_size, num_channels_latents, height, width)
506508

@@ -517,7 +519,7 @@ def prepare_latents(
517519
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
518520
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
519521

520-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
522+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
521523

522524
return latents, latent_image_ids
523525

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,13 @@ def __init__(
216216
controlnet=controlnet,
217217
)
218218
self.vae_scale_factor = (
219-
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
219+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
220220
)
221221
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
222222
self.tokenizer_max_length = (
223223
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
224224
)
225-
self.default_sample_size = 64
225+
self.default_sample_size = 128
226226

227227
def _get_t5_prompt_embeds(
228228
self,
@@ -410,8 +410,10 @@ def check_inputs(
410410
callback_on_step_end_tensor_inputs=None,
411411
max_sequence_length=None,
412412
):
413-
if height % 8 != 0 or width % 8 != 0:
414-
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
413+
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
414+
raise ValueError(
415+
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
416+
)
415417

416418
if callback_on_step_end_tensor_inputs is not None and not all(
417419
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -450,9 +452,9 @@ def check_inputs(
450452
@staticmethod
451453
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
452454
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
453-
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
454-
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
455-
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
455+
latent_image_ids = torch.zeros(height, width, 3)
456+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
457+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
456458

457459
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
458460

@@ -479,10 +481,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
479481
height = height // vae_scale_factor
480482
width = width // vae_scale_factor
481483

482-
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
484+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
483485
latents = latents.permute(0, 3, 1, 4, 2, 5)
484486

485-
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
487+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
486488

487489
return latents
488490

@@ -498,8 +500,8 @@ def prepare_latents(
498500
generator,
499501
latents=None,
500502
):
501-
height = 2 * (int(height) // self.vae_scale_factor)
502-
width = 2 * (int(width) // self.vae_scale_factor)
503+
height = int(height) // self.vae_scale_factor
504+
width = int(width) // self.vae_scale_factor
503505

504506
shape = (batch_size, num_channels_latents, height, width)
505507

@@ -516,7 +518,7 @@ def prepare_latents(
516518
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
517519
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
518520

519-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
521+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
520522

521523
return latents, latent_image_ids
522524

src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -228,13 +228,13 @@ def __init__(
228228
controlnet=controlnet,
229229
)
230230
self.vae_scale_factor = (
231-
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
231+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
232232
)
233233
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
234234
self.tokenizer_max_length = (
235235
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
236236
)
237-
self.default_sample_size = 64
237+
self.default_sample_size = 128
238238

239239
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
240240
def _get_t5_prompt_embeds(
@@ -453,8 +453,10 @@ def check_inputs(
453453
if strength < 0 or strength > 1:
454454
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
455455

456-
if height % 8 != 0 or width % 8 != 0:
457-
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
456+
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
457+
raise ValueError(
458+
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
459+
)
458460

459461
if callback_on_step_end_tensor_inputs is not None and not all(
460462
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -493,9 +495,9 @@ def check_inputs(
493495
@staticmethod
494496
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
495497
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
496-
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
497-
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
498-
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
498+
latent_image_ids = torch.zeros(height, width, 3)
499+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
500+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
499501

500502
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
501503

@@ -522,10 +524,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
522524
height = height // vae_scale_factor
523525
width = width // vae_scale_factor
524526

525-
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
527+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
526528
latents = latents.permute(0, 3, 1, 4, 2, 5)
527529

528-
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
530+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
529531

530532
return latents
531533

@@ -549,11 +551,11 @@ def prepare_latents(
549551
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
550552
)
551553

552-
height = 2 * (int(height) // self.vae_scale_factor)
553-
width = 2 * (int(width) // self.vae_scale_factor)
554+
height = int(height) // self.vae_scale_factor
555+
width = int(width) // self.vae_scale_factor
554556

555557
shape = (batch_size, num_channels_latents, height, width)
556-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
558+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
557559

558560
if latents is not None:
559561
return latents.to(device=device, dtype=dtype), latent_image_ids
@@ -852,7 +854,7 @@ def __call__(
852854
control_mode = control_mode.reshape([-1, 1])
853855

854856
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
855-
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
857+
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
856858
mu = calculate_shift(
857859
image_seq_len,
858860
self.scheduler.config.base_image_seq_len,

0 commit comments

Comments
 (0)