14
14
15
15
import inspect
16
16
from dataclasses import dataclass , field
17
- from typing import Any , Dict , List , Tuple , Union
17
+ from typing import Any , Dict , List , Tuple , Union , Type
18
+ from collections import OrderedDict
18
19
19
20
import torch
20
21
from tqdm .auto import tqdm
30
31
from .pipeline_loading_utils import _fetch_class_library_tuple , _get_pipeline_class
31
32
from .pipeline_utils import DiffusionPipeline
32
33
34
+ import warnings
35
+
33
36
34
37
if is_accelerate_available ():
35
38
import accelerate
@@ -99,6 +102,7 @@ class PipelineBlock:
99
102
optional_components = []
100
103
required_components = []
101
104
required_auxiliaries = []
105
+ optional_auxiliaries = []
102
106
103
107
@property
104
108
def inputs (self ) -> Tuple [Tuple [str , Any ], ...]:
@@ -122,7 +126,7 @@ def __init__(self, **kwargs):
122
126
for key , value in kwargs .items ():
123
127
if key in self .required_components or key in self .optional_components :
124
128
self .components [key ] = value
125
- elif key in self .required_auxiliaries :
129
+ elif key in self .required_auxiliaries or key in self . optional_auxiliaries :
126
130
self .auxiliaries [key ] = value
127
131
else :
128
132
self .configs [key ] = value
@@ -152,10 +156,11 @@ def from_pipe(cls, pipe: DiffusionPipeline, **kwargs):
152
156
components_to_add [component_name ] = component
153
157
154
158
# add auxiliaries
159
+ expected_auxiliaries = set (cls .required_auxiliaries + cls .optional_auxiliaries )
155
160
# - auxiliaries that are passed in kwargs
156
- auxiliaries_to_add = {k : kwargs .pop (k ) for k in cls . required_auxiliaries if k in kwargs }
161
+ auxiliaries_to_add = {k : kwargs .pop (k ) for k in expected_auxiliaries if k in kwargs }
157
162
# - auxiliaries that are in the pipeline
158
- for aux_name in cls . required_auxiliaries :
163
+ for aux_name in expected_auxiliaries :
159
164
if hasattr (pipe , aux_name ) and aux_name not in auxiliaries_to_add :
160
165
auxiliaries_to_add [aux_name ] = getattr (pipe , aux_name )
161
166
block_kwargs = {** components_to_add , ** auxiliaries_to_add }
@@ -167,7 +172,7 @@ def from_pipe(cls, pipe: DiffusionPipeline, **kwargs):
167
172
expected_configs = {
168
173
k
169
174
for k in pipe .config .keys ()
170
- if k in init_params and k not in expected_components and k not in cls . required_auxiliaries
175
+ if k in init_params and k not in expected_components and k not in expected_auxiliaries
171
176
}
172
177
173
178
for config_name in expected_configs :
@@ -210,6 +215,188 @@ def __repr__(self):
210
215
)
211
216
212
217
218
+ def combine_inputs (* input_lists : List [Tuple [str , Any ]]) -> List [Tuple [str , Any ]]:
219
+ """
220
+ Combines multiple lists of (name, default_value) tuples.
221
+ For duplicate inputs, updates only if current value is None and new value is not None.
222
+ Warns if multiple non-None default values exist for the same input.
223
+ """
224
+ combined_dict = {}
225
+ for inputs in input_lists :
226
+ for name , value in inputs :
227
+ if name in combined_dict :
228
+ current_value = combined_dict [name ]
229
+ if current_value is not None and value is not None and current_value != value :
230
+ warnings .warn (
231
+ f"Multiple different default values found for input '{ name } ': "
232
+ f"{ current_value } and { value } . Using { current_value } ."
233
+ )
234
+ if current_value is None and value is not None :
235
+ combined_dict [name ] = value
236
+ else :
237
+ combined_dict [name ] = value
238
+ return list (combined_dict .items ())
239
+
240
+
241
+
242
+ class AutoStep (PipelineBlock ):
243
+ base_blocks = [] # list of block classes
244
+ trigger_inputs = [] # list of trigger inputs (None for default block)
245
+ required_components = []
246
+ optional_components = []
247
+ required_auxiliaries = []
248
+ optional_auxiliaries = []
249
+
250
+ def __init__ (self , ** kwargs ):
251
+ self .blocks = []
252
+
253
+ for block_cls , trigger in zip (self .base_blocks , self .trigger_inputs ):
254
+ # Check components
255
+ missing_components = [
256
+ component for component in block_cls .required_components
257
+ if component not in kwargs
258
+ ]
259
+
260
+ # Check auxiliaries
261
+ missing_auxiliaries = [
262
+ auxiliary for auxiliary in block_cls .required_auxiliaries
263
+ if auxiliary not in kwargs
264
+ ]
265
+
266
+ if not missing_components and not missing_auxiliaries :
267
+ # Only get kwargs that the block's __init__ accepts
268
+ block_params = inspect .signature (block_cls .__init__ ).parameters
269
+ block_kwargs = {
270
+ k : v for k , v in kwargs .items ()
271
+ if k in block_params
272
+ }
273
+ self .blocks .append (block_cls (** block_kwargs ))
274
+
275
+ # Print message about trigger condition
276
+ if trigger is None :
277
+ print (f"Added default block: { block_cls .__name__ } " )
278
+ else :
279
+ print (f"Added block { block_cls .__name__ } - will be dispatched if '{ trigger } ' input is not None" )
280
+ else :
281
+ if trigger is None :
282
+ print (f"Cannot add default block { block_cls .__name__ } :" )
283
+ else :
284
+ print (f"Cannot add block { block_cls .__name__ } (triggered by '{ trigger } '):" )
285
+ if missing_components :
286
+ print (f" - Missing components: { missing_components } " )
287
+ if missing_auxiliaries :
288
+ print (f" - Missing auxiliaries: { missing_auxiliaries } " )
289
+
290
+ @property
291
+ def components (self ):
292
+ # Combine components from all blocks
293
+ components = {}
294
+ for block in self .blocks :
295
+ components .update (block .components )
296
+ return components
297
+
298
+ @property
299
+ def auxiliaries (self ):
300
+ # Combine auxiliaries from all blocks
301
+ auxiliaries = {}
302
+ for block in self .blocks :
303
+ auxiliaries .update (block .auxiliaries )
304
+ return auxiliaries
305
+
306
+ @property
307
+ def configs (self ):
308
+ # Combine configs from all blocks
309
+ configs = {}
310
+ for block in self .blocks :
311
+ configs .update (block .configs )
312
+ return configs
313
+
314
+ @property
315
+ def inputs (self ) -> List [Tuple [str , Any ]]:
316
+ return combine_inputs (* (block .inputs for block in self .blocks ))
317
+
318
+ @property
319
+ def intermediates_inputs (self ) -> List [str ]:
320
+ return list (set ().union (* (
321
+ block .intermediates_inputs for block in self .blocks
322
+ )))
323
+
324
+ @property
325
+ def intermediates_outputs (self ) -> List [str ]:
326
+ return list (set ().union (* (
327
+ block .intermediates_outputs for block in self .blocks
328
+ )))
329
+
330
+ def __call__ (self , pipeline , state ):
331
+ # Check triggers in priority order
332
+ for idx , trigger in enumerate (self .trigger_inputs [:- 1 ]): # Skip last (None) trigger
333
+ if state .get_input (trigger ) is not None :
334
+ return self .blocks [idx ](pipeline , state )
335
+ # If no triggers match, use the default block (last one)
336
+ return self .blocks [- 1 ](pipeline , state )
337
+
338
+
339
+ def make_auto_step (pipeline_block_map : OrderedDict ) -> Type [AutoStep ]:
340
+ """
341
+ Creates a new AutoStep subclass with updated class attributes based on the pipeline block map.
342
+
343
+ Args:
344
+ pipeline_block_map: OrderedDict mapping trigger inputs to pipeline block classes.
345
+ Order determines priority (earlier entries take precedence).
346
+ Must include None key for the default block.
347
+ """
348
+ blocks = list (pipeline_block_map .values ())
349
+ triggers = list (pipeline_block_map .keys ())
350
+
351
+ # Get all expected components (either required or optional by any block)
352
+ expected_components = []
353
+ for block in blocks :
354
+ for component in (block .required_components + block .optional_components ):
355
+ if component not in expected_components :
356
+ expected_components .append (component )
357
+
358
+ # A component is required if it's in required_components of all blocks
359
+ required_components = [
360
+ component for component in expected_components
361
+ if all (component in block .required_components for block in blocks )
362
+ ]
363
+
364
+ # All other expected components are optional
365
+ optional_components = [
366
+ component for component in expected_components
367
+ if component not in required_components
368
+ ]
369
+
370
+ # Get all expected auxiliaries (either required or optional by any block)
371
+ expected_auxiliaries = []
372
+ for block in blocks :
373
+ for auxiliary in (block .required_auxiliaries + getattr (block , 'optional_auxiliaries' , [])):
374
+ if auxiliary not in expected_auxiliaries :
375
+ expected_auxiliaries .append (auxiliary )
376
+
377
+ # An auxiliary is required if it's in required_auxiliaries of all blocks
378
+ required_auxiliaries = [
379
+ auxiliary for auxiliary in expected_auxiliaries
380
+ if all (auxiliary in block .required_auxiliaries for block in blocks )
381
+ ]
382
+
383
+ # All other expected auxiliaries are optional
384
+ optional_auxiliaries = [
385
+ auxiliary for auxiliary in expected_auxiliaries
386
+ if auxiliary not in required_auxiliaries
387
+ ]
388
+
389
+ # Create new class with updated attributes
390
+ return type ('AutoStep' , (AutoStep ,), {
391
+ 'base_blocks' : blocks ,
392
+ 'trigger_inputs' : triggers ,
393
+ 'required_components' : required_components ,
394
+ 'optional_components' : optional_components ,
395
+ 'required_auxiliaries' : required_auxiliaries ,
396
+ 'optional_auxiliaries' : optional_auxiliaries ,
397
+ })
398
+
399
+
213
400
class ModularPipelineBuilder (ConfigMixin ):
214
401
"""
215
402
Base class for all Modular pipelines.
@@ -585,7 +772,7 @@ def from_pipe(cls, pipeline, **kwargs):
585
772
# Create each block, passing only unused items that the block expects
586
773
for block_class in modular_pipeline_class .default_pipeline_blocks :
587
774
expected_components = set (block_class .required_components + block_class .optional_components )
588
- expected_auxiliaries = set (block_class .required_auxiliaries )
775
+ expected_auxiliaries = set (block_class .required_auxiliaries + block_class . optional_auxiliaries )
589
776
590
777
# Get init parameters to check for expected configs
591
778
init_params = inspect .signature (block_class .__init__ ).parameters
0 commit comments