|
| 1 | +import argparse |
| 2 | +import math |
| 3 | +import os |
| 4 | + |
| 5 | +import torch |
| 6 | +from neural_compressor.utils.pytorch import load |
| 7 | +from PIL import Image |
| 8 | +from transformers import CLIPTextModel, CLIPTokenizer |
| 9 | + |
| 10 | +from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel |
| 11 | + |
| 12 | + |
| 13 | +def parse_args(): |
| 14 | + parser = argparse.ArgumentParser() |
| 15 | + parser.add_argument( |
| 16 | + "-m", |
| 17 | + "--pretrained_model_name_or_path", |
| 18 | + type=str, |
| 19 | + default=None, |
| 20 | + required=True, |
| 21 | + help="Path to pretrained model or model identifier from huggingface.co/models.", |
| 22 | + ) |
| 23 | + parser.add_argument( |
| 24 | + "-c", |
| 25 | + "--caption", |
| 26 | + type=str, |
| 27 | + default="robotic cat with wings", |
| 28 | + help="Text used to generate images.", |
| 29 | + ) |
| 30 | + parser.add_argument( |
| 31 | + "-n", |
| 32 | + "--images_num", |
| 33 | + type=int, |
| 34 | + default=4, |
| 35 | + help="How much images to generate.", |
| 36 | + ) |
| 37 | + parser.add_argument( |
| 38 | + "-s", |
| 39 | + "--seed", |
| 40 | + type=int, |
| 41 | + default=42, |
| 42 | + help="Seed for random process.", |
| 43 | + ) |
| 44 | + parser.add_argument( |
| 45 | + "-ci", |
| 46 | + "--cuda_id", |
| 47 | + type=int, |
| 48 | + default=0, |
| 49 | + help="cuda_id.", |
| 50 | + ) |
| 51 | + args = parser.parse_args() |
| 52 | + return args |
| 53 | + |
| 54 | + |
| 55 | +def image_grid(imgs, rows, cols): |
| 56 | + if not len(imgs) == rows * cols: |
| 57 | + raise ValueError("The specified number of rows and columns are not correct.") |
| 58 | + |
| 59 | + w, h = imgs[0].size |
| 60 | + grid = Image.new("RGB", size=(cols * w, rows * h)) |
| 61 | + grid_w, grid_h = grid.size |
| 62 | + |
| 63 | + for i, img in enumerate(imgs): |
| 64 | + grid.paste(img, box=(i % cols * w, i // cols * h)) |
| 65 | + return grid |
| 66 | + |
| 67 | + |
| 68 | +def generate_images( |
| 69 | + pipeline, |
| 70 | + prompt="robotic cat with wings", |
| 71 | + guidance_scale=7.5, |
| 72 | + num_inference_steps=50, |
| 73 | + num_images_per_prompt=1, |
| 74 | + seed=42, |
| 75 | +): |
| 76 | + generator = torch.Generator(pipeline.device).manual_seed(seed) |
| 77 | + images = pipeline( |
| 78 | + prompt, |
| 79 | + guidance_scale=guidance_scale, |
| 80 | + num_inference_steps=num_inference_steps, |
| 81 | + generator=generator, |
| 82 | + num_images_per_prompt=num_images_per_prompt, |
| 83 | + ).images |
| 84 | + _rows = int(math.sqrt(num_images_per_prompt)) |
| 85 | + grid = image_grid(images, rows=_rows, cols=num_images_per_prompt // _rows) |
| 86 | + return grid, images |
| 87 | + |
| 88 | + |
| 89 | +args = parse_args() |
| 90 | +# Load models and create wrapper for stable diffusion |
| 91 | +tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") |
| 92 | +text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") |
| 93 | +vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") |
| 94 | +unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") |
| 95 | + |
| 96 | +pipeline = StableDiffusionPipeline.from_pretrained( |
| 97 | + args.pretrained_model_name_or_path, text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer |
| 98 | +) |
| 99 | +pipeline.safety_checker = lambda images, clip_input: (images, False) |
| 100 | +if os.path.exists(os.path.join(args.pretrained_model_name_or_path, "best_model.pt")): |
| 101 | + unet = load(args.pretrained_model_name_or_path, model=unet) |
| 102 | + unet.eval() |
| 103 | + setattr(pipeline, "unet", unet) |
| 104 | +else: |
| 105 | + unet = unet.to(torch.device("cuda", args.cuda_id)) |
| 106 | +pipeline = pipeline.to(unet.device) |
| 107 | +grid, images = generate_images(pipeline, prompt=args.caption, num_images_per_prompt=args.images_num, seed=args.seed) |
| 108 | +grid.save(os.path.join(args.pretrained_model_name_or_path, "{}.png".format("_".join(args.caption.split())))) |
| 109 | +dirname = os.path.join(args.pretrained_model_name_or_path, "_".join(args.caption.split())) |
| 110 | +os.makedirs(dirname, exist_ok=True) |
| 111 | +for idx, image in enumerate(images): |
| 112 | + image.save(os.path.join(dirname, "{}.png".format(idx + 1))) |
0 commit comments