Skip to content

Commit ad8467f

Browse files
SauravMaheshkarpatrickvonplaten
authored andcommitted
feat: add act_fn param to OutValueFunctionBlock (huggingface#3994)
* feat: add act_fn param to OutValueFunctionBlock * feat: update unet1d tests to not use mish * feat: add `mish` as the default activation function Co-authored-by: Patrick von Platen <[email protected]> * feat: drop mish tests from unet1d --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent e7644e2 commit ad8467f

File tree

2 files changed

+4
-20
lines changed

2 files changed

+4
-20
lines changed

src/diffusers/models/unet_1d_blocks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,12 +235,12 @@ def forward(self, hidden_states, temb=None):
235235

236236

237237
class OutValueFunctionBlock(nn.Module):
238-
def __init__(self, fc_dim, embed_dim):
238+
def __init__(self, fc_dim, embed_dim, act_fn="mish"):
239239
super().__init__()
240240
self.final_block = nn.ModuleList(
241241
[
242242
nn.Linear(fc_dim + embed_dim, fc_dim // 2),
243-
nn.Mish(),
243+
get_activation(act_fn),
244244
nn.Linear(fc_dim // 2, 1),
245245
]
246246
)
@@ -652,5 +652,5 @@ def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, ac
652652
if out_block_type == "OutConv1DBlock":
653653
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
654654
elif out_block_type == "ValueFunction":
655-
return OutValueFunctionBlock(fc_dim, embed_dim)
655+
return OutValueFunctionBlock(fc_dim, embed_dim, act_fn)
656656
return None

tests/models/test_models_unet_1d.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,27 +52,21 @@ def test_ema_training(self):
5252
def test_training(self):
5353
pass
5454

55-
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
5655
def test_determinism(self):
5756
super().test_determinism()
5857

59-
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
6058
def test_outputs_equivalence(self):
6159
super().test_outputs_equivalence()
6260

63-
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
6461
def test_from_save_pretrained(self):
6562
super().test_from_save_pretrained()
6663

67-
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
6864
def test_from_save_pretrained_variant(self):
6965
super().test_from_save_pretrained_variant()
7066

71-
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
7267
def test_model_from_pretrained(self):
7368
super().test_model_from_pretrained()
7469

75-
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
7670
def test_output(self):
7771
super().test_output()
7872

@@ -89,12 +83,11 @@ def prepare_init_args_and_inputs_for_common(self):
8983
"mid_block_type": "MidResTemporalBlock1D",
9084
"down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
9185
"up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"),
92-
"act_fn": "mish",
86+
"act_fn": "swish",
9387
}
9488
inputs_dict = self.dummy_input
9589
return init_dict, inputs_dict
9690

97-
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
9891
def test_from_pretrained_hub(self):
9992
model, loading_info = UNet1DModel.from_pretrained(
10093
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet"
@@ -107,7 +100,6 @@ def test_from_pretrained_hub(self):
107100

108101
assert image is not None, "Make sure output is not None"
109102

110-
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
111103
def test_output_pretrained(self):
112104
model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet")
113105
torch.manual_seed(0)
@@ -177,27 +169,21 @@ def input_shape(self):
177169
def output_shape(self):
178170
return (4, 14, 1)
179171

180-
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
181172
def test_determinism(self):
182173
super().test_determinism()
183174

184-
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
185175
def test_outputs_equivalence(self):
186176
super().test_outputs_equivalence()
187177

188-
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
189178
def test_from_save_pretrained(self):
190179
super().test_from_save_pretrained()
191180

192-
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
193181
def test_from_save_pretrained_variant(self):
194182
super().test_from_save_pretrained_variant()
195183

196-
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
197184
def test_model_from_pretrained(self):
198185
super().test_model_from_pretrained()
199186

200-
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
201187
def test_output(self):
202188
# UNetRL is a value-function is different output shape
203189
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -241,7 +227,6 @@ def prepare_init_args_and_inputs_for_common(self):
241227
inputs_dict = self.dummy_input
242228
return init_dict, inputs_dict
243229

244-
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
245230
def test_from_pretrained_hub(self):
246231
value_function, vf_loading_info = UNet1DModel.from_pretrained(
247232
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
@@ -254,7 +239,6 @@ def test_from_pretrained_hub(self):
254239

255240
assert image is not None, "Make sure output is not None"
256241

257-
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
258242
def test_output_pretrained(self):
259243
value_function, vf_loading_info = UNet1DModel.from_pretrained(
260244
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"

0 commit comments

Comments
 (0)