Skip to content

Commit c544196

Browse files
authored
refactor: extract init/forward function in UNet2DConditionModel (#6478)
* - extract function for stage in UNet2DConditionModel init & forward - Add new function get_mid_block() to unet_2d_blocks.py * add type hint to get_mid_block aligned with get_up_block and get_down_block; rename _set_xxx function * add type hint and use keyword arguments * remove `copy from` in versatile diffusion
1 parent 6382663 commit c544196

File tree

3 files changed

+483
-315
lines changed

3 files changed

+483
-315
lines changed

src/diffusers/models/unet_2d_blocks.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,81 @@ def get_down_block(
249249
raise ValueError(f"{down_block_type} does not exist.")
250250

251251

252+
def get_mid_block(
253+
mid_block_type: str,
254+
temb_channels: int,
255+
in_channels: int,
256+
resnet_eps: float,
257+
resnet_act_fn: str,
258+
resnet_groups: int,
259+
output_scale_factor: float = 1.0,
260+
transformer_layers_per_block: int = 1,
261+
num_attention_heads: Optional[int] = None,
262+
cross_attention_dim: Optional[int] = None,
263+
dual_cross_attention: bool = False,
264+
use_linear_projection: bool = False,
265+
mid_block_only_cross_attention: bool = False,
266+
upcast_attention: bool = False,
267+
resnet_time_scale_shift: str = "default",
268+
attention_type: str = "default",
269+
resnet_skip_time_act: bool = False,
270+
cross_attention_norm: Optional[str] = None,
271+
attention_head_dim: Optional[int] = 1,
272+
dropout: float = 0.0,
273+
):
274+
if mid_block_type == "UNetMidBlock2DCrossAttn":
275+
return UNetMidBlock2DCrossAttn(
276+
transformer_layers_per_block=transformer_layers_per_block,
277+
in_channels=in_channels,
278+
temb_channels=temb_channels,
279+
dropout=dropout,
280+
resnet_eps=resnet_eps,
281+
resnet_act_fn=resnet_act_fn,
282+
output_scale_factor=output_scale_factor,
283+
resnet_time_scale_shift=resnet_time_scale_shift,
284+
cross_attention_dim=cross_attention_dim,
285+
num_attention_heads=num_attention_heads,
286+
resnet_groups=resnet_groups,
287+
dual_cross_attention=dual_cross_attention,
288+
use_linear_projection=use_linear_projection,
289+
upcast_attention=upcast_attention,
290+
attention_type=attention_type,
291+
)
292+
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
293+
return UNetMidBlock2DSimpleCrossAttn(
294+
in_channels=in_channels,
295+
temb_channels=temb_channels,
296+
dropout=dropout,
297+
resnet_eps=resnet_eps,
298+
resnet_act_fn=resnet_act_fn,
299+
output_scale_factor=output_scale_factor,
300+
cross_attention_dim=cross_attention_dim,
301+
attention_head_dim=attention_head_dim,
302+
resnet_groups=resnet_groups,
303+
resnet_time_scale_shift=resnet_time_scale_shift,
304+
skip_time_act=resnet_skip_time_act,
305+
only_cross_attention=mid_block_only_cross_attention,
306+
cross_attention_norm=cross_attention_norm,
307+
)
308+
elif mid_block_type == "UNetMidBlock2D":
309+
return UNetMidBlock2D(
310+
in_channels=in_channels,
311+
temb_channels=temb_channels,
312+
dropout=dropout,
313+
num_layers=0,
314+
resnet_eps=resnet_eps,
315+
resnet_act_fn=resnet_act_fn,
316+
output_scale_factor=output_scale_factor,
317+
resnet_groups=resnet_groups,
318+
resnet_time_scale_shift=resnet_time_scale_shift,
319+
add_attention=False,
320+
)
321+
elif mid_block_type is None:
322+
return None
323+
else:
324+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
325+
326+
252327
def get_up_block(
253328
up_block_type: str,
254329
num_layers: int,

0 commit comments

Comments
 (0)