|
3 | 3 | # Christos Aridas
|
4 | 4 | # License: MIT
|
5 | 5 |
|
6 |
| -from __future__ import print_function |
7 |
| - |
8 | 6 | from collections import Counter
|
9 | 7 |
|
10 | 8 | import pytest
|
11 | 9 | import numpy as np
|
12 | 10 |
|
13 |
| -from pytest import raises |
14 |
| - |
15 | 11 | from sklearn.datasets import load_iris
|
16 | 12 |
|
17 | 13 | from imblearn.datasets import make_imbalance
|
18 | 14 |
|
19 |
| -data = load_iris() |
20 |
| -X, Y = data.data, data.target |
21 | 15 |
|
| 16 | +@pytest.fixture |
| 17 | +def iris(): |
| 18 | + return load_iris(return_X_y=True) |
22 | 19 |
|
23 |
| -def test_make_imbalanced_backcompat(): |
| 20 | + |
| 21 | +def test_make_imbalanced_backcompat(iris): |
24 | 22 | # check an error is raised with we don't pass sampling_strategy and ratio
|
25 |
| - with raises(TypeError, match="missing 1 required positional argument"): |
26 |
| - make_imbalance(X, Y) |
| 23 | + with pytest.raises(TypeError, match="missing 1 required positional argument"): |
| 24 | + make_imbalance(*iris) |
27 | 25 |
|
28 | 26 |
|
29 |
| -def test_make_imbalance_error(): |
| 27 | +@pytest.mark.parametrize( |
| 28 | + "sampling_strategy, err_msg", |
| 29 | + [({0: -100, 1: 50, 2: 50}, "in a class cannot be negative"), |
| 30 | + ({0: 10, 1: 70}, "should be less or equal to the original"), |
| 31 | + ('random-string', "has to be a dictionary or a function")] |
| 32 | +) |
| 33 | +def test_make_imbalance_error(iris, sampling_strategy, err_msg): |
30 | 34 | # we are reusing part of utils.check_sampling_strategy, however this is not
|
31 | 35 | # cover in the common tests so we will repeat it here
|
32 |
| - sampling_strategy = {0: -100, 1: 50, 2: 50} |
33 |
| - with raises(ValueError, match="in a class cannot be negative"): |
34 |
| - make_imbalance(X, Y, sampling_strategy) |
35 |
| - sampling_strategy = {0: 10, 1: 70} |
36 |
| - with raises(ValueError, match="should be less or equal to the original"): |
37 |
| - make_imbalance(X, Y, sampling_strategy) |
38 |
| - y_ = np.zeros((X.shape[0], )) |
39 |
| - sampling_strategy = {0: 10} |
40 |
| - with raises(ValueError, match="needs to have more than 1 class."): |
41 |
| - make_imbalance(X, y_, sampling_strategy) |
42 |
| - sampling_strategy = 'random-string' |
43 |
| - with raises(ValueError, match="has to be a dictionary or a function"): |
44 |
| - make_imbalance(X, Y, sampling_strategy) |
45 |
| - |
46 |
| - |
47 |
| -def test_make_imbalance_dict(): |
48 |
| - sampling_strategy = {0: 10, 1: 20, 2: 30} |
49 |
| - X_, y_ = make_imbalance(X, Y, sampling_strategy=sampling_strategy) |
50 |
| - assert Counter(y_) == sampling_strategy |
51 |
| - |
52 |
| - sampling_strategy = {0: 10, 1: 20} |
53 |
| - X_, y_ = make_imbalance(X, Y, sampling_strategy=sampling_strategy) |
54 |
| - assert Counter(y_) == {0: 10, 1: 20, 2: 50} |
| 36 | + X, y = iris |
| 37 | + with pytest.raises(ValueError, match=err_msg): |
| 38 | + make_imbalance(X, y, sampling_strategy) |
| 39 | + |
| 40 | + |
| 41 | +def test_make_imbalance_error_single_class(iris): |
| 42 | + X, y = iris |
| 43 | + y = np.zeros_like(y) |
| 44 | + with pytest.raises(ValueError, match="needs to have more than 1 class."): |
| 45 | + make_imbalance(X, y, {0: 10}) |
| 46 | + |
| 47 | + |
| 48 | +@pytest.mark.parametrize( |
| 49 | + "sampling_strategy, expected_counts", |
| 50 | + [({0: 10, 1: 20, 2: 30}, {0: 10, 1: 20, 2: 30}), |
| 51 | + ({0: 10, 1: 20}, {0: 10, 1: 20, 2: 50})] |
| 52 | +) |
| 53 | +def test_make_imbalance_dict(iris, sampling_strategy, expected_counts): |
| 54 | + X, y = iris |
| 55 | + _, y_ = make_imbalance(X, y, sampling_strategy=sampling_strategy) |
| 56 | + assert Counter(y_) == expected_counts |
55 | 57 |
|
56 | 58 |
|
57 | 59 | @pytest.mark.filterwarnings("ignore:'ratio' has been deprecated in 0.4")
|
58 |
| -def test_make_imbalance_ratio(): |
59 |
| - # check that using 'ratio' is working |
60 |
| - sampling_strategy = {0: 10, 1: 20, 2: 30} |
61 |
| - X_, y_ = make_imbalance(X, Y, ratio=sampling_strategy) |
62 |
| - assert Counter(y_) == sampling_strategy |
63 |
| - |
64 |
| - sampling_strategy = {0: 10, 1: 20} |
65 |
| - X_, y_ = make_imbalance(X, Y, ratio=sampling_strategy) |
66 |
| - assert Counter(y_) == {0: 10, 1: 20, 2: 50} |
| 60 | +@pytest.mark.parametrize( |
| 61 | + "sampling_strategy, expected_counts", |
| 62 | + [({0: 10, 1: 20, 2: 30}, {0: 10, 1: 20, 2: 30}), |
| 63 | + ({0: 10, 1: 20}, {0: 10, 1: 20, 2: 50})] |
| 64 | +) |
| 65 | +def test_make_imbalance_dict_ratio(iris, sampling_strategy, expected_counts): |
| 66 | + X, y = iris |
| 67 | + _, y_ = make_imbalance(X, y, ratio=sampling_strategy) |
| 68 | + assert Counter(y_) == expected_counts |
0 commit comments