@@ -249,6 +249,81 @@ def get_down_block(
249
249
raise ValueError (f"{ down_block_type } does not exist." )
250
250
251
251
252
+ def get_mid_block (
253
+ mid_block_type : str ,
254
+ temb_channels : int ,
255
+ in_channels : int ,
256
+ resnet_eps : float ,
257
+ resnet_act_fn : str ,
258
+ resnet_groups : int ,
259
+ output_scale_factor : float = 1.0 ,
260
+ transformer_layers_per_block : int = 1 ,
261
+ num_attention_heads : Optional [int ] = None ,
262
+ cross_attention_dim : Optional [int ] = None ,
263
+ dual_cross_attention : bool = False ,
264
+ use_linear_projection : bool = False ,
265
+ mid_block_only_cross_attention : bool = False ,
266
+ upcast_attention : bool = False ,
267
+ resnet_time_scale_shift : str = "default" ,
268
+ attention_type : str = "default" ,
269
+ resnet_skip_time_act : bool = False ,
270
+ cross_attention_norm : Optional [str ] = None ,
271
+ attention_head_dim : Optional [int ] = 1 ,
272
+ dropout : float = 0.0 ,
273
+ ):
274
+ if mid_block_type == "UNetMidBlock2DCrossAttn" :
275
+ return UNetMidBlock2DCrossAttn (
276
+ transformer_layers_per_block = transformer_layers_per_block ,
277
+ in_channels = in_channels ,
278
+ temb_channels = temb_channels ,
279
+ dropout = dropout ,
280
+ resnet_eps = resnet_eps ,
281
+ resnet_act_fn = resnet_act_fn ,
282
+ output_scale_factor = output_scale_factor ,
283
+ resnet_time_scale_shift = resnet_time_scale_shift ,
284
+ cross_attention_dim = cross_attention_dim ,
285
+ num_attention_heads = num_attention_heads ,
286
+ resnet_groups = resnet_groups ,
287
+ dual_cross_attention = dual_cross_attention ,
288
+ use_linear_projection = use_linear_projection ,
289
+ upcast_attention = upcast_attention ,
290
+ attention_type = attention_type ,
291
+ )
292
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn" :
293
+ return UNetMidBlock2DSimpleCrossAttn (
294
+ in_channels = in_channels ,
295
+ temb_channels = temb_channels ,
296
+ dropout = dropout ,
297
+ resnet_eps = resnet_eps ,
298
+ resnet_act_fn = resnet_act_fn ,
299
+ output_scale_factor = output_scale_factor ,
300
+ cross_attention_dim = cross_attention_dim ,
301
+ attention_head_dim = attention_head_dim ,
302
+ resnet_groups = resnet_groups ,
303
+ resnet_time_scale_shift = resnet_time_scale_shift ,
304
+ skip_time_act = resnet_skip_time_act ,
305
+ only_cross_attention = mid_block_only_cross_attention ,
306
+ cross_attention_norm = cross_attention_norm ,
307
+ )
308
+ elif mid_block_type == "UNetMidBlock2D" :
309
+ return UNetMidBlock2D (
310
+ in_channels = in_channels ,
311
+ temb_channels = temb_channels ,
312
+ dropout = dropout ,
313
+ num_layers = 0 ,
314
+ resnet_eps = resnet_eps ,
315
+ resnet_act_fn = resnet_act_fn ,
316
+ output_scale_factor = output_scale_factor ,
317
+ resnet_groups = resnet_groups ,
318
+ resnet_time_scale_shift = resnet_time_scale_shift ,
319
+ add_attention = False ,
320
+ )
321
+ elif mid_block_type is None :
322
+ return None
323
+ else :
324
+ raise ValueError (f"unknown mid_block_type : { mid_block_type } " )
325
+
326
+
252
327
def get_up_block (
253
328
up_block_type : str ,
254
329
num_layers : int ,
0 commit comments