Skip to content

Commit ffc2992

Browse files
committed
add autostep (not complete)
1 parent c70a285 commit ffc2992

File tree

2 files changed

+209
-12
lines changed

2 files changed

+209
-12
lines changed

src/diffusers/pipelines/modular_pipeline_builder.py

Lines changed: 193 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
import inspect
1616
from dataclasses import dataclass, field
17-
from typing import Any, Dict, List, Tuple, Union
17+
from typing import Any, Dict, List, Tuple, Union, Type
18+
from collections import OrderedDict
1819

1920
import torch
2021
from tqdm.auto import tqdm
@@ -30,6 +31,8 @@
3031
from .pipeline_loading_utils import _fetch_class_library_tuple, _get_pipeline_class
3132
from .pipeline_utils import DiffusionPipeline
3233

34+
import warnings
35+
3336

3437
if is_accelerate_available():
3538
import accelerate
@@ -99,6 +102,7 @@ class PipelineBlock:
99102
optional_components = []
100103
required_components = []
101104
required_auxiliaries = []
105+
optional_auxiliaries = []
102106

103107
@property
104108
def inputs(self) -> Tuple[Tuple[str, Any], ...]:
@@ -122,7 +126,7 @@ def __init__(self, **kwargs):
122126
for key, value in kwargs.items():
123127
if key in self.required_components or key in self.optional_components:
124128
self.components[key] = value
125-
elif key in self.required_auxiliaries:
129+
elif key in self.required_auxiliaries or key in self.optional_auxiliaries:
126130
self.auxiliaries[key] = value
127131
else:
128132
self.configs[key] = value
@@ -152,10 +156,11 @@ def from_pipe(cls, pipe: DiffusionPipeline, **kwargs):
152156
components_to_add[component_name] = component
153157

154158
# add auxiliaries
159+
expected_auxiliaries = set(cls.required_auxiliaries + cls.optional_auxiliaries)
155160
# - auxiliaries that are passed in kwargs
156-
auxiliaries_to_add = {k: kwargs.pop(k) for k in cls.required_auxiliaries if k in kwargs}
161+
auxiliaries_to_add = {k: kwargs.pop(k) for k in expected_auxiliaries if k in kwargs}
157162
# - auxiliaries that are in the pipeline
158-
for aux_name in cls.required_auxiliaries:
163+
for aux_name in expected_auxiliaries:
159164
if hasattr(pipe, aux_name) and aux_name not in auxiliaries_to_add:
160165
auxiliaries_to_add[aux_name] = getattr(pipe, aux_name)
161166
block_kwargs = {**components_to_add, **auxiliaries_to_add}
@@ -167,7 +172,7 @@ def from_pipe(cls, pipe: DiffusionPipeline, **kwargs):
167172
expected_configs = {
168173
k
169174
for k in pipe.config.keys()
170-
if k in init_params and k not in expected_components and k not in cls.required_auxiliaries
175+
if k in init_params and k not in expected_components and k not in expected_auxiliaries
171176
}
172177

173178
for config_name in expected_configs:
@@ -210,6 +215,188 @@ def __repr__(self):
210215
)
211216

212217

218+
def combine_inputs(*input_lists: List[Tuple[str, Any]]) -> List[Tuple[str, Any]]:
219+
"""
220+
Combines multiple lists of (name, default_value) tuples.
221+
For duplicate inputs, updates only if current value is None and new value is not None.
222+
Warns if multiple non-None default values exist for the same input.
223+
"""
224+
combined_dict = {}
225+
for inputs in input_lists:
226+
for name, value in inputs:
227+
if name in combined_dict:
228+
current_value = combined_dict[name]
229+
if current_value is not None and value is not None and current_value != value:
230+
warnings.warn(
231+
f"Multiple different default values found for input '{name}': "
232+
f"{current_value} and {value}. Using {current_value}."
233+
)
234+
if current_value is None and value is not None:
235+
combined_dict[name] = value
236+
else:
237+
combined_dict[name] = value
238+
return list(combined_dict.items())
239+
240+
241+
242+
class AutoStep(PipelineBlock):
243+
base_blocks = [] # list of block classes
244+
trigger_inputs = [] # list of trigger inputs (None for default block)
245+
required_components = []
246+
optional_components = []
247+
required_auxiliaries = []
248+
optional_auxiliaries = []
249+
250+
def __init__(self, **kwargs):
251+
self.blocks = []
252+
253+
for block_cls, trigger in zip(self.base_blocks, self.trigger_inputs):
254+
# Check components
255+
missing_components = [
256+
component for component in block_cls.required_components
257+
if component not in kwargs
258+
]
259+
260+
# Check auxiliaries
261+
missing_auxiliaries = [
262+
auxiliary for auxiliary in block_cls.required_auxiliaries
263+
if auxiliary not in kwargs
264+
]
265+
266+
if not missing_components and not missing_auxiliaries:
267+
# Only get kwargs that the block's __init__ accepts
268+
block_params = inspect.signature(block_cls.__init__).parameters
269+
block_kwargs = {
270+
k: v for k, v in kwargs.items()
271+
if k in block_params
272+
}
273+
self.blocks.append(block_cls(**block_kwargs))
274+
275+
# Print message about trigger condition
276+
if trigger is None:
277+
print(f"Added default block: {block_cls.__name__}")
278+
else:
279+
print(f"Added block {block_cls.__name__} - will be dispatched if '{trigger}' input is not None")
280+
else:
281+
if trigger is None:
282+
print(f"Cannot add default block {block_cls.__name__}:")
283+
else:
284+
print(f"Cannot add block {block_cls.__name__} (triggered by '{trigger}'):")
285+
if missing_components:
286+
print(f" - Missing components: {missing_components}")
287+
if missing_auxiliaries:
288+
print(f" - Missing auxiliaries: {missing_auxiliaries}")
289+
290+
@property
291+
def components(self):
292+
# Combine components from all blocks
293+
components = {}
294+
for block in self.blocks:
295+
components.update(block.components)
296+
return components
297+
298+
@property
299+
def auxiliaries(self):
300+
# Combine auxiliaries from all blocks
301+
auxiliaries = {}
302+
for block in self.blocks:
303+
auxiliaries.update(block.auxiliaries)
304+
return auxiliaries
305+
306+
@property
307+
def configs(self):
308+
# Combine configs from all blocks
309+
configs = {}
310+
for block in self.blocks:
311+
configs.update(block.configs)
312+
return configs
313+
314+
@property
315+
def inputs(self) -> List[Tuple[str, Any]]:
316+
return combine_inputs(*(block.inputs for block in self.blocks))
317+
318+
@property
319+
def intermediates_inputs(self) -> List[str]:
320+
return list(set().union(*(
321+
block.intermediates_inputs for block in self.blocks
322+
)))
323+
324+
@property
325+
def intermediates_outputs(self) -> List[str]:
326+
return list(set().union(*(
327+
block.intermediates_outputs for block in self.blocks
328+
)))
329+
330+
def __call__(self, pipeline, state):
331+
# Check triggers in priority order
332+
for idx, trigger in enumerate(self.trigger_inputs[:-1]): # Skip last (None) trigger
333+
if state.get_input(trigger) is not None:
334+
return self.blocks[idx](pipeline, state)
335+
# If no triggers match, use the default block (last one)
336+
return self.blocks[-1](pipeline, state)
337+
338+
339+
def make_auto_step(pipeline_block_map: OrderedDict) -> Type[AutoStep]:
340+
"""
341+
Creates a new AutoStep subclass with updated class attributes based on the pipeline block map.
342+
343+
Args:
344+
pipeline_block_map: OrderedDict mapping trigger inputs to pipeline block classes.
345+
Order determines priority (earlier entries take precedence).
346+
Must include None key for the default block.
347+
"""
348+
blocks = list(pipeline_block_map.values())
349+
triggers = list(pipeline_block_map.keys())
350+
351+
# Get all expected components (either required or optional by any block)
352+
expected_components = []
353+
for block in blocks:
354+
for component in (block.required_components + block.optional_components):
355+
if component not in expected_components:
356+
expected_components.append(component)
357+
358+
# A component is required if it's in required_components of all blocks
359+
required_components = [
360+
component for component in expected_components
361+
if all(component in block.required_components for block in blocks)
362+
]
363+
364+
# All other expected components are optional
365+
optional_components = [
366+
component for component in expected_components
367+
if component not in required_components
368+
]
369+
370+
# Get all expected auxiliaries (either required or optional by any block)
371+
expected_auxiliaries = []
372+
for block in blocks:
373+
for auxiliary in (block.required_auxiliaries + getattr(block, 'optional_auxiliaries', [])):
374+
if auxiliary not in expected_auxiliaries:
375+
expected_auxiliaries.append(auxiliary)
376+
377+
# An auxiliary is required if it's in required_auxiliaries of all blocks
378+
required_auxiliaries = [
379+
auxiliary for auxiliary in expected_auxiliaries
380+
if all(auxiliary in block.required_auxiliaries for block in blocks)
381+
]
382+
383+
# All other expected auxiliaries are optional
384+
optional_auxiliaries = [
385+
auxiliary for auxiliary in expected_auxiliaries
386+
if auxiliary not in required_auxiliaries
387+
]
388+
389+
# Create new class with updated attributes
390+
return type('AutoStep', (AutoStep,), {
391+
'base_blocks': blocks,
392+
'trigger_inputs': triggers,
393+
'required_components': required_components,
394+
'optional_components': optional_components,
395+
'required_auxiliaries': required_auxiliaries,
396+
'optional_auxiliaries': optional_auxiliaries,
397+
})
398+
399+
213400
class ModularPipelineBuilder(ConfigMixin):
214401
"""
215402
Base class for all Modular pipelines.
@@ -585,7 +772,7 @@ def from_pipe(cls, pipeline, **kwargs):
585772
# Create each block, passing only unused items that the block expects
586773
for block_class in modular_pipeline_class.default_pipeline_blocks:
587774
expected_components = set(block_class.required_components + block_class.optional_components)
588-
expected_auxiliaries = set(block_class.required_auxiliaries)
775+
expected_auxiliaries = set(block_class.required_auxiliaries + block_class.optional_auxiliaries)
589776

590777
# Get init parameters to check for expected configs
591778
init_params = inspect.signature(block_class.__init__).parameters

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import inspect
1616
from typing import Any, List, Optional, Tuple, Union
17+
from collections import OrderedDict
1718

1819
import PIL
1920
import torch
@@ -33,7 +34,7 @@
3334
)
3435
from ...utils.torch_utils import is_compiled_module, randn_tensor
3536
from ..controlnet.multicontrolnet import MultiControlNetModel
36-
from ..modular_pipeline_builder import ModularPipelineBuilder, PipelineBlock, PipelineState
37+
from ..modular_pipeline_builder import ModularPipelineBuilder, PipelineBlock, PipelineState, make_auto_step
3738
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
3839
from .pipeline_output import (
3940
StableDiffusionXLPipelineOutput,
@@ -401,7 +402,7 @@ def denoising_value_valid(dnv):
401402

402403
class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
403404
optional_components = ["vae", "scheduler"]
404-
required_auxiliaries = ["image_processor"]
405+
optional_auxiliaries = ["image_processor"]
405406

406407
@property
407408
def inputs(self) -> List[Tuple[str, Any]]:
@@ -645,7 +646,7 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin
645646

646647
class StableDiffusionXLDenoiseStep(PipelineBlock):
647648
required_components = ["unet", "scheduler"]
648-
required_auxiliaries = ["guider"]
649+
optional_auxiliaries = ["guider"]
649650

650651
@property
651652
def inputs(self) -> List[Tuple[str, Any]]:
@@ -780,7 +781,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
780781

781782
class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
782783
required_components = ["unet", "controlnet", "scheduler"]
783-
required_auxiliaries = ["guider", "controlnet_guider", "control_image_processor"]
784+
optional_auxiliaries = ["guider", "controlnet_guider", "control_image_processor"]
784785

785786
@property
786787
def inputs(self) -> List[Tuple[str, Any]]:
@@ -1069,7 +1070,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
10691070

10701071
class StableDiffusionXLDecodeLatentsStep(PipelineBlock):
10711072
optional_components = ["vae"]
1072-
required_auxiliaries = ["image_processor"]
1073+
optional_auxiliaries = ["image_processor"]
10731074

10741075
@property
10751076
def inputs(self) -> List[Tuple[str, Any]]:
@@ -1154,6 +1155,15 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
11541155
return pipeline, state
11551156

11561157

1158+
AUTO_DENOISE_BLOCK_MAP = OrderedDict([
1159+
# Higher priority blocks first
1160+
("control_image", StableDiffusionXLControlNetDenoiseStep),
1161+
# Default block
1162+
(None, StableDiffusionXLDenoiseStep),
1163+
])
1164+
1165+
StableDiffusionXLAutoDenoiseStep = make_auto_step(AUTO_DENOISE_BLOCK_MAP)
1166+
11571167
class StableDiffusionXLModularPipeline(
11581168
ModularPipelineBuilder,
11591169
StableDiffusionMixin,
@@ -1166,7 +1176,7 @@ class StableDiffusionXLModularPipeline(
11661176
StableDiffusionXLSetTimestepsStep,
11671177
StableDiffusionXLPrepareLatentsStep,
11681178
StableDiffusionXLPrepareAdditionalConditioningStep,
1169-
StableDiffusionXLDenoiseStep,
1179+
StableDiffusionXLAutoDenoiseStep,
11701180
StableDiffusionXLDecodeLatentsStep,
11711181
]
11721182

0 commit comments

Comments
 (0)