Skip to content
This repository was archived by the owner on Dec 6, 2023. It is now read-only.

Commit d79cb54

Browse files
authored
Fix compatibility with upcoming scikit-learn release and remove unused imports (#164)
1 parent 55a8003 commit d79cb54

File tree

9 files changed

+8
-16
lines changed

9 files changed

+8
-16
lines changed

lightning/impl/adagrad.py

-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import numpy as np
55

66
from sklearn.utils import check_random_state
7-
from sklearn.preprocessing import LabelBinarizer
87
from six.moves import xrange
98

109
from .base import BaseClassifier, BaseRegressor

lightning/impl/datasets/samples_generator.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from itertools import product
2+
13
import numpy as np
24
import scipy.sparse as sp
35
from six.moves import xrange
@@ -195,9 +197,6 @@ def make_classification(n_samples=100, n_features=20, n_informative=2,
195197
.. [1] I. Guyon, "Design of experiments for the NIPS 2003 variable
196198
selection benchmark", 2003.
197199
"""
198-
from itertools import product
199-
from sklearn.utils import shuffle as util_shuffle
200-
201200
generator = check_random_state(random_state)
202201

203202
# Count features, clusters and samples
@@ -308,7 +307,7 @@ def make_classification(n_samples=100, n_features=20, n_informative=2,
308307

309308
# Randomly permute samples and features
310309
if shuffle:
311-
X, y = util_shuffle(X, y, random_state=generator)
310+
X, y = shuffle_func(X, y, random_state=generator)
312311

313312
indices = np.arange(n_features)
314313
generator.shuffle(indices)

lightning/impl/dual_cd.py

-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import numpy as np
1414

15-
from sklearn.preprocessing import LabelBinarizer
1615
from sklearn.preprocessing import add_dummy_feature
1716
from six.moves import xrange
1817

lightning/impl/sdca.py

-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import numpy as np
55

66
from sklearn.utils import check_random_state
7-
from sklearn.preprocessing import LabelBinarizer
87
from six.moves import xrange
98

109
from .base import BaseClassifier, BaseRegressor

lightning/impl/svrg.py

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import numpy as np
55

6-
from sklearn.preprocessing import LabelBinarizer
76
from six.moves import xrange
87

98
from .base import BaseClassifier, BaseRegressor

lightning/impl/tests/test_fista.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
n_classes=3, random_state=0)
2121
bin_csr = sp.csr_matrix(bin_dense)
2222
mult_csr = sp.csr_matrix(mult_dense)
23-
digit = load_digits(2)
23+
digit = load_digits(n_class=2)
2424

2525

2626
def test_fista_multiclass_l1l2():

lightning/impl/tests/test_primal_cd.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import scipy.sparse as sp
33

44
from sklearn.datasets import load_digits
5-
from sklearn.metrics.pairwise import pairwise_kernels
65
from sklearn.preprocessing import LabelBinarizer
76
from six.moves import xrange
87

@@ -20,7 +19,7 @@
2019
n_classes=3, random_state=0)
2120
mult_csc = sp.csc_matrix(mult_dense)
2221

23-
digit = load_digits(2)
22+
digit = load_digits(n_class=2)
2423

2524

2625
def test_fit_linear_binary_l1r():

lightning/impl/tests/test_sag.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from scipy import sparse
33

44
from sklearn.datasets import load_iris, make_classification
5-
from sklearn.preprocessing import LabelBinarizer
65

76
from lightning.impl.base import BaseClassifier
87
from lightning.impl.dataset_fast import get_dataset
@@ -224,7 +223,7 @@ def test_sag_dataset():
224223

225224

226225
def test_sag_score():
227-
X, y = make_classification(1000, random_state=0)
226+
X, y = make_classification(n_samples=1000, random_state=0)
228227

229228
pysag = PySAGClassifier(eta=1e-3, alpha=0.0, beta=0.0, max_iter=10,
230229
random_state=0)
@@ -238,7 +237,7 @@ def test_sag_score():
238237

239238
def test_sag_proba():
240239
n_samples = 10
241-
X, y = make_classification(n_samples, random_state=0)
240+
X, y = make_classification(n_samples=n_samples, random_state=0)
242241
sag = SAGClassifier(eta=1e-3, alpha=0.0, beta=0.0, max_iter=10,
243242
loss='log', random_state=0)
244243
sag.fit(X, y)
@@ -274,7 +273,7 @@ def test_l2_regularized_sag():
274273

275274

276275
def test_saga_score():
277-
X, y = make_classification(1000, random_state=0)
276+
X, y = make_classification(n_samples=1000, random_state=0)
278277

279278
pysaga = PySAGAClassifier(eta=1e-3, alpha=0.0, beta=0.0, max_iter=1,
280279
penalty=None, random_state=0)

lightning/impl/tests/test_svrg.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import numpy as np
22

33
from sklearn.datasets import load_iris
4-
from sklearn.base import BaseEstimator, ClassifierMixin
54

65
from lightning.classification import SVRGClassifier
76
from lightning.regression import SVRGRegressor

0 commit comments

Comments
 (0)