Skip to content

Commit b440f09

Browse files
XinyuYe-IntelJimmy
authored and
Jimmy
committed
Added distillation for quantization example on textual inversion. (huggingface#2760)
* Added distillation for quantization example on textual inversion. Signed-off-by: Ye, Xinyu <[email protected]> * refined readme and code style. Signed-off-by: Ye, Xinyu <[email protected]> * Update text2images.py * refined code of model load and added compatibility check. Signed-off-by: Ye, Xinyu <[email protected]> * fixed code style. Signed-off-by: Ye, Xinyu <[email protected]> * fix C403 [*] Unnecessary `list` comprehension (rewrite as a `set` comprehension) Signed-off-by: Ye, Xinyu <[email protected]> --------- Signed-off-by: Ye, Xinyu <[email protected]>
1 parent e98cd25 commit b440f09

File tree

4 files changed

+1230
-0
lines changed

4 files changed

+1230
-0
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Distillation for quantization on Textual Inversion models to personalize text2image
2+
3+
[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images._By using just 3-5 images new concepts can be taught to Stable Diffusion and the model personalized on your own images_
4+
The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
5+
We have enabled distillation for quantization in `textual_inversion.py` to do quantization aware training as well as distillation on the model generated by Textual Inversion method.
6+
7+
## Installing the dependencies
8+
9+
Before running the scripts, make sure to install the library's training dependencies:
10+
11+
```bash
12+
pip install -r requirements.txt
13+
```
14+
15+
## Prepare Datasets
16+
17+
One picture which is from the huggingface datasets [sd-concepts-library/dicoo2](https://huggingface.co/sd-concepts-library/dicoo2) is needed, and save it to the `./dicoo` directory. The picture is shown below:
18+
19+
<a href="https://huggingface.co/sd-concepts-library/dicoo2/blob/main/concept_images/1.jpeg">
20+
<img src="https://huggingface.co/sd-concepts-library/dicoo2/resolve/main/concept_images/1.jpeg" width = "300" height="300">
21+
</a>
22+
23+
## Get a FP32 Textual Inversion model
24+
25+
Use the following command to fine-tune the Stable Diffusion model on the above dataset to obtain the FP32 Textual Inversion model.
26+
27+
```bash
28+
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
29+
export DATA_DIR="./dicoo"
30+
31+
accelerate launch textual_inversion.py \
32+
--pretrained_model_name_or_path=$MODEL_NAME \
33+
--train_data_dir=$DATA_DIR \
34+
--learnable_property="object" \
35+
--placeholder_token="<dicoo>" --initializer_token="toy" \
36+
--resolution=512 \
37+
--train_batch_size=1 \
38+
--gradient_accumulation_steps=4 \
39+
--max_train_steps=3000 \
40+
--learning_rate=5.0e-04 --scale_lr \
41+
--lr_scheduler="constant" \
42+
--lr_warmup_steps=0 \
43+
--output_dir="dicoo_model"
44+
```
45+
46+
## Do distillation for quantization
47+
48+
Distillation for quantization is a method that combines [intermediate layer knowledge distillation](https://github.com/intel/neural-compressor/blob/master/docs/source/distillation.md#intermediate-layer-knowledge-distillation) and [quantization aware training](https://github.com/intel/neural-compressor/blob/master/docs/source/quantization.md#quantization-aware-training) in the same training process to improve the performance of the quantized model. Provided a FP32 model, the distillation for quantization approach will take this model itself as the teacher model and transfer the knowledges of the specified layers to the student model, i.e. quantized version of the FP32 model, during the quantization aware training process.
49+
50+
Once you have the FP32 Textual Inversion model, the following command will take the FP32 Textual Inversion model as input to do distillation for quantization and generate the INT8 Textual Inversion model.
51+
52+
```bash
53+
export FP32_MODEL_NAME="./dicoo_model"
54+
export DATA_DIR="./dicoo"
55+
56+
accelerate launch textual_inversion.py \
57+
--pretrained_model_name_or_path=$FP32_MODEL_NAME \
58+
--train_data_dir=$DATA_DIR \
59+
--use_ema --learnable_property="object" \
60+
--placeholder_token="<dicoo>" --initializer_token="toy" \
61+
--resolution=512 \
62+
--train_batch_size=1 \
63+
--gradient_accumulation_steps=4 \
64+
--max_train_steps=300 \
65+
--learning_rate=5.0e-04 --max_grad_norm=3 \
66+
--lr_scheduler="constant" \
67+
--lr_warmup_steps=0 \
68+
--output_dir="int8_model" \
69+
--do_quantization --do_distillation --verify_loading
70+
```
71+
72+
After the distillation for quantization process, the quantized UNet would be 4 times smaller (3279MB -> 827MB).
73+
74+
## Inference
75+
76+
Once you have trained a INT8 model with the above command, the inference can be done simply using the `text2images.py` script. Make sure to include the `placeholder_token` in your prompt.
77+
78+
```bash
79+
export INT8_MODEL_NAME="./int8_model"
80+
81+
python text2images.py \
82+
--pretrained_model_name_or_path=$INT8_MODEL_NAME \
83+
--caption "a lovely <dicoo> in red dress and hat, in the snowly and brightly night, with many brighly buildings." \
84+
--images_num 4
85+
```
86+
87+
Here is the comparison of images generated by the FP32 model (left) and INT8 model (right) respectively:
88+
89+
<p float="left">
90+
<img src="https://huggingface.co/datasets/Intel/textual_inversion_dicoo_dfq/resolve/main/FP32.png" width = "300" height = "300" alt="FP32" align=center />
91+
<img src="https://huggingface.co/datasets/Intel/textual_inversion_dicoo_dfq/resolve/main/INT8.png" width = "300" height = "300" alt="INT8" align=center />
92+
</p>
93+
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
accelerate
2+
torchvision
3+
transformers>=4.25.0
4+
ftfy
5+
tensorboard
6+
modelcards
7+
neural-compressor
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import argparse
2+
import math
3+
import os
4+
5+
import torch
6+
from neural_compressor.utils.pytorch import load
7+
from PIL import Image
8+
from transformers import CLIPTextModel, CLIPTokenizer
9+
10+
from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel
11+
12+
13+
def parse_args():
14+
parser = argparse.ArgumentParser()
15+
parser.add_argument(
16+
"-m",
17+
"--pretrained_model_name_or_path",
18+
type=str,
19+
default=None,
20+
required=True,
21+
help="Path to pretrained model or model identifier from huggingface.co/models.",
22+
)
23+
parser.add_argument(
24+
"-c",
25+
"--caption",
26+
type=str,
27+
default="robotic cat with wings",
28+
help="Text used to generate images.",
29+
)
30+
parser.add_argument(
31+
"-n",
32+
"--images_num",
33+
type=int,
34+
default=4,
35+
help="How much images to generate.",
36+
)
37+
parser.add_argument(
38+
"-s",
39+
"--seed",
40+
type=int,
41+
default=42,
42+
help="Seed for random process.",
43+
)
44+
parser.add_argument(
45+
"-ci",
46+
"--cuda_id",
47+
type=int,
48+
default=0,
49+
help="cuda_id.",
50+
)
51+
args = parser.parse_args()
52+
return args
53+
54+
55+
def image_grid(imgs, rows, cols):
56+
if not len(imgs) == rows * cols:
57+
raise ValueError("The specified number of rows and columns are not correct.")
58+
59+
w, h = imgs[0].size
60+
grid = Image.new("RGB", size=(cols * w, rows * h))
61+
grid_w, grid_h = grid.size
62+
63+
for i, img in enumerate(imgs):
64+
grid.paste(img, box=(i % cols * w, i // cols * h))
65+
return grid
66+
67+
68+
def generate_images(
69+
pipeline,
70+
prompt="robotic cat with wings",
71+
guidance_scale=7.5,
72+
num_inference_steps=50,
73+
num_images_per_prompt=1,
74+
seed=42,
75+
):
76+
generator = torch.Generator(pipeline.device).manual_seed(seed)
77+
images = pipeline(
78+
prompt,
79+
guidance_scale=guidance_scale,
80+
num_inference_steps=num_inference_steps,
81+
generator=generator,
82+
num_images_per_prompt=num_images_per_prompt,
83+
).images
84+
_rows = int(math.sqrt(num_images_per_prompt))
85+
grid = image_grid(images, rows=_rows, cols=num_images_per_prompt // _rows)
86+
return grid, images
87+
88+
89+
args = parse_args()
90+
# Load models and create wrapper for stable diffusion
91+
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
92+
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
93+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
94+
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
95+
96+
pipeline = StableDiffusionPipeline.from_pretrained(
97+
args.pretrained_model_name_or_path, text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer
98+
)
99+
pipeline.safety_checker = lambda images, clip_input: (images, False)
100+
if os.path.exists(os.path.join(args.pretrained_model_name_or_path, "best_model.pt")):
101+
unet = load(args.pretrained_model_name_or_path, model=unet)
102+
unet.eval()
103+
setattr(pipeline, "unet", unet)
104+
else:
105+
unet = unet.to(torch.device("cuda", args.cuda_id))
106+
pipeline = pipeline.to(unet.device)
107+
grid, images = generate_images(pipeline, prompt=args.caption, num_images_per_prompt=args.images_num, seed=args.seed)
108+
grid.save(os.path.join(args.pretrained_model_name_or_path, "{}.png".format("_".join(args.caption.split()))))
109+
dirname = os.path.join(args.pretrained_model_name_or_path, "_".join(args.caption.split()))
110+
os.makedirs(dirname, exist_ok=True)
111+
for idx, image in enumerate(images):
112+
image.save(os.path.join(dirname, "{}.png".format(idx + 1)))

0 commit comments

Comments
 (0)