@@ -52,27 +52,21 @@ def test_ema_training(self):
52
52
def test_training (self ):
53
53
pass
54
54
55
- @unittest .skipIf (torch_device == "mps" , "mish op not supported in MPS" )
56
55
def test_determinism (self ):
57
56
super ().test_determinism ()
58
57
59
- @unittest .skipIf (torch_device == "mps" , "mish op not supported in MPS" )
60
58
def test_outputs_equivalence (self ):
61
59
super ().test_outputs_equivalence ()
62
60
63
- @unittest .skipIf (torch_device == "mps" , "mish op not supported in MPS" )
64
61
def test_from_save_pretrained (self ):
65
62
super ().test_from_save_pretrained ()
66
63
67
- @unittest .skipIf (torch_device == "mps" , "mish op not supported in MPS" )
68
64
def test_from_save_pretrained_variant (self ):
69
65
super ().test_from_save_pretrained_variant ()
70
66
71
- @unittest .skipIf (torch_device == "mps" , "mish op not supported in MPS" )
72
67
def test_model_from_pretrained (self ):
73
68
super ().test_model_from_pretrained ()
74
69
75
- @unittest .skipIf (torch_device == "mps" , "mish op not supported in MPS" )
76
70
def test_output (self ):
77
71
super ().test_output ()
78
72
@@ -89,12 +83,11 @@ def prepare_init_args_and_inputs_for_common(self):
89
83
"mid_block_type" : "MidResTemporalBlock1D" ,
90
84
"down_block_types" : ("DownResnetBlock1D" , "DownResnetBlock1D" , "DownResnetBlock1D" , "DownResnetBlock1D" ),
91
85
"up_block_types" : ("UpResnetBlock1D" , "UpResnetBlock1D" , "UpResnetBlock1D" ),
92
- "act_fn" : "mish " ,
86
+ "act_fn" : "swish " ,
93
87
}
94
88
inputs_dict = self .dummy_input
95
89
return init_dict , inputs_dict
96
90
97
- @unittest .skipIf (torch_device == "mps" , "mish op not supported in MPS" )
98
91
def test_from_pretrained_hub (self ):
99
92
model , loading_info = UNet1DModel .from_pretrained (
100
93
"bglick13/hopper-medium-v2-value-function-hor32" , output_loading_info = True , subfolder = "unet"
@@ -107,7 +100,6 @@ def test_from_pretrained_hub(self):
107
100
108
101
assert image is not None , "Make sure output is not None"
109
102
110
- @unittest .skipIf (torch_device == "mps" , "mish op not supported in MPS" )
111
103
def test_output_pretrained (self ):
112
104
model = UNet1DModel .from_pretrained ("bglick13/hopper-medium-v2-value-function-hor32" , subfolder = "unet" )
113
105
torch .manual_seed (0 )
@@ -177,27 +169,21 @@ def input_shape(self):
177
169
def output_shape (self ):
178
170
return (4 , 14 , 1 )
179
171
180
- @unittest .skipIf (torch_device == "mps" , "mish op not supported in MPS" )
181
172
def test_determinism (self ):
182
173
super ().test_determinism ()
183
174
184
- @unittest .skipIf (torch_device == "mps" , "mish op not supported in MPS" )
185
175
def test_outputs_equivalence (self ):
186
176
super ().test_outputs_equivalence ()
187
177
188
- @unittest .skipIf (torch_device == "mps" , "mish op not supported in MPS" )
189
178
def test_from_save_pretrained (self ):
190
179
super ().test_from_save_pretrained ()
191
180
192
- @unittest .skipIf (torch_device == "mps" , "mish op not supported in MPS" )
193
181
def test_from_save_pretrained_variant (self ):
194
182
super ().test_from_save_pretrained_variant ()
195
183
196
- @unittest .skipIf (torch_device == "mps" , "mish op not supported in MPS" )
197
184
def test_model_from_pretrained (self ):
198
185
super ().test_model_from_pretrained ()
199
186
200
- @unittest .skipIf (torch_device == "mps" , "mish op not supported in MPS" )
201
187
def test_output (self ):
202
188
# UNetRL is a value-function is different output shape
203
189
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
@@ -241,7 +227,6 @@ def prepare_init_args_and_inputs_for_common(self):
241
227
inputs_dict = self .dummy_input
242
228
return init_dict , inputs_dict
243
229
244
- @unittest .skipIf (torch_device == "mps" , "mish op not supported in MPS" )
245
230
def test_from_pretrained_hub (self ):
246
231
value_function , vf_loading_info = UNet1DModel .from_pretrained (
247
232
"bglick13/hopper-medium-v2-value-function-hor32" , output_loading_info = True , subfolder = "value_function"
@@ -254,7 +239,6 @@ def test_from_pretrained_hub(self):
254
239
255
240
assert image is not None , "Make sure output is not None"
256
241
257
- @unittest .skipIf (torch_device == "mps" , "mish op not supported in MPS" )
258
242
def test_output_pretrained (self ):
259
243
value_function , vf_loading_info = UNet1DModel .from_pretrained (
260
244
"bglick13/hopper-medium-v2-value-function-hor32" , output_loading_info = True , subfolder = "value_function"
0 commit comments