Skip to content

Commit 77bfb56

Browse files
init-22sayakpaul
andauthored
adding required parameters while calling the get_up_block and get_down_block (#3210)
* removed unnecessary parameters from get_up_block and get_down_block functions * adding resnet_skip_time_act, resnet_out_scale_factor and cross_attention_norm to get_up_block and get_down_block functions --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 70ef774 commit 77bfb56

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ def get_down_block(
4242
only_cross_attention=False,
4343
upcast_attention=False,
4444
resnet_time_scale_shift="default",
45+
resnet_skip_time_act=False,
46+
resnet_out_scale_factor=1.0,
47+
cross_attention_norm=None,
4548
):
4649
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
4750
if down_block_type == "DownBlockFlat":
@@ -98,6 +101,9 @@ def get_up_block(
98101
only_cross_attention=False,
99102
upcast_attention=False,
100103
resnet_time_scale_shift="default",
104+
resnet_skip_time_act=False,
105+
resnet_out_scale_factor=1.0,
106+
cross_attention_norm=None,
101107
):
102108
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
103109
if up_block_type == "UpBlockFlat":

0 commit comments

Comments
 (0)