67
67
from diffusers .utils .torch_utils import is_compiled_module
68
68
69
69
70
+ if is_wandb_available ():
71
+ import wandb
72
+
70
73
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
71
74
check_min_version ("0.27.0.dev0" )
72
75
@@ -140,6 +143,61 @@ def save_model_card(
140
143
model_card .save (os .path .join (repo_folder , "README.md" ))
141
144
142
145
146
+ def log_validation (
147
+ pipeline ,
148
+ args ,
149
+ accelerator ,
150
+ pipeline_args ,
151
+ epoch ,
152
+ is_final_validation = False ,
153
+ ):
154
+ logger .info (
155
+ f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
156
+ f" { args .validation_prompt } ."
157
+ )
158
+
159
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
160
+ scheduler_args = {}
161
+
162
+ if "variance_type" in pipeline .scheduler .config :
163
+ variance_type = pipeline .scheduler .config .variance_type
164
+
165
+ if variance_type in ["learned" , "learned_range" ]:
166
+ variance_type = "fixed_small"
167
+
168
+ scheduler_args ["variance_type" ] = variance_type
169
+
170
+ pipeline .scheduler = DPMSolverMultistepScheduler .from_config (pipeline .scheduler .config , ** scheduler_args )
171
+
172
+ pipeline = pipeline .to (accelerator .device )
173
+ pipeline .set_progress_bar_config (disable = True )
174
+
175
+ # run inference
176
+ generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed ) if args .seed else None
177
+
178
+ with torch .cuda .amp .autocast ():
179
+ images = [pipeline (** pipeline_args , generator = generator ).images [0 ] for _ in range (args .num_validation_images )]
180
+
181
+ for tracker in accelerator .trackers :
182
+ phase_name = "test" if is_final_validation else "validation"
183
+ if tracker .name == "tensorboard" :
184
+ np_images = np .stack ([np .asarray (img ) for img in images ])
185
+ tracker .writer .add_images (phase_name , np_images , epoch , dataformats = "NHWC" )
186
+ if tracker .name == "wandb" :
187
+ tracker .log (
188
+ {
189
+ phase_name : [
190
+ wandb .Image (image , caption = f"{ i } : { args .validation_prompt } " ) for i , image in enumerate (images )
191
+ ]
192
+ }
193
+ )
194
+
195
+ del pipeline
196
+ torch .cuda .empty_cache ()
197
+
198
+ return images
199
+
200
+
143
201
def import_model_class_from_model_name_or_path (
144
202
pretrained_model_name_or_path : str , revision : str , subfolder : str = "text_encoder"
145
203
):
@@ -862,7 +920,6 @@ def main(args):
862
920
if args .report_to == "wandb" :
863
921
if not is_wandb_available ():
864
922
raise ImportError ("Make sure to install wandb if you want to use it for logging during training." )
865
- import wandb
866
923
867
924
# Make one log on every process with the configuration for debugging.
868
925
logging .basicConfig (
@@ -1615,10 +1672,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1615
1672
1616
1673
if accelerator .is_main_process :
1617
1674
if args .validation_prompt is not None and epoch % args .validation_epochs == 0 :
1618
- logger .info (
1619
- f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
1620
- f" { args .validation_prompt } ."
1621
- )
1622
1675
# create pipeline
1623
1676
if not args .train_text_encoder :
1624
1677
text_encoder_one = text_encoder_cls_one .from_pretrained (
@@ -1644,50 +1697,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1644
1697
torch_dtype = weight_dtype ,
1645
1698
)
1646
1699
1647
- # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1648
- scheduler_args = {}
1649
-
1650
- if "variance_type" in pipeline .scheduler .config :
1651
- variance_type = pipeline .scheduler .config .variance_type
1652
-
1653
- if variance_type in ["learned" , "learned_range" ]:
1654
- variance_type = "fixed_small"
1655
-
1656
- scheduler_args ["variance_type" ] = variance_type
1657
-
1658
- pipeline .scheduler = DPMSolverMultistepScheduler .from_config (
1659
- pipeline .scheduler .config , ** scheduler_args
1660
- )
1661
-
1662
- pipeline = pipeline .to (accelerator .device )
1663
- pipeline .set_progress_bar_config (disable = True )
1664
-
1665
- # run inference
1666
- generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed ) if args .seed else None
1667
1700
pipeline_args = {"prompt" : args .validation_prompt }
1668
1701
1669
- with torch .cuda .amp .autocast ():
1670
- images = [
1671
- pipeline (** pipeline_args , generator = generator ).images [0 ]
1672
- for _ in range (args .num_validation_images )
1673
- ]
1674
-
1675
- for tracker in accelerator .trackers :
1676
- if tracker .name == "tensorboard" :
1677
- np_images = np .stack ([np .asarray (img ) for img in images ])
1678
- tracker .writer .add_images ("validation" , np_images , epoch , dataformats = "NHWC" )
1679
- if tracker .name == "wandb" :
1680
- tracker .log (
1681
- {
1682
- "validation" : [
1683
- wandb .Image (image , caption = f"{ i } : { args .validation_prompt } " )
1684
- for i , image in enumerate (images )
1685
- ]
1686
- }
1687
- )
1688
-
1689
- del pipeline
1690
- torch .cuda .empty_cache ()
1702
+ images = log_validation (
1703
+ pipeline ,
1704
+ args ,
1705
+ accelerator ,
1706
+ pipeline_args ,
1707
+ epoch ,
1708
+ )
1691
1709
1692
1710
# Save the lora layers
1693
1711
accelerator .wait_for_everyone ()
@@ -1733,45 +1751,21 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1733
1751
torch_dtype = weight_dtype ,
1734
1752
)
1735
1753
1736
- # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1737
- scheduler_args = {}
1738
-
1739
- if "variance_type" in pipeline .scheduler .config :
1740
- variance_type = pipeline .scheduler .config .variance_type
1741
-
1742
- if variance_type in ["learned" , "learned_range" ]:
1743
- variance_type = "fixed_small"
1744
-
1745
- scheduler_args ["variance_type" ] = variance_type
1746
-
1747
- pipeline .scheduler = DPMSolverMultistepScheduler .from_config (pipeline .scheduler .config , ** scheduler_args )
1748
-
1749
1754
# load attention processors
1750
1755
pipeline .load_lora_weights (args .output_dir )
1751
1756
1752
1757
# run inference
1753
1758
images = []
1754
1759
if args .validation_prompt and args .num_validation_images > 0 :
1755
- pipeline = pipeline .to (accelerator .device )
1756
- generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed ) if args .seed else None
1757
- images = [
1758
- pipeline (args .validation_prompt , num_inference_steps = 25 , generator = generator ).images [0 ]
1759
- for _ in range (args .num_validation_images )
1760
- ]
1761
-
1762
- for tracker in accelerator .trackers :
1763
- if tracker .name == "tensorboard" :
1764
- np_images = np .stack ([np .asarray (img ) for img in images ])
1765
- tracker .writer .add_images ("test" , np_images , epoch , dataformats = "NHWC" )
1766
- if tracker .name == "wandb" :
1767
- tracker .log (
1768
- {
1769
- "test" : [
1770
- wandb .Image (image , caption = f"{ i } : { args .validation_prompt } " )
1771
- for i , image in enumerate (images )
1772
- ]
1773
- }
1774
- )
1760
+ pipeline_args = {"prompt" : args .validation_prompt , "num_inference_steps" : 25 }
1761
+ images = log_validation (
1762
+ pipeline ,
1763
+ args ,
1764
+ accelerator ,
1765
+ pipeline_args ,
1766
+ epoch ,
1767
+ final_validation = True ,
1768
+ )
1775
1769
1776
1770
if args .push_to_hub :
1777
1771
save_model_card (
0 commit comments