Skip to content

Commit b6f6794

Browse files
authored
Merge pull request #169 from scikit-learn-contrib/SandroCasagrande-fit_transform
Added fit_transform method to DataFrameMapper
2 parents 7fdc39a + 2aab227 commit b6f6794

File tree

3 files changed

+103
-19
lines changed

3 files changed

+103
-19
lines changed

README.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,10 @@ Example: imputing with a fixed value:
408408

409409
Changelog
410410
---------
411+
Unreleased
412+
**********
413+
* Change behaviour of DataFrameMapper's fit_transform method to invoke each underlying transformers'
414+
native fit_transform if implemented. (#150)
411415

412416
1.7.0 (2018-08-15)
413417
******************
@@ -417,7 +421,6 @@ Changelog
417421
with values other than the mode (#144).
418422
* Preserve input data types when no transform is supplied (#138).
419423

420-
421424
1.6.0 (2017-10-28)
422425
******************
423426
* Add column name to exception during fit/transform (#110).

sklearn_pandas/dataframe_mapper.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,16 @@ def __init__(self, features, default=False, sparse=False, df_out=False,
114114
if (df_out and (sparse or default)):
115115
raise ValueError("Can not use df_out with sparse or default")
116116

117+
def _build(self):
118+
"""
119+
Build attributes built_features and built_default.
120+
"""
121+
if isinstance(self.features, list):
122+
self.built_features = [_build_feature(*f) for f in self.features]
123+
else:
124+
self.built_features = self.features
125+
self.built_default = _build_transformer(self.default)
126+
117127
@property
118128
def _selected_columns(self):
119129
"""
@@ -198,12 +208,7 @@ def fit(self, X, y=None):
198208
y the target vector relative to X, optional
199209
200210
"""
201-
if isinstance(self.features, list):
202-
self.built_features = [_build_feature(*f) for f in self.features]
203-
else:
204-
self.built_features = self.features
205-
206-
self.built_default = _build_transformer(self.default)
211+
self._build()
207212

208213
for columns, transformers, options in self.built_features:
209214
input_df = options.get('input_df', self.input_df)
@@ -273,23 +278,32 @@ def get_dtype(self, ex):
273278
else:
274279
raise TypeError(type(ex))
275280

276-
def transform(self, X):
281+
def _transform(self, X, y=None, do_fit=False):
277282
"""
278-
Transform the given data. Assumes that fit has already been called.
279-
280-
X the data to transform
283+
Transform the given data with possibility to fit in advance.
284+
Avoids code duplication for implementation of transform and
285+
fit_transform.
281286
"""
287+
if do_fit:
288+
self._build()
289+
282290
extracted = []
283291
self.transformed_names_ = []
284292
for columns, transformers, options in self.built_features:
285293
input_df = options.get('input_df', self.input_df)
294+
286295
# columns could be a string or list of
287296
# strings; we don't care because pandas
288297
# will handle either.
289298
Xt = self._get_col_subset(X, columns, input_df)
290299
if transformers is not None:
291300
with add_column_names_to_exception(columns):
292-
Xt = transformers.transform(Xt)
301+
if do_fit and hasattr(transformers, 'fit_transform'):
302+
Xt = _call_fit(transformers.fit_transform, Xt, y)
303+
else:
304+
if do_fit:
305+
_call_fit(transformers.fit, Xt, y)
306+
Xt = transformers.transform(Xt)
293307
extracted.append(_handle_feature(Xt))
294308

295309
alias = options.get('alias')
@@ -302,7 +316,12 @@ def transform(self, X):
302316
Xt = self._get_col_subset(X, unsel_cols, self.input_df)
303317
if self.built_default is not None:
304318
with add_column_names_to_exception(unsel_cols):
305-
Xt = self.built_default.transform(Xt)
319+
if do_fit and hasattr(self.built_default, 'fit_transform'):
320+
Xt = _call_fit(self.built_default.fit_transform, Xt, y)
321+
else:
322+
if do_fit:
323+
_call_fit(self.built_default.fit, Xt, y)
324+
Xt = self.built_default.transform(Xt)
306325
self.transformed_names_ += self.get_names(
307326
unsel_cols, self.built_default, Xt)
308327
else:
@@ -348,3 +367,22 @@ def transform(self, X):
348367
return df_out
349368
else:
350369
return stacked
370+
371+
def transform(self, X):
372+
"""
373+
Transform the given data. Assumes that fit has already been called.
374+
375+
X the data to transform
376+
"""
377+
return self._transform(X)
378+
379+
def fit_transform(self, X, y=None):
380+
"""
381+
Fit a transformation from the pipeline and directly apply
382+
it to the given data.
383+
384+
X the data to fit
385+
386+
y the target vector relative to X, optional
387+
"""
388+
return self._transform(X, y, True)

tests/test_dataframe_mapper.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,50 @@ def test_pca(complex_dataframe):
402402
assert cols[1] == 'feat1_feat2_1'
403403

404404

405+
def test_fit_transform(simple_dataframe):
406+
"""
407+
Check that custom fit_transform methods of the transformers are invoked.
408+
"""
409+
df = simple_dataframe
410+
mock_transformer = Mock()
411+
# return something of measurable length but does nothing
412+
mock_transformer.fit_transform.return_value = np.array([1, 2, 3])
413+
mapper = DataFrameMapper([("a", mock_transformer)])
414+
mapper.fit_transform(df)
415+
assert mock_transformer.fit_transform.called
416+
417+
418+
def test_fit_transform_equiv_mock(simple_dataframe):
419+
"""
420+
Check for equivalent results for code paths fit_transform
421+
versus fit and transform in DataFrameMapper using the mock
422+
transformer which does not implement a custom fit_transform.
423+
"""
424+
df = simple_dataframe
425+
mapper = DataFrameMapper([('a', MockXTransformer())])
426+
transformed_combined = mapper.fit_transform(df)
427+
transformed_separate = mapper.fit(df).transform(df)
428+
assert np.all(transformed_combined == transformed_separate)
429+
430+
431+
def test_fit_transform_equiv_pca(complex_dataframe):
432+
"""
433+
Check for equivalent results for code paths fit_transform
434+
versus fit and transform in DataFrameMapper and transformer
435+
using PCA which implements a custom fit_transform. The
436+
equivalence of both paths in the transformer only can be
437+
asserted since this is tested in the sklearn tests
438+
scikit-learn/sklearn/decomposition/tests/test_pca.py
439+
"""
440+
df = complex_dataframe
441+
mapper = DataFrameMapper(
442+
[(['feat1', 'feat2'], sklearn.decomposition.PCA(2))],
443+
df_out=True)
444+
transformed_combined = mapper.fit_transform(df)
445+
transformed_separate = mapper.fit(df).transform(df)
446+
assert np.allclose(transformed_combined, transformed_separate)
447+
448+
405449
def test_input_df_true_first_transformer(simple_dataframe, monkeypatch):
406450
"""
407451
If input_df is True, the first transformer is passed
@@ -438,7 +482,8 @@ def test_input_df_true_next_transformers(simple_dataframe, monkeypatch):
438482
mapper = DataFrameMapper([
439483
('a', [MockXTransformer(), MockTClassifier()])
440484
], input_df=True)
441-
out = mapper.fit_transform(df)
485+
mapper.fit(df)
486+
out = mapper.transform(df)
442487

443488
args, _ = MockTClassifier().fit.call_args
444489
assert isinstance(args[0], pd.Series)
@@ -537,15 +582,14 @@ def test_get_col_subset_single_column_list(simple_dataframe):
537582

538583
def test_cols_string_array(simple_dataframe):
539584
"""
540-
If an string specified as the columns, the transformer
585+
If a string is specified as the columns, the transformer
541586
is called with a 1-d array as input.
542587
"""
543588
df = simple_dataframe
544589
mock_transformer = Mock()
545-
mock_transformer.transform.return_value = np.array([1, 2, 3]) # do nothing
546590
mapper = DataFrameMapper([("a", mock_transformer)])
547591

548-
mapper.fit_transform(df)
592+
mapper.fit(df)
549593
args, kwargs = mock_transformer.fit.call_args
550594
assert args[0].shape == (3,)
551595

@@ -557,10 +601,9 @@ def test_cols_list_column_vector(simple_dataframe):
557601
"""
558602
df = simple_dataframe
559603
mock_transformer = Mock()
560-
mock_transformer.transform.return_value = np.array([1, 2, 3]) # do nothing
561604
mapper = DataFrameMapper([(["a"], mock_transformer)])
562605

563-
mapper.fit_transform(df)
606+
mapper.fit(df)
564607
args, kwargs = mock_transformer.fit.call_args
565608
assert args[0].shape == (3, 1)
566609

0 commit comments

Comments
 (0)