Skip to content

Commit 2b298dd

Browse files
authored
FIX: detect ill-pose sampling-strategy as a float (#507)
1 parent 03a8334 commit 2b298dd

File tree

2 files changed

+25
-7
lines changed

2 files changed

+25
-7
lines changed

imblearn/utils/_validation.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,13 +317,23 @@ def _sampling_strategy_float(sampling_strategy, y, sampling_type):
317317
key: int(n_sample_majority * sampling_strategy - value)
318318
for (key, value) in target_stats.items() if key != class_majority
319319
}
320+
if any([n_samples <= 0 for n_samples in sampling_strategy_.values()]):
321+
raise ValueError("The specified ratio required to remove samples "
322+
"from the minority class while trying to "
323+
"generate new samples. Please increase the "
324+
"ratio.")
320325
elif (sampling_type == 'under-sampling'):
321326
n_sample_minority = min(target_stats.values())
322327
class_minority = min(target_stats, key=target_stats.get)
323328
sampling_strategy_ = {
324329
key: int(n_sample_minority / sampling_strategy)
325330
for (key, value) in target_stats.items() if key != class_minority
326331
}
332+
if any([n_samples > target_stats[target]
333+
for target, n_samples in sampling_strategy_.items()]):
334+
raise ValueError("The specified ratio required to generate new "
335+
"sample in the majority class while trying to "
336+
"remove samples. Please increase the ratio.")
327337
else:
328338
raise ValueError("'clean-sampling' methods do let the user "
329339
"specify the sampling ratio.")

imblearn/utils/tests/test_validation.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,18 @@ def test_check_sampling_strategy_warning():
7070
}, multiclass_target, 'clean-sampling')
7171

7272

73-
def test_check_sampling_strategy_float_error():
74-
msg = "'clean-sampling' methods do let the user specify the sampling ratio"
75-
with pytest.raises(ValueError, match=msg):
76-
check_sampling_strategy(0.5, binary_target, 'clean-sampling')
73+
@pytest.mark.parametrize(
74+
"ratio, y, type, err_msg",
75+
[(0.5, binary_target, 'clean-sampling',
76+
"'clean-sampling' methods do let the user specify the sampling ratio"),
77+
(0.1, np.array([0] * 10 + [1] * 20), 'over-sampling',
78+
"remove samples from the minority class while trying to generate new"),
79+
(0.1, np.array([0] * 10 + [1] * 20), 'under-sampling',
80+
"generate new sample in the majority class while trying to remove")]
81+
)
82+
def test_check_sampling_strategy_float_error(ratio, y, type, err_msg):
83+
with pytest.raises(ValueError, match=err_msg):
84+
check_sampling_strategy(ratio, y, type)
7785

7886

7987
def test_check_sampling_strategy_error():
@@ -329,9 +337,9 @@ def test_check_ratio(ratio, sampling_type, expected_ratio, target):
329337
def test_sampling_strategy_dict_over_sampling():
330338
y = np.array([1] * 50 + [2] * 100 + [3] * 25)
331339
sampling_strategy = {1: 70, 2: 140, 3: 70}
332-
expected_msg = ("After over-sampling, the number of samples \(140\) in"
333-
" class 2 will be larger than the number of samples in the"
334-
" majority class \(class #2 -> 100\)")
340+
expected_msg = (r"After over-sampling, the number of samples \(140\) in"
341+
r" class 2 will be larger than the number of samples in"
342+
r" the majority class \(class #2 -> 100\)")
335343
with warns(UserWarning, expected_msg):
336344
check_sampling_strategy(sampling_strategy, y, 'over-sampling')
337345

0 commit comments

Comments
 (0)