Skip to content

Commit 071807c

Browse files
[training] feat: enable quantization for hidream lora training. (#11494)
* feat: enable quantization for hidream lora training. * better handle compute dtype. * finalize. * fix dtype. --------- Co-authored-by: Linoy Tsaban <[email protected]>
1 parent ee1516e commit 071807c

File tree

3 files changed

+68
-17
lines changed

3 files changed

+68
-17
lines changed

examples/dreambooth/README_hidream.md

+27
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,30 @@ We provide several options for optimizing memory optimization:
117117
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
118118

119119
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model.
120+
121+
## Using quantization
122+
123+
You can quantize the base model with [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/index) to reduce memory usage. To do so, pass a JSON file path to `--bnb_quantization_config_path`. This file should hold the configuration to initialize `BitsAndBytesConfig`. Below is an example JSON file:
124+
125+
```json
126+
{
127+
"load_in_4bit": true,
128+
"bnb_4bit_quant_type": "nf4"
129+
}
130+
```
131+
132+
Below, we provide some numbers with and without the use of NF4 quantization when training:
133+
134+
```
135+
(with quantization)
136+
Memory (before device placement): 9.085089683532715 GB.
137+
Memory (after device placement): 34.59585428237915 GB.
138+
Memory (after backward): 36.90267467498779 GB.
139+
140+
(without quantization)
141+
Memory (before device placement): 0.0 GB.
142+
Memory (after device placement): 57.6400408744812 GB.
143+
Memory (after backward): 59.932212829589844 GB.
144+
```
145+
146+
The reason why we see some memory before device placement in the case of quantization is because, by default bnb quantized models are placed on the GPU first.

examples/dreambooth/train_dreambooth_lora_hidream.py

+40-16
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import argparse
1717
import copy
1818
import itertools
19+
import json
1920
import logging
2021
import math
2122
import os
@@ -27,14 +28,13 @@
2728

2829
import numpy as np
2930
import torch
30-
import torch.utils.checkpoint
3131
import transformers
3232
from accelerate import Accelerator
3333
from accelerate.logging import get_logger
3434
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
3535
from huggingface_hub import create_repo, upload_folder
3636
from huggingface_hub.utils import insecure_hashlib
37-
from peft import LoraConfig, set_peft_model_state_dict
37+
from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict
3838
from peft.utils import get_peft_model_state_dict
3939
from PIL import Image
4040
from PIL.ImageOps import exif_transpose
@@ -47,6 +47,7 @@
4747
import diffusers
4848
from diffusers import (
4949
AutoencoderKL,
50+
BitsAndBytesConfig,
5051
FlowMatchEulerDiscreteScheduler,
5152
HiDreamImagePipeline,
5253
HiDreamImageTransformer2DModel,
@@ -282,6 +283,12 @@ def parse_args(input_args=None):
282283
default="meta-llama/Meta-Llama-3.1-8B-Instruct",
283284
help="Path to pretrained model or model identifier from huggingface.co/models.",
284285
)
286+
parser.add_argument(
287+
"--bnb_quantization_config_path",
288+
type=str,
289+
default=None,
290+
help="Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.",
291+
)
285292
parser.add_argument(
286293
"--revision",
287294
type=str,
@@ -1056,6 +1063,14 @@ def main(args):
10561063
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3"
10571064
)
10581065

1066+
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
1067+
# as these weights are only used for inference, keeping weights in full precision is not required.
1068+
weight_dtype = torch.float32
1069+
if accelerator.mixed_precision == "fp16":
1070+
weight_dtype = torch.float16
1071+
elif accelerator.mixed_precision == "bf16":
1072+
weight_dtype = torch.bfloat16
1073+
10591074
# Load scheduler and models
10601075
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
10611076
args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision, shift=3.0
@@ -1064,20 +1079,31 @@ def main(args):
10641079
text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders(
10651080
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
10661081
)
1067-
10681082
vae = AutoencoderKL.from_pretrained(
10691083
args.pretrained_model_name_or_path,
10701084
subfolder="vae",
10711085
revision=args.revision,
10721086
variant=args.variant,
10731087
)
1088+
quantization_config = None
1089+
if args.bnb_quantization_config_path is not None:
1090+
with open(args.bnb_quantization_config_path, "r") as f:
1091+
config_kwargs = json.load(f)
1092+
if "load_in_4bit" in config_kwargs and config_kwargs["load_in_4bit"]:
1093+
config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype
1094+
quantization_config = BitsAndBytesConfig(**config_kwargs)
1095+
10741096
transformer = HiDreamImageTransformer2DModel.from_pretrained(
10751097
args.pretrained_model_name_or_path,
10761098
subfolder="transformer",
10771099
revision=args.revision,
10781100
variant=args.variant,
1101+
quantization_config=quantization_config,
1102+
torch_dtype=weight_dtype,
10791103
force_inference_output=True,
10801104
)
1105+
if args.bnb_quantization_config_path is not None:
1106+
transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)
10811107

10821108
# We only train the additional adapter LoRA layers
10831109
transformer.requires_grad_(False)
@@ -1087,14 +1113,6 @@ def main(args):
10871113
text_encoder_three.requires_grad_(False)
10881114
text_encoder_four.requires_grad_(False)
10891115

1090-
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
1091-
# as these weights are only used for inference, keeping weights in full precision is not required.
1092-
weight_dtype = torch.float32
1093-
if accelerator.mixed_precision == "fp16":
1094-
weight_dtype = torch.float16
1095-
elif accelerator.mixed_precision == "bf16":
1096-
weight_dtype = torch.bfloat16
1097-
10981116
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
10991117
# due to pytorch#99272, MPS does not yet support bfloat16.
11001118
raise ValueError(
@@ -1109,7 +1127,12 @@ def main(args):
11091127
text_encoder_three.to(**to_kwargs)
11101128
text_encoder_four.to(**to_kwargs)
11111129
# we never offload the transformer to CPU, so we can just use the accelerator device
1112-
transformer.to(accelerator.device, dtype=weight_dtype)
1130+
transformer_to_kwargs = (
1131+
{"device": accelerator.device}
1132+
if args.bnb_quantization_config_path is not None
1133+
else {"device": accelerator.device, "dtype": weight_dtype}
1134+
)
1135+
transformer.to(**transformer_to_kwargs)
11131136

11141137
# Initialize a text encoding pipeline and keep it to CPU for now.
11151138
text_encoding_pipeline = HiDreamImagePipeline.from_pretrained(
@@ -1695,10 +1718,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16951718
accelerator.wait_for_everyone()
16961719
if accelerator.is_main_process:
16971720
transformer = unwrap_model(transformer)
1698-
if args.upcast_before_saving:
1699-
transformer.to(torch.float32)
1700-
else:
1701-
transformer = transformer.to(weight_dtype)
1721+
if args.bnb_quantization_config_path is None:
1722+
if args.upcast_before_saving:
1723+
transformer.to(torch.float32)
1724+
else:
1725+
transformer = transformer.to(weight_dtype)
17021726
transformer_lora_layers = get_peft_model_state_dict(transformer)
17031727

17041728
HiDreamImagePipeline.save_lora_weights(

src/diffusers/quantizers/quantization_config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
179179
This is a wrapper class about all possible attributes and features that you can play with a model that has been
180180
loaded using `bitsandbytes`.
181181
182-
This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive.
182+
This replaces `load_in_8bit` or `load_in_4bit` therefore both options are mutually exclusive.
183183
184184
Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`,
185185
then more arguments will be added to this class.

0 commit comments

Comments
 (0)