@@ -70,10 +70,18 @@ def test_check_sampling_strategy_warning():
70
70
}, multiclass_target , 'clean-sampling' )
71
71
72
72
73
- def test_check_sampling_strategy_float_error ():
74
- msg = "'clean-sampling' methods do let the user specify the sampling ratio"
75
- with pytest .raises (ValueError , match = msg ):
76
- check_sampling_strategy (0.5 , binary_target , 'clean-sampling' )
73
+ @pytest .mark .parametrize (
74
+ "ratio, y, type, err_msg" ,
75
+ [(0.5 , binary_target , 'clean-sampling' ,
76
+ "'clean-sampling' methods do let the user specify the sampling ratio" ),
77
+ (0.1 , np .array ([0 ] * 10 + [1 ] * 20 ), 'over-sampling' ,
78
+ "remove samples from the minority class while trying to generate new" ),
79
+ (0.1 , np .array ([0 ] * 10 + [1 ] * 20 ), 'under-sampling' ,
80
+ "generate new sample in the majority class while trying to remove" )]
81
+ )
82
+ def test_check_sampling_strategy_float_error (ratio , y , type , err_msg ):
83
+ with pytest .raises (ValueError , match = err_msg ):
84
+ check_sampling_strategy (ratio , y , type )
77
85
78
86
79
87
def test_check_sampling_strategy_error ():
@@ -329,9 +337,9 @@ def test_check_ratio(ratio, sampling_type, expected_ratio, target):
329
337
def test_sampling_strategy_dict_over_sampling ():
330
338
y = np .array ([1 ] * 50 + [2 ] * 100 + [3 ] * 25 )
331
339
sampling_strategy = {1 : 70 , 2 : 140 , 3 : 70 }
332
- expected_msg = ("After over-sampling, the number of samples \(140\) in"
333
- " class 2 will be larger than the number of samples in the "
334
- " majority class \(class #2 -> 100\)" )
340
+ expected_msg = (r "After over-sampling, the number of samples \(140\) in"
341
+ r " class 2 will be larger than the number of samples in"
342
+ r" the majority class \(class #2 -> 100\)" )
335
343
with warns (UserWarning , expected_msg ):
336
344
check_sampling_strategy (sampling_strategy , y , 'over-sampling' )
337
345
0 commit comments