Skip to content

[hybrid inference 🍯🐝] Add VAE encode #11017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@
title: Overview
- local: hybrid_inference/vae_decode
title: VAE Decode
- local: hybrid_inference/vae_encode
title: VAE Encode
- local: hybrid_inference/api_reference
title: API Reference
title: Hybrid Inference
Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/hybrid_inference/api_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@
## Remote Decode

[[autodoc]] utils.remote_utils.remote_decode

## Remote Encode

[[autodoc]] utils.remote_utils.remote_encode
10 changes: 8 additions & 2 deletions docs/source/en/hybrid_inference/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
## Available Models

* **VAE Decode 🖼️:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed.
* **VAE Encode 🔢 (coming soon):** Efficiently encode images into latent representations for generation and training.
* **VAE Encode 🔢:** Efficiently encode images into latent representations for generation and training.
* **Text Encoders 📃 (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow.

---
Expand All @@ -46,9 +46,15 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.

## Changelog

- March 10 2025: Added VAE encode
- March 2 2025: Initial release.

## Contents

The documentation is organized into two sections:
The documentation is organized into three sections:

* **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference.
* **VAE Encode** Learn the basics of how to use VAE Encode with Hybrid Inference.
* **API Reference** Dive into task-specific settings and parameters.
183 changes: 183 additions & 0 deletions docs/source/en/hybrid_inference/vae_encode.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Getting Started: VAE Encode with Hybrid Inference

VAE encode is used for training, image-to-image and image-to-video - turning into images or videos into latent representations.

## Memory

These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs.

For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled encoding has to be used which increases time taken and impacts quality.

<details><summary>SD v1.5</summary>

| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 |
| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |


</details>

<details><summary>SDXL</summary>

| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |

</details>

## Available VAEs

| | **Endpoint** | **Model** |
|:-:|:-----------:|:--------:|
| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |


> [!TIP]
> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).


## Code

> [!TIP]
> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`


A helper method simplifies interacting with Hybrid Inference.

```python
from diffusers.utils.remote_utils import remote_encode
```

### Basic example

Let's encode an image, then decode it to demonstrate.

<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"/>
</figure>

<details><summary>Code</summary>

```python
from diffusers.utils import load_image
from diffusers.utils.remote_utils import remote_decode

image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true")

latent = remote_encode(
endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/",
scaling_factor=0.3611,
shift_factor=0.1159,
)

decoded = remote_decode(
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
scaling_factor=0.3611,
shift_factor=0.1159,
Comment on lines +116 to +117
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a merge blocker but we could probably make a note for the users about how to know these values (i.e., by seeing the config values here).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The values are in the docstrings, and I think it's unlikely for an end user to be using this themselves, most usage will come from integrations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reviewing more models I'm considering keeping do_scaling anyway. For example, Wan doesn't have a scaling_factor, it has latents_mean/latents_std.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's unlikely for an end user to be using this themselves, most usage will come from integrations.

No strong opinions but having it documented would be better than not having it documented I guess to cover a broader user base.

After reviewing more models I'm considering keeping do_scaling anyway. For example, Wan doesn't have a scaling_factor, it has latents_mean/latents_std.

Would introducing something like scaling_kwargs make sense? We could define partial/full scaling functions on per-model basis to mitigate any confusions.

)
```

</details>

<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/decoded.png"/>
</figure>


### Generation

Now let's look at a generation example, we'll encode the image, generate then remotely decode too!

<details><summary>Code</summary>

```python
import torch
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers.utils import load_image
from diffusers.utils.remote_utils import remote_decode, remote_encode

pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
variant="fp16",
vae=None,
).to("cuda")

init_image = load_image(
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((768, 512))

init_latent = remote_encode(
endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/",
image=init_image,
scaling_factor=0.18215,
)

prompt = "A fantasy landscape, trending on artstation"
latent = pipe(
prompt=prompt,
image=init_latent,
strength=0.75,
output_type="latent",
).images

image = remote_decode(
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
scaling_factor=0.18215,
)
image.save("fantasy_landscape.jpg")
```

</details>

<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/fantasy_landscape.png"/>
</figure>

## Integrations

* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
114 changes: 108 additions & 6 deletions src/diffusers/utils/remote_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@
from PIL import Image


DECODE_ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/"
DECODE_ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
DECODE_ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"


ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"
ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/"


def detect_image_type(data: bytes) -> str:
if data.startswith(b"\xff\xd8"):
return "jpeg"
Expand All @@ -55,7 +66,7 @@ def detect_image_type(data: bytes) -> str:
return "unknown"


def check_inputs(
def check_inputs_decode(
endpoint: str,
tensor: "torch.Tensor",
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
Expand Down Expand Up @@ -89,7 +100,7 @@ def check_inputs(
)


def postprocess(
def postprocess_decode(
response: requests.Response,
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
output_type: Literal["mp4", "pil", "pt"] = "pil",
Expand Down Expand Up @@ -142,7 +153,7 @@ def postprocess(
return output


def prepare(
def prepare_decode(
tensor: "torch.Tensor",
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
do_scaling: bool = True,
Expand Down Expand Up @@ -293,7 +304,7 @@ def remote_decode(
standard_warn=False,
)
output_tensor_type = "binary"
check_inputs(
check_inputs_decode(
endpoint,
tensor,
processor,
Expand All @@ -309,7 +320,7 @@ def remote_decode(
height,
width,
)
kwargs = prepare(
kwargs = prepare_decode(
tensor=tensor,
processor=processor,
do_scaling=do_scaling,
Expand All @@ -324,11 +335,102 @@ def remote_decode(
response = requests.post(endpoint, **kwargs)
if not response.ok:
raise RuntimeError(response.json())
output = postprocess(
output = postprocess_decode(
response=response,
processor=processor,
output_type=output_type,
return_type=return_type,
partial_postprocess=partial_postprocess,
)
return output


def check_inputs_encode(
endpoint: str,
image: Union["torch.Tensor", Image.Image],
scaling_factor: Optional[float] = None,
shift_factor: Optional[float] = None,
):
pass


def postprocess_encode(
response: requests.Response,
):
output_tensor = response.content
parameters = response.headers
shape = json.loads(parameters["shape"])
dtype = parameters["dtype"]
torch_dtype = DTYPE_MAP[dtype]
output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape)
return output_tensor


def prepare_encode(
image: Union["torch.Tensor", Image.Image],
scaling_factor: Optional[float] = None,
shift_factor: Optional[float] = None,
):
headers = {}
parameters = {}
if scaling_factor is not None:
parameters["scaling_factor"] = scaling_factor
if shift_factor is not None:
parameters["shift_factor"] = shift_factor
if isinstance(image, torch.Tensor):
data = safetensors.torch._tobytes(image, "tensor")
parameters["shape"] = list(image.shape)
parameters["dtype"] = str(image.dtype).split(".")[-1]
else:
buffer = io.BytesIO()
image.save(buffer, format="PNG")
data = buffer.getvalue()
return {"data": data, "params": parameters, "headers": headers}


def remote_encode(
endpoint: str,
image: Union["torch.Tensor", Image.Image],
scaling_factor: Optional[float] = None,
shift_factor: Optional[float] = None,
) -> "torch.Tensor":
"""
Hugging Face Hybrid Inference that allow running VAE encode remotely.

Args:
endpoint (`str`):
Endpoint for Remote Decode.
image (`torch.Tensor` or `PIL.Image.Image`):
Image to be encoded.
scaling_factor (`float`, *optional*):
Scaling is applied when passed e.g. [`latents * self.vae.config.scaling_factor`].
- SD v1: 0.18215
- SD XL: 0.13025
- Flux: 0.3611
If `None`, input must be passed with scaling applied.
shift_factor (`float`, *optional*):
Shift is applied when passed e.g. `latents - self.vae.config.shift_factor`.
- Flux: 0.1159
If `None`, input must be passed with scaling applied.

Returns:
output (`torch.Tensor`).
"""
check_inputs_encode(
endpoint,
image,
scaling_factor,
shift_factor,
)
kwargs = prepare_encode(
image=image,
scaling_factor=scaling_factor,
shift_factor=shift_factor,
)
response = requests.post(endpoint, **kwargs)
if not response.ok:
raise RuntimeError(response.json())
output = postprocess_encode(
response=response,
)
return output
Loading
Loading