Skip to content

Commit c9e266e

Browse files
hlkysayakpaul
authored andcommitted
Add sigmas to Flux pipelines (#10081)
1 parent 812e688 commit c9e266e

9 files changed

+63
-68
lines changed

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ def __call__(
554554
height: Optional[int] = None,
555555
width: Optional[int] = None,
556556
num_inference_steps: int = 28,
557-
timesteps: List[int] = None,
557+
sigmas: Optional[List[float]] = None,
558558
guidance_scale: float = 3.5,
559559
num_images_per_prompt: Optional[int] = 1,
560560
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -585,10 +585,10 @@ def __call__(
585585
num_inference_steps (`int`, *optional*, defaults to 50):
586586
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
587587
expense of slower inference.
588-
timesteps (`List[int]`, *optional*):
589-
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
590-
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
591-
passed will be used. Must be in descending order.
588+
sigmas (`List[float]`, *optional*):
589+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
590+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
591+
will be used.
592592
guidance_scale (`float`, *optional*, defaults to 7.0):
593593
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
594594
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -699,7 +699,7 @@ def __call__(
699699
)
700700

701701
# 5. Prepare timesteps
702-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
702+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
703703
image_seq_len = latents.shape[1]
704704
mu = calculate_shift(
705705
image_seq_len,
@@ -712,8 +712,7 @@ def __call__(
712712
self.scheduler,
713713
num_inference_steps,
714714
device,
715-
timesteps,
716-
sigmas,
715+
sigmas=sigmas,
717716
mu=mu,
718717
)
719718
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)

src/diffusers/pipelines/flux/pipeline_flux_control.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,7 @@ def __call__(
621621
height: Optional[int] = None,
622622
width: Optional[int] = None,
623623
num_inference_steps: int = 28,
624-
timesteps: List[int] = None,
624+
sigmas: Optional[List[float]] = None,
625625
guidance_scale: float = 3.5,
626626
num_images_per_prompt: Optional[int] = 1,
627627
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -660,10 +660,10 @@ def __call__(
660660
num_inference_steps (`int`, *optional*, defaults to 50):
661661
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
662662
expense of slower inference.
663-
timesteps (`List[int]`, *optional*):
664-
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
665-
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
666-
passed will be used. Must be in descending order.
663+
sigmas (`List[float]`, *optional*):
664+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
665+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
666+
will be used.
667667
guidance_scale (`float`, *optional*, defaults to 7.0):
668668
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
669669
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -799,7 +799,7 @@ def __call__(
799799
)
800800

801801
# 5. Prepare timesteps
802-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
802+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
803803
image_seq_len = latents.shape[1]
804804
mu = calculate_shift(
805805
image_seq_len,
@@ -812,8 +812,7 @@ def __call__(
812812
self.scheduler,
813813
num_inference_steps,
814814
device,
815-
timesteps,
816-
sigmas,
815+
sigmas=sigmas,
817816
mu=mu,
818817
)
819818
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)

src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ def __call__(
647647
width: Optional[int] = None,
648648
strength: float = 0.6,
649649
num_inference_steps: int = 28,
650-
timesteps: List[int] = None,
650+
sigmas: Optional[List[float]] = None,
651651
guidance_scale: float = 7.0,
652652
num_images_per_prompt: Optional[int] = 1,
653653
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -698,10 +698,10 @@ def __call__(
698698
num_inference_steps (`int`, *optional*, defaults to 50):
699699
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
700700
expense of slower inference.
701-
timesteps (`List[int]`, *optional*):
702-
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
703-
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
704-
passed will be used. Must be in descending order.
701+
sigmas (`List[float]`, *optional*):
702+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
703+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
704+
will be used.
705705
guidance_scale (`float`, *optional*, defaults to 7.0):
706706
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
707707
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -805,7 +805,7 @@ def __call__(
805805
)
806806

807807
# 4.Prepare timesteps
808-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
808+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
809809
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
810810
mu = calculate_shift(
811811
image_seq_len,
@@ -818,8 +818,7 @@ def __call__(
818818
self.scheduler,
819819
num_inference_steps,
820820
device,
821-
timesteps,
822-
sigmas,
821+
sigmas=sigmas,
823822
mu=mu,
824823
)
825824
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ def __call__(
602602
height: Optional[int] = None,
603603
width: Optional[int] = None,
604604
num_inference_steps: int = 28,
605-
timesteps: List[int] = None,
605+
sigmas: Optional[List[float]] = None,
606606
guidance_scale: float = 7.0,
607607
control_guidance_start: Union[float, List[float]] = 0.0,
608608
control_guidance_end: Union[float, List[float]] = 1.0,
@@ -638,10 +638,10 @@ def __call__(
638638
num_inference_steps (`int`, *optional*, defaults to 50):
639639
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
640640
expense of slower inference.
641-
timesteps (`List[int]`, *optional*):
642-
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
643-
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
644-
passed will be used. Must be in descending order.
641+
sigmas (`List[float]`, *optional*):
642+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
643+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
644+
will be used.
645645
guidance_scale (`float`, *optional*, defaults to 7.0):
646646
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
647647
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -872,7 +872,7 @@ def __call__(
872872
)
873873

874874
# 5. Prepare timesteps
875-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
875+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
876876
image_seq_len = latents.shape[1]
877877
mu = calculate_shift(
878878
image_seq_len,
@@ -885,8 +885,7 @@ def __call__(
885885
self.scheduler,
886886
num_inference_steps,
887887
device,
888-
timesteps,
889-
sigmas,
888+
sigmas=sigmas,
890889
mu=mu,
891890
)
892891

src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ def __call__(
646646
width: Optional[int] = None,
647647
strength: float = 0.6,
648648
num_inference_steps: int = 28,
649-
timesteps: List[int] = None,
649+
sigmas: Optional[List[float]] = None,
650650
guidance_scale: float = 7.0,
651651
control_guidance_start: Union[float, List[float]] = 0.0,
652652
control_guidance_end: Union[float, List[float]] = 1.0,
@@ -685,8 +685,10 @@ def __call__(
685685
num_inference_steps (`int`, *optional*, defaults to 28):
686686
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
687687
expense of slower inference.
688-
timesteps (`List[int]`, *optional*):
689-
Custom timesteps to use for the denoising process.
688+
sigmas (`List[float]`, *optional*):
689+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
690+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
691+
will be used.
690692
guidance_scale (`float`, *optional*, defaults to 7.0):
691693
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
692694
control_mode (`int` or `List[int]`, *optional*):
@@ -858,7 +860,7 @@ def __call__(
858860
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
859861
control_mode = control_mode.reshape([-1, 1])
860862

861-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
863+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
862864
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
863865
mu = calculate_shift(
864866
image_seq_len,
@@ -871,8 +873,7 @@ def __call__(
871873
self.scheduler,
872874
num_inference_steps,
873875
device,
874-
timesteps,
875-
sigmas,
876+
sigmas=sigmas,
876877
mu=mu,
877878
)
878879
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)

src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ def __call__(
752752
width: Optional[int] = None,
753753
strength: float = 0.6,
754754
padding_mask_crop: Optional[int] = None,
755-
timesteps: List[int] = None,
755+
sigmas: Optional[List[float]] = None,
756756
num_inference_steps: int = 28,
757757
guidance_scale: float = 7.0,
758758
control_guidance_start: Union[float, List[float]] = 0.0,
@@ -799,8 +799,10 @@ def __call__(
799799
num_inference_steps (`int`, *optional*, defaults to 28):
800800
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
801801
expense of slower inference.
802-
timesteps (`List[int]`, *optional*):
803-
Custom timesteps to use for the denoising process.
802+
sigmas (`List[float]`, *optional*):
803+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
804+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
805+
will be used.
804806
guidance_scale (`float`, *optional*, defaults to 7.0):
805807
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
806808
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
@@ -1009,7 +1011,7 @@ def __call__(
10091011

10101012
# 6. Prepare timesteps
10111013

1012-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
1014+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
10131015
image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * (
10141016
int(global_width) // self.vae_scale_factor // 2
10151017
)
@@ -1024,8 +1026,7 @@ def __call__(
10241026
self.scheduler,
10251027
num_inference_steps,
10261028
device,
1027-
timesteps,
1028-
sigmas,
1029+
sigmas=sigmas,
10291030
mu=mu,
10301031
)
10311032
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)

src/diffusers/pipelines/flux/pipeline_flux_fill.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@ def __call__(
689689
height: Optional[int] = None,
690690
width: Optional[int] = None,
691691
num_inference_steps: int = 50,
692-
timesteps: List[int] = None,
692+
sigmas: Optional[List[float]] = None,
693693
guidance_scale: float = 30.0,
694694
num_images_per_prompt: Optional[int] = 1,
695695
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -735,10 +735,10 @@ def __call__(
735735
num_inference_steps (`int`, *optional*, defaults to 50):
736736
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
737737
expense of slower inference.
738-
timesteps (`List[int]`, *optional*):
739-
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
740-
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
741-
passed will be used. Must be in descending order.
738+
sigmas (`List[float]`, *optional*):
739+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
740+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
741+
will be used.
742742
guidance_scale (`float`, *optional*, defaults to 7.0):
743743
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
744744
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -878,7 +878,7 @@ def __call__(
878878
masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
879879

880880
# 6. Prepare timesteps
881-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
881+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
882882
image_seq_len = latents.shape[1]
883883
mu = calculate_shift(
884884
image_seq_len,
@@ -891,8 +891,7 @@ def __call__(
891891
self.scheduler,
892892
num_inference_steps,
893893
device,
894-
timesteps,
895-
sigmas,
894+
sigmas=sigmas,
896895
mu=mu,
897896
)
898897
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)

src/diffusers/pipelines/flux/pipeline_flux_img2img.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ def __call__(
593593
width: Optional[int] = None,
594594
strength: float = 0.6,
595595
num_inference_steps: int = 28,
596-
timesteps: List[int] = None,
596+
sigmas: Optional[List[float]] = None,
597597
guidance_scale: float = 7.0,
598598
num_images_per_prompt: Optional[int] = 1,
599599
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -636,10 +636,10 @@ def __call__(
636636
num_inference_steps (`int`, *optional*, defaults to 50):
637637
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
638638
expense of slower inference.
639-
timesteps (`List[int]`, *optional*):
640-
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
641-
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
642-
passed will be used. Must be in descending order.
639+
sigmas (`List[float]`, *optional*):
640+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
641+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
642+
will be used.
643643
guidance_scale (`float`, *optional*, defaults to 7.0):
644644
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
645645
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -742,7 +742,7 @@ def __call__(
742742
)
743743

744744
# 4.Prepare timesteps
745-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
745+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
746746
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
747747
mu = calculate_shift(
748748
image_seq_len,
@@ -755,8 +755,7 @@ def __call__(
755755
self.scheduler,
756756
num_inference_steps,
757757
device,
758-
timesteps,
759-
sigmas,
758+
sigmas=sigmas,
760759
mu=mu,
761760
)
762761
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)

0 commit comments

Comments
 (0)