Skip to content

Commit eed8d95

Browse files
authored
TST: refactor and pytest style (#470)
1 parent 2f0a4b2 commit eed8d95

28 files changed

+313
-543
lines changed

doc/whats_new/v0.0.4.rst

+3
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ Maintenance
104104
- Catch deprecation warning in testing.
105105
:issue:`441` by :user:`Guillaume Lemaitre <glemaitre>`.
106106

107+
- Refactor and impose `pytest` style tests.
108+
:issue:`470` by :user:`Guillaume Lemaitre <glemaitre>`.
109+
107110
Documentation
108111
.............
109112

imblearn/combine/tests/test_smote_enn.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
# Christos Aridas
44
# License: MIT
55

6-
from __future__ import print_function
7-
6+
import pytest
87
import numpy as np
9-
from pytest import raises
108

119
from sklearn.utils.testing import assert_allclose, assert_array_equal
1210

@@ -100,12 +98,12 @@ def test_validate_estimator_default():
10098
assert_array_equal(y_resampled, y_gt)
10199

102100

103-
def test_error_wrong_object():
104-
smote = 'rnd'
105-
enn = 'rnd'
106-
smt = SMOTEENN(smote=smote, random_state=RND_SEED)
107-
with raises(ValueError, match="smote needs to be a SMOTE"):
108-
smt.fit_resample(X, Y)
109-
smt = SMOTEENN(enn=enn, random_state=RND_SEED)
110-
with raises(ValueError, match="enn needs to be an "):
101+
@pytest.mark.parametrize(
102+
"smote_params, err_msg",
103+
[({'smote': 'rnd'}, "smote needs to be a SMOTE"),
104+
({'enn': 'rnd'}, "enn needs to be an ")]
105+
)
106+
def test_error_wrong_object(smote_params, err_msg):
107+
smt = SMOTEENN(**smote_params)
108+
with pytest.raises(ValueError, match=err_msg):
111109
smt.fit_resample(X, Y)

imblearn/combine/tests/test_smote_tomek.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
# Christos Aridas
44
# License: MIT
55

6-
from __future__ import print_function
7-
6+
import pytest
87
import numpy as np
9-
from pytest import raises
108

119
from sklearn.utils.testing import assert_allclose, assert_array_equal
1210

@@ -106,12 +104,12 @@ def test_validate_estimator_default():
106104
assert_array_equal(y_resampled, y_gt)
107105

108106

109-
def test_error_wrong_object():
110-
smote = 'rnd'
111-
tomek = 'rnd'
112-
smt = SMOTETomek(smote=smote, random_state=RND_SEED)
113-
with raises(ValueError, match="smote needs to be a SMOTE"):
114-
smt.fit_resample(X, Y)
115-
smt = SMOTETomek(tomek=tomek, random_state=RND_SEED)
116-
with raises(ValueError, match="tomek needs to be a TomekLinks"):
107+
@pytest.mark.parametrize(
108+
"smote_params, err_msg",
109+
[({'smote': 'rnd'}, "smote needs to be a SMOTE"),
110+
({'tomek': 'rnd'}, "tomek needs to be a TomekLinks")]
111+
)
112+
def test_error_wrong_object(smote_params, err_msg):
113+
smt = SMOTETomek(**smote_params)
114+
with pytest.raises(ValueError, match=err_msg):
117115
smt.fit_resample(X, Y)

imblearn/datasets/tests/test_imbalance.py

+44-42
Original file line numberDiff line numberDiff line change
@@ -3,64 +3,66 @@
33
# Christos Aridas
44
# License: MIT
55

6-
from __future__ import print_function
7-
86
from collections import Counter
97

108
import pytest
119
import numpy as np
1210

13-
from pytest import raises
14-
1511
from sklearn.datasets import load_iris
1612

1713
from imblearn.datasets import make_imbalance
1814

19-
data = load_iris()
20-
X, Y = data.data, data.target
2115

16+
@pytest.fixture
17+
def iris():
18+
return load_iris(return_X_y=True)
2219

23-
def test_make_imbalanced_backcompat():
20+
21+
def test_make_imbalanced_backcompat(iris):
2422
# check an error is raised with we don't pass sampling_strategy and ratio
25-
with raises(TypeError, match="missing 1 required positional argument"):
26-
make_imbalance(X, Y)
23+
with pytest.raises(TypeError, match="missing 1 required positional argument"):
24+
make_imbalance(*iris)
2725

2826

29-
def test_make_imbalance_error():
27+
@pytest.mark.parametrize(
28+
"sampling_strategy, err_msg",
29+
[({0: -100, 1: 50, 2: 50}, "in a class cannot be negative"),
30+
({0: 10, 1: 70}, "should be less or equal to the original"),
31+
('random-string', "has to be a dictionary or a function")]
32+
)
33+
def test_make_imbalance_error(iris, sampling_strategy, err_msg):
3034
# we are reusing part of utils.check_sampling_strategy, however this is not
3135
# cover in the common tests so we will repeat it here
32-
sampling_strategy = {0: -100, 1: 50, 2: 50}
33-
with raises(ValueError, match="in a class cannot be negative"):
34-
make_imbalance(X, Y, sampling_strategy)
35-
sampling_strategy = {0: 10, 1: 70}
36-
with raises(ValueError, match="should be less or equal to the original"):
37-
make_imbalance(X, Y, sampling_strategy)
38-
y_ = np.zeros((X.shape[0], ))
39-
sampling_strategy = {0: 10}
40-
with raises(ValueError, match="needs to have more than 1 class."):
41-
make_imbalance(X, y_, sampling_strategy)
42-
sampling_strategy = 'random-string'
43-
with raises(ValueError, match="has to be a dictionary or a function"):
44-
make_imbalance(X, Y, sampling_strategy)
45-
46-
47-
def test_make_imbalance_dict():
48-
sampling_strategy = {0: 10, 1: 20, 2: 30}
49-
X_, y_ = make_imbalance(X, Y, sampling_strategy=sampling_strategy)
50-
assert Counter(y_) == sampling_strategy
51-
52-
sampling_strategy = {0: 10, 1: 20}
53-
X_, y_ = make_imbalance(X, Y, sampling_strategy=sampling_strategy)
54-
assert Counter(y_) == {0: 10, 1: 20, 2: 50}
36+
X, y = iris
37+
with pytest.raises(ValueError, match=err_msg):
38+
make_imbalance(X, y, sampling_strategy)
39+
40+
41+
def test_make_imbalance_error_single_class(iris):
42+
X, y = iris
43+
y = np.zeros_like(y)
44+
with pytest.raises(ValueError, match="needs to have more than 1 class."):
45+
make_imbalance(X, y, {0: 10})
46+
47+
48+
@pytest.mark.parametrize(
49+
"sampling_strategy, expected_counts",
50+
[({0: 10, 1: 20, 2: 30}, {0: 10, 1: 20, 2: 30}),
51+
({0: 10, 1: 20}, {0: 10, 1: 20, 2: 50})]
52+
)
53+
def test_make_imbalance_dict(iris, sampling_strategy, expected_counts):
54+
X, y = iris
55+
_, y_ = make_imbalance(X, y, sampling_strategy=sampling_strategy)
56+
assert Counter(y_) == expected_counts
5557

5658

5759
@pytest.mark.filterwarnings("ignore:'ratio' has been deprecated in 0.4")
58-
def test_make_imbalance_ratio():
59-
# check that using 'ratio' is working
60-
sampling_strategy = {0: 10, 1: 20, 2: 30}
61-
X_, y_ = make_imbalance(X, Y, ratio=sampling_strategy)
62-
assert Counter(y_) == sampling_strategy
63-
64-
sampling_strategy = {0: 10, 1: 20}
65-
X_, y_ = make_imbalance(X, Y, ratio=sampling_strategy)
66-
assert Counter(y_) == {0: 10, 1: 20, 2: 50}
60+
@pytest.mark.parametrize(
61+
"sampling_strategy, expected_counts",
62+
[({0: 10, 1: 20, 2: 30}, {0: 10, 1: 20, 2: 30}),
63+
({0: 10, 1: 20}, {0: 10, 1: 20, 2: 50})]
64+
)
65+
def test_make_imbalance_dict_ratio(iris, sampling_strategy, expected_counts):
66+
X, y = iris
67+
_, y_ = make_imbalance(X, y, ratio=sampling_strategy)
68+
assert Counter(y_) == expected_counts

imblearn/datasets/tests/test_zenodo.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
# Christos Aridas
77
# License: MIT
88

9-
from imblearn.datasets import fetch_datasets
10-
from sklearn.utils.testing import SkipTest, assert_allclose
9+
import pytest
1110

12-
from pytest import raises
11+
from imblearn.datasets import fetch_datasets
12+
from sklearn.utils.testing import SkipTest
1313

1414
DATASET_SHAPE = {
1515
'ecoli': (336, 7),
@@ -79,19 +79,20 @@ def test_fetch_filter():
7979
assert DATASET_SHAPE['ecoli'] == X1.shape
8080
assert X1.shape == X2.shape
8181

82-
assert_allclose(X1.sum(), X2.sum())
82+
assert X1.sum() == pytest.approx(X2.sum())
8383

8484
y1, y2 = datasets1['ecoli'].target, datasets2['ecoli'].target
8585
assert (X1.shape[0], ) == y1.shape
8686
assert (X1.shape[0], ) == y2.shape
8787

8888

89-
def test_fetch_error():
90-
with raises(ValueError, match='is not a dataset available.'):
91-
fetch_datasets(filter_data=tuple(['rnd']))
92-
with raises(ValueError, match='dataset with the ID='):
93-
fetch_datasets(filter_data=tuple([-1]))
94-
with raises(ValueError, match='dataset with the ID='):
95-
fetch_datasets(filter_data=tuple([100]))
96-
with raises(ValueError, match='value in the tuple'):
97-
fetch_datasets(filter_data=tuple([1.00]))
89+
@pytest.mark.parametrize(
90+
"filter_data, err_msg",
91+
[(('rnf',), "is not a dataset available"),
92+
((-1,), "dataset with the ID="),
93+
((100,), "dataset with the ID="),
94+
((1.00,), "value in the tuple")]
95+
)
96+
def test_fetch_error(filter_data, err_msg):
97+
with pytest.raises(ValueError, match=err_msg):
98+
fetch_datasets(filter_data=filter_data)

imblearn/ensemble/tests/test_balance_cascade.py

-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
# Christos Aridas
44
# License: MIT
55

6-
from __future__ import print_function
7-
86
import numpy as np
97

108
from pytest import raises

imblearn/keras/tests/test_generator.py

+24-16
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,13 @@
1818
from imblearn.keras import BalancedBatchGenerator
1919
from imblearn.keras import balanced_batch_generator
2020

21-
iris = load_iris()
22-
X, y = make_imbalance(iris.data, iris.target, {0: 30, 1: 50, 2: 40})
23-
y = to_categorical(y, 3)
21+
22+
@pytest.fixture
23+
def data():
24+
iris = load_iris()
25+
X, y = make_imbalance(iris.data, iris.target, {0: 30, 1: 50, 2: 40})
26+
y = to_categorical(y, 3)
27+
return X, y
2428

2529

2630
def _build_keras_model(n_classes, n_features):
@@ -31,19 +35,20 @@ def _build_keras_model(n_classes, n_features):
3135
return model
3236

3337

34-
def test_balanced_batch_generator_class_no_return_indices():
38+
def test_balanced_batch_generator_class_no_return_indices(data):
3539
with pytest.raises(ValueError, match='needs to return the indices'):
36-
BalancedBatchGenerator(X, y, sampler=ClusterCentroids(), batch_size=10)
40+
BalancedBatchGenerator(*data, sampler=ClusterCentroids(), batch_size=10)
3741

3842

3943
@pytest.mark.parametrize(
4044
"sampler, sample_weight",
4145
[(None, None),
4246
(RandomOverSampler(), None),
4347
(NearMiss(), None),
44-
(None, np.random.uniform(size=(y.shape[0])))]
48+
(None, np.random.uniform(size=120))]
4549
)
46-
def test_balanced_batch_generator_class(sampler, sample_weight):
50+
def test_balanced_batch_generator_class(data, sampler, sample_weight):
51+
X, y = data
4752
model = _build_keras_model(y.shape[1], X.shape[1])
4853
training_generator = BalancedBatchGenerator(X, y,
4954
sample_weight=sample_weight,
@@ -55,33 +60,35 @@ def test_balanced_batch_generator_class(sampler, sample_weight):
5560

5661

5762
@pytest.mark.parametrize("keep_sparse", [True, False])
58-
def test_balanced_batch_generator_class_sparse(keep_sparse):
63+
def test_balanced_batch_generator_class_sparse(data, keep_sparse):
64+
X, y = data
5965
training_generator = BalancedBatchGenerator(sparse.csr_matrix(X), y,
6066
batch_size=10,
6167
keep_sparse=keep_sparse,
6268
random_state=42)
6369
for idx in range(len(training_generator)):
64-
X_batch, y_batch = training_generator.__getitem__(idx)
70+
X_batch, _ = training_generator.__getitem__(idx)
6571
if keep_sparse:
6672
assert sparse.issparse(X_batch)
6773
else:
6874
assert not sparse.issparse(X_batch)
6975

7076

71-
def test_balanced_batch_generator_function_no_return_indices():
77+
def test_balanced_batch_generator_function_no_return_indices(data):
7278
with pytest.raises(ValueError, match='needs to return the indices'):
7379
balanced_batch_generator(
74-
X, y, sampler=ClusterCentroids(), batch_size=10, random_state=42)
80+
*data, sampler=ClusterCentroids(), batch_size=10, random_state=42)
7581

7682

7783
@pytest.mark.parametrize(
7884
"sampler, sample_weight",
7985
[(None, None),
8086
(RandomOverSampler(), None),
8187
(NearMiss(), None),
82-
(None, np.random.uniform(size=(y.shape[0])))]
88+
(None, np.random.uniform(size=120))]
8389
)
84-
def test_balanced_batch_generator_function(sampler, sample_weight):
90+
def test_balanced_batch_generator_function(data, sampler, sample_weight):
91+
X, y = data
8592
model = _build_keras_model(y.shape[1], X.shape[1])
8693
training_generator, steps_per_epoch = balanced_batch_generator(
8794
X, y, sample_weight=sample_weight, sampler=sampler, batch_size=10,
@@ -92,12 +99,13 @@ def test_balanced_batch_generator_function(sampler, sample_weight):
9299

93100

94101
@pytest.mark.parametrize("keep_sparse", [True, False])
95-
def test_balanced_batch_generator_function_sparse(keep_sparse):
102+
def test_balanced_batch_generator_function_sparse(data, keep_sparse):
103+
X, y = data
96104
training_generator, steps_per_epoch = balanced_batch_generator(
97105
sparse.csr_matrix(X), y, keep_sparse=keep_sparse, batch_size=10,
98106
random_state=42)
99-
for idx in range(steps_per_epoch):
100-
X_batch, y_batch = next(training_generator)
107+
for _ in range(steps_per_epoch):
108+
X_batch, _ = next(training_generator)
101109
if keep_sparse:
102110
assert sparse.issparse(X_batch)
103111
else:

0 commit comments

Comments
 (0)