Skip to content

Commit 7ac6e28

Browse files
a-r-r-o-wyiyixuxusayakpaul
authored
Flux Fill, Canny, Depth, Redux (#9985)
* update --------- Co-authored-by: yiyixuxu <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent b5fd6f1 commit 7ac6e28

18 files changed

+4189
-8
lines changed

docs/source/en/api/pipelines/flux.md

Lines changed: 149 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,20 @@ Flux can be quite expensive to run on consumer hardware devices. However, you ca
2222

2323
</Tip>
2424

25-
Flux comes in two variants:
25+
Flux comes in the following variants:
2626

27-
* Timestep-distilled (`black-forest-labs/FLUX.1-schnell`)
28-
* Guidance-distilled (`black-forest-labs/FLUX.1-dev`)
27+
| model type | model id |
28+
|:----------:|:--------:|
29+
| Timestep-distilled | [`black-forest-labs/FLUX.1-schnell`](https://huggingface.co/black-forest-labs/FLUX.1-schnell) |
30+
| Guidance-distilled | [`black-forest-labs/FLUX.1-dev`](https://huggingface.co/black-forest-labs/FLUX.1-dev) |
31+
| Fill Inpainting/Outpainting (Guidance-distilled) | [`black-forest-labs/FLUX.1-Fill-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev) |
32+
| Canny Control (Guidance-distilled) | [`black-forest-labs/FLUX.1-Canny-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev) |
33+
| Depth Control (Guidance-distilled) | [`black-forest-labs/FLUX.1-Depth-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev) |
34+
| Canny Control (LoRA) | [`black-forest-labs/FLUX.1-Canny-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora) |
35+
| Depth Control (LoRA) | [`black-forest-labs/FLUX.1-Depth-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora) |
36+
| Redux (Adapter) | [`black-forest-labs/FLUX.1-Redux-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) |
2937

30-
Both checkpoints have slightly difference usage which we detail below.
38+
All checkpoints have different usage which we detail below.
3139

3240
### Timestep-distilled
3341

@@ -77,7 +85,132 @@ out = pipe(
7785
out.save("image.png")
7886
```
7987

88+
### Fill Inpainting/Outpainting
89+
90+
* Flux Fill pipeline does not require `strength` as an input like regular inpainting pipelines.
91+
* It supports both inpainting and outpainting.
92+
93+
```python
94+
import torch
95+
from diffusers import FluxFillPipeline
96+
from diffusers.utils import load_image
97+
98+
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup.png")
99+
mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup_mask.png")
100+
101+
repo_id = "black-forest-labs/FLUX.1-Fill-dev"
102+
pipe = FluxFillPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to("cuda")
103+
104+
image = pipe(
105+
prompt="a white paper cup",
106+
image=image,
107+
mask_image=mask,
108+
height=1632,
109+
width=1232,
110+
max_sequence_length=512,
111+
generator=torch.Generator("cpu").manual_seed(0)
112+
).images[0]
113+
image.save(f"output.png")
114+
```
115+
116+
### Canny Control
117+
118+
**Note:** `black-forest-labs/Flux.1-Canny-dev` is _not_ a [`ControlNetModel`] model. ControlNet models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Canny Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible.
119+
120+
```python
121+
# !pip install -U controlnet-aux
122+
import torch
123+
from controlnet_aux import CannyDetector
124+
from diffusers import FluxControlPipeline
125+
from diffusers.utils import load_image
126+
127+
pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Canny-dev", torch_dtype=torch.bfloat16).to("cuda")
128+
129+
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
130+
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
131+
132+
processor = CannyDetector()
133+
control_image = processor(control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024)
134+
135+
image = pipe(
136+
prompt=prompt,
137+
control_image=control_image,
138+
height=1024,
139+
width=1024,
140+
num_inference_steps=50,
141+
guidance_scale=30.0,
142+
).images[0]
143+
image.save("output.png")
144+
```
145+
146+
### Depth Control
147+
148+
**Note:** `black-forest-labs/Flux.1-Depth-dev` is _not_ a ControlNet model. [`ControlNetModel`] models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Depth Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible.
149+
150+
```python
151+
# !pip install git+https://github.com/asomoza/image_gen_aux.git
152+
import torch
153+
from diffusers import FluxControlPipeline, FluxTransformer2DModel
154+
from diffusers.utils import load_image
155+
from image_gen_aux import DepthPreprocessor
156+
157+
pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Depth-dev", torch_dtype=torch.bfloat16).to("cuda")
158+
159+
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
160+
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
161+
162+
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
163+
control_image = processor(control_image)[0].convert("RGB")
164+
165+
image = pipe(
166+
prompt=prompt,
167+
control_image=control_image,
168+
height=1024,
169+
width=1024,
170+
num_inference_steps=30,
171+
guidance_scale=10.0,
172+
generator=torch.Generator().manual_seed(42),
173+
).images[0]
174+
image.save("output.png")
175+
```
176+
177+
### Redux
178+
179+
* Flux Redux pipeline is an adapter for FLUX.1 base models. It can be used with both flux-dev and flux-schnell, for image-to-image generation.
180+
* You can first use the `FluxPriorReduxPipeline` to get the `prompt_embeds` and `pooled_prompt_embeds`, and then feed them into the `FluxPipeline` for image-to-image generation.
181+
* When use `FluxPriorReduxPipeline` with a base pipeline, you can set `text_encoder=None` and `text_encoder_2=None` in the base pipeline, in order to save VRAM.
182+
183+
```python
184+
import torch
185+
from diffusers import FluxPriorReduxPipeline, FluxPipeline
186+
from diffusers.utils import load_image
187+
device = "cuda"
188+
dtype = torch.bfloat16
189+
190+
191+
repo_redux = "black-forest-labs/FLUX.1-Redux-dev"
192+
repo_base = "black-forest-labs/FLUX.1-dev"
193+
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(repo_redux, torch_dtype=dtype).to(device)
194+
pipe = FluxPipeline.from_pretrained(
195+
repo_base,
196+
text_encoder=None,
197+
text_encoder_2=None,
198+
torch_dtype=torch.bfloat16
199+
).to(device)
200+
201+
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png")
202+
pipe_prior_output = pipe_prior_redux(image)
203+
images = pipe(
204+
guidance_scale=2.5,
205+
num_inference_steps=50,
206+
generator=torch.Generator("cpu").manual_seed(0),
207+
**pipe_prior_output,
208+
).images
209+
images[0].save("flux-redux.png")
210+
```
211+
80212
## Running FP16 inference
213+
81214
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
82215

83216
FP16 inference code:
@@ -188,3 +321,15 @@ image.save("flux-fp8-dev.png")
188321
[[autodoc]] FluxControlNetImg2ImgPipeline
189322
- all
190323
- __call__
324+
325+
## FluxControlPipeline
326+
327+
[[autodoc]] FluxControlPipeline
328+
- all
329+
- __call__
330+
331+
## FluxControlImg2ImgPipeline
332+
333+
[[autodoc]] FluxControlImg2ImgPipeline
334+
- all
335+
- __call__

scripts/convert_flux_to_diffusers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
3838
parser.add_argument("--filename", default="flux.safetensors", type=str)
3939
parser.add_argument("--checkpoint_path", default=None, type=str)
40+
parser.add_argument("--in_channels", type=int, default=64)
41+
parser.add_argument("--out_channels", type=int, default=None)
4042
parser.add_argument("--vae", action="store_true")
4143
parser.add_argument("--transformer", action="store_true")
4244
parser.add_argument("--output_path", type=str)
@@ -279,10 +281,13 @@ def main(args):
279281
num_single_layers = 38
280282
inner_dim = 3072
281283
mlp_ratio = 4.0
284+
282285
converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers(
283286
original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
284287
)
285-
transformer = FluxTransformer2DModel(guidance_embeds=has_guidance)
288+
transformer = FluxTransformer2DModel(
289+
in_channels=args.in_channels, out_channels=args.out_channels, guidance_embeds=has_guidance
290+
)
286291
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
287292

288293
print(

src/diffusers/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,12 +269,16 @@
269269
"CogVideoXVideoToVideoPipeline",
270270
"CogView3PlusPipeline",
271271
"CycleDiffusionPipeline",
272+
"FluxControlImg2ImgPipeline",
272273
"FluxControlNetImg2ImgPipeline",
273274
"FluxControlNetInpaintPipeline",
274275
"FluxControlNetPipeline",
276+
"FluxControlPipeline",
277+
"FluxFillPipeline",
275278
"FluxImg2ImgPipeline",
276279
"FluxInpaintPipeline",
277280
"FluxPipeline",
281+
"FluxPriorReduxPipeline",
278282
"HunyuanDiTControlNetPipeline",
279283
"HunyuanDiTPAGPipeline",
280284
"HunyuanDiTPipeline",
@@ -321,6 +325,7 @@
321325
"PixArtAlphaPipeline",
322326
"PixArtSigmaPAGPipeline",
323327
"PixArtSigmaPipeline",
328+
"ReduxImageEncoder",
324329
"SemanticStableDiffusionPipeline",
325330
"ShapEImg2ImgPipeline",
326331
"ShapEPipeline",
@@ -734,12 +739,16 @@
734739
CogVideoXVideoToVideoPipeline,
735740
CogView3PlusPipeline,
736741
CycleDiffusionPipeline,
742+
FluxControlImg2ImgPipeline,
737743
FluxControlNetImg2ImgPipeline,
738744
FluxControlNetInpaintPipeline,
739745
FluxControlNetPipeline,
746+
FluxControlPipeline,
747+
FluxFillPipeline,
740748
FluxImg2ImgPipeline,
741749
FluxInpaintPipeline,
742750
FluxPipeline,
751+
FluxPriorReduxPipeline,
743752
HunyuanDiTControlNetPipeline,
744753
HunyuanDiTPAGPipeline,
745754
HunyuanDiTPipeline,
@@ -786,6 +795,7 @@
786795
PixArtAlphaPipeline,
787796
PixArtSigmaPAGPipeline,
788797
PixArtSigmaPipeline,
798+
ReduxImageEncoder,
789799
SemanticStableDiffusionPipeline,
790800
ShapEImg2ImgPipeline,
791801
ShapEPipeline,

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def __init__(
238238
self,
239239
patch_size: int = 1,
240240
in_channels: int = 64,
241+
out_channels: Optional[int] = None,
241242
num_layers: int = 19,
242243
num_single_layers: int = 38,
243244
attention_head_dim: int = 128,
@@ -248,7 +249,7 @@ def __init__(
248249
axes_dims_rope: Tuple[int] = (16, 56, 56),
249250
):
250251
super().__init__()
251-
self.out_channels = in_channels
252+
self.out_channels = out_channels or in_channels
252253
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
253254

254255
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
@@ -261,7 +262,7 @@ def __init__(
261262
)
262263

263264
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
264-
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
265+
self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim)
265266

266267
self.transformer_blocks = nn.ModuleList(
267268
[
@@ -449,13 +450,15 @@ def forward(
449450
logger.warning(
450451
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
451452
)
453+
452454
hidden_states = self.x_embedder(hidden_states)
453455

454456
timestep = timestep.to(hidden_states.dtype) * 1000
455457
if guidance is not None:
456458
guidance = guidance.to(hidden_states.dtype) * 1000
457459
else:
458460
guidance = None
461+
459462
temb = (
460463
self.time_text_embed(timestep, pooled_projections)
461464
if guidance is None

src/diffusers/pipelines/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,17 @@
127127
"AnimateDiffVideoToVideoControlNetPipeline",
128128
]
129129
_import_structure["flux"] = [
130+
"FluxControlPipeline",
131+
"FluxControlImg2ImgPipeline",
130132
"FluxControlNetPipeline",
131133
"FluxControlNetImg2ImgPipeline",
132134
"FluxControlNetInpaintPipeline",
133135
"FluxImg2ImgPipeline",
134136
"FluxInpaintPipeline",
135137
"FluxPipeline",
138+
"FluxFillPipeline",
139+
"FluxPriorReduxPipeline",
140+
"ReduxImageEncoder",
136141
]
137142
_import_structure["audioldm"] = ["AudioLDMPipeline"]
138143
_import_structure["audioldm2"] = [
@@ -521,12 +526,17 @@
521526
VQDiffusionPipeline,
522527
)
523528
from .flux import (
529+
FluxControlImg2ImgPipeline,
524530
FluxControlNetImg2ImgPipeline,
525531
FluxControlNetInpaintPipeline,
526532
FluxControlNetPipeline,
533+
FluxControlPipeline,
534+
FluxFillPipeline,
527535
FluxImg2ImgPipeline,
528536
FluxInpaintPipeline,
529537
FluxPipeline,
538+
FluxPriorReduxPipeline,
539+
ReduxImageEncoder,
530540
)
531541
from .hunyuandit import HunyuanDiTPipeline
532542
from .i2vgen_xl import I2VGenXLPipeline

src/diffusers/pipelines/flux/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
_dummy_objects = {}
1414
_additional_imports = {}
15-
_import_structure = {"pipeline_output": ["FluxPipelineOutput"]}
15+
_import_structure = {"pipeline_output": ["FluxPipelineOutput", "FluxPriorReduxPipelineOutput"]}
1616

1717
try:
1818
if not (is_transformers_available() and is_torch_available()):
@@ -22,25 +22,35 @@
2222

2323
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2424
else:
25+
_import_structure["modeling_flux"] = ["ReduxImageEncoder"]
2526
_import_structure["pipeline_flux"] = ["FluxPipeline"]
27+
_import_structure["pipeline_flux_control"] = ["FluxControlPipeline"]
28+
_import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"]
2629
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
2730
_import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
2831
_import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
32+
_import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"]
2933
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
3034
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
35+
_import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"]
3136
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
3237
try:
3338
if not (is_transformers_available() and is_torch_available()):
3439
raise OptionalDependencyNotAvailable()
3540
except OptionalDependencyNotAvailable:
3641
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
3742
else:
43+
from .modeling_flux import ReduxImageEncoder
3844
from .pipeline_flux import FluxPipeline
45+
from .pipeline_flux_control import FluxControlPipeline
46+
from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline
3947
from .pipeline_flux_controlnet import FluxControlNetPipeline
4048
from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
4149
from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
50+
from .pipeline_flux_fill import FluxFillPipeline
4251
from .pipeline_flux_img2img import FluxImg2ImgPipeline
4352
from .pipeline_flux_inpaint import FluxInpaintPipeline
53+
from .pipeline_flux_prior_redux import FluxPriorReduxPipeline
4454
else:
4555
import sys
4656

0 commit comments

Comments
 (0)