Skip to content

Flux Fill, Canny, Depth, Redux #9985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion scripts/convert_flux_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--filename", default="flux.safetensors", type=str)
parser.add_argument("--checkpoint_path", default=None, type=str)
parser.add_argument("--in_channels", type=int, default=64)
parser.add_argument("--out_channels", type=int, default=None)
parser.add_argument("--vae", action="store_true")
parser.add_argument("--transformer", action="store_true")
parser.add_argument("--output_path", type=str)
Expand Down Expand Up @@ -279,10 +281,13 @@ def main(args):
num_single_layers = 38
inner_dim = 3072
mlp_ratio = 4.0

converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers(
original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
)
transformer = FluxTransformer2DModel(guidance_embeds=has_guidance)
transformer = FluxTransformer2DModel(
in_channels=args.in_channels, out_channels=args.out_channels, guidance_embeds=has_guidance
)
transformer.load_state_dict(converted_transformer_state_dict, strict=True)

print(
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@
"FluxControlNetImg2ImgPipeline",
"FluxControlNetInpaintPipeline",
"FluxControlNetPipeline",
"FluxFillPipeline",
"FluxImg2ImgPipeline",
"FluxInpaintPipeline",
"FluxPipeline",
Expand Down Expand Up @@ -737,6 +738,7 @@
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxPipeline,
Expand Down
7 changes: 5 additions & 2 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
out_channels: Optional[int] = None,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
Expand All @@ -248,7 +249,7 @@ def __init__(
axes_dims_rope: Tuple[int] = (16, 56, 56),
):
super().__init__()
self.out_channels = in_channels
self.out_channels = out_channels or in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim

self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
Expand All @@ -261,7 +262,7 @@ def __init__(
)

self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim)

self.transformer_blocks = nn.ModuleList(
[
Expand Down Expand Up @@ -449,13 +450,15 @@ def forward(
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)

hidden_states = self.x_embedder(hidden_states)

timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None

temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
"FluxImg2ImgPipeline",
"FluxInpaintPipeline",
"FluxPipeline",
"FluxFillPipeline",
]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
Expand Down Expand Up @@ -524,6 +525,7 @@
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/flux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
_import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
_import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
_import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"]
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Expand All @@ -39,6 +40,7 @@
from .pipeline_flux_controlnet import FluxControlNetPipeline
from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
from .pipeline_flux_fill import FluxFillPipeline
from .pipeline_flux_img2img import FluxImg2ImgPipeline
from .pipeline_flux_inpaint import FluxInpaintPipeline
else:
Expand Down
86 changes: 82 additions & 4 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast

from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
Expand Down Expand Up @@ -529,6 +529,41 @@ def prepare_latents(

return latents, latent_image_ids

# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The easiest way is probably to have 3 separate FluxControlPipeline? but maybe @DN6 can come up with better ideas!

def prepare_image(
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
guess_mode=False,
):
if isinstance(image, torch.Tensor):
pass
else:
image = self.image_processor.preprocess(image, height=height, width=width)

image_batch_size = image.shape[0]

if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt

image = image.repeat_interleave(repeat_by, dim=0)

image = image.to(device=device, dtype=dtype)

if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2)

return image

@property
def guidance_scale(self):
return self._guidance_scale
Expand Down Expand Up @@ -556,9 +591,11 @@ def __call__(
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 3.5,
control_image: PipelineImageInput = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
control_latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
Expand Down Expand Up @@ -595,6 +632,14 @@ def __call__(
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
images must be passed as a list such that each element of the list can be correctly batched for input
to a single ControlNet.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
Expand Down Expand Up @@ -667,6 +712,7 @@ def __call__(

device = self._execution_device

# 3. Prepare text embeddings
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
Expand All @@ -686,7 +732,35 @@ def __call__(
)

# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
num_channels_latents = (
self.transformer.config.in_channels // 4
if control_image is None
else self.transformer.config.in_channels // 8
)

if control_image is not None and control_latents is None:
control_image = self.prepare_image(
image=control_image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.vae.dtype,
)

control_latents = self.vae.encode(control_image).latent_dist.sample(generator=generator)
control_latents = (control_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor

height_control_image, width_control_image = control_latents.shape[2:]
control_latents = self._pack_latents(
control_latents,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)

latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
Expand Down Expand Up @@ -732,11 +806,16 @@ def __call__(
if self.interrupt:
continue

if control_latents is not None:
latent_model_input = torch.cat([latents, control_latents], dim=2)
else:
latent_model_input = latents

# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)

noise_pred = self.transformer(
hidden_states=latents,
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
Expand Down Expand Up @@ -774,7 +853,6 @@ def __call__(

if output_type == "latent":
image = latents

else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@ def __call__(
device = self._execution_device
dtype = self.transformer.dtype

# 3. Prepare text embeddings
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
Expand Down
Loading
Loading