Skip to content

Commit 0b48def

Browse files
AurelienMassiotGitNameglemaitre
authored
FEA Add macro-averaged mean absolute error (#780)
Co-authored-by: GitName <[email protected]> Co-authored-by: Guillaume Lemaitre <[email protected]>
1 parent a9e1121 commit 0b48def

File tree

7 files changed

+137
-1
lines changed

7 files changed

+137
-1
lines changed

doc/api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ Imbalance-learn provides some fast-prototyping tools.
215215
metrics.sensitivity_score
216216
metrics.specificity_score
217217
metrics.geometric_mean_score
218+
metrics.macro_averaged_mean_absolute_error
218219
metrics.make_index_balanced_accuracy
219220

220221
.. _datasets_ref:

doc/bibtex/refs.bib

+16
Original file line numberDiff line numberDiff line change
@@ -207,4 +207,20 @@ @article{torelli2014rose
207207
issn = {1573-756X},
208208
url = {https://doi.org/10.1007/s10618-012-0295-5},
209209
doi = {10.1007/s10618-012-0295-5}
210+
}
211+
212+
@article{esuli2009ordinal,
213+
author = {A. Esuli and S. Baccianella and F. Sebastiani},
214+
title = {Evaluation Measures for Ordinal Regression},
215+
journal = {Intelligent Systems Design and Applications, International Conference on},
216+
year = {2009},
217+
volume = {1},
218+
issn = {},
219+
pages = {283-287},
220+
keywords = {ordinal regression;ordinal classification;evaluation measures;class imbalance;product reviews},
221+
doi = {10.1109/ISDA.2009.230},
222+
url = {https://doi.ieeecomputersociety.org/10.1109/ISDA.2009.230},
223+
publisher = {IEEE Computer Society},
224+
address = {Los Alamitos, CA, USA},
225+
month = {dec}
210226
}

doc/metrics.rst

+12
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,18 @@ The :func:`make_index_balanced_accuracy` :cite:`garcia2012effectiveness` can
4545
wrap any metric and give more importance to a specific class using the
4646
parameter ``alpha``.
4747

48+
.. _macro_averaged_mean_absolute_error:
49+
50+
Macro-Averaged Mean Absolute Error (MA-MAE)
51+
-------------------------------------------
52+
53+
Ordinal classification is used when there is a rank among classes, for example
54+
levels of functionality or movie ratings.
55+
56+
The :func:`macro_averaged_mean_absolute_error` :cite:`esuli2009ordinal` is used
57+
for imbalanced ordinal classification. The mean absolute error is computed for
58+
each class and averaged over classes, giving an equal weight to each class.
59+
4860
.. _classification_report:
4961

5062
Summary of important metrics

doc/whats_new/v0.7.rst

+6
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ Enhancements
7676
dictionary instead of a string.
7777
:pr:`770` by :user:`Guillaume Lemaitre <glemaitre>`.
7878

79+
- Add the the function
80+
:func:`imblearn.metrics.macro_averaged_mean_absolute_error` returning the
81+
average across class of the MAE. This metric is used in ordinal
82+
classification.
83+
:pr:`780` by :user:`Aurélien Massiot <AurelienMassiot>`.
84+
7985
Deprecation
8086
...........
8187

imblearn/metrics/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ._classification import geometric_mean_score
1010
from ._classification import make_index_balanced_accuracy
1111
from ._classification import classification_report_imbalanced
12+
from ._classification import macro_averaged_mean_absolute_error
1213

1314
__all__ = [
1415
"sensitivity_specificity_support",
@@ -17,4 +18,5 @@
1718
"geometric_mean_score",
1819
"make_index_balanced_accuracy",
1920
"classification_report_imbalanced",
21+
"macro_averaged_mean_absolute_error",
2022
]

imblearn/metrics/_classification.py

+70-1
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,16 @@
1818
import numpy as np
1919
import scipy as sp
2020

21+
from sklearn.metrics import mean_absolute_error
2122
from sklearn.metrics import precision_recall_fscore_support
2223
from sklearn.metrics._classification import _check_targets
2324
from sklearn.metrics._classification import _prf_divide
24-
2525
from sklearn.preprocessing import LabelEncoder
2626
from sklearn.utils.multiclass import unique_labels
27+
from sklearn.utils.validation import (
28+
check_consistent_length,
29+
column_or_1d,
30+
)
2731

2832
try:
2933
from inspect import signature
@@ -997,3 +1001,68 @@ class 2 1.00 0.67 1.00 0.80 0.82 0.64\
9971001
if output_dict:
9981002
return report_dict
9991003
return report
1004+
1005+
1006+
def macro_averaged_mean_absolute_error(y_true, y_pred, *, sample_weight=None):
1007+
"""Compute Macro-Averaged Mean Absolute Error (MA-MAE)
1008+
for imbalanced ordinal classification.
1009+
1010+
This function computes each MAE for each class and average them,
1011+
giving an equal weight to each class.
1012+
1013+
Read more in the :ref:`User Guide <macro_averaged_mean_absolute_error>`.
1014+
1015+
Parameters
1016+
----------
1017+
y_true : array-like of shape (n_samples,) or (n_samples, n_outputs)
1018+
Ground truth (correct) target values.
1019+
1020+
y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
1021+
Estimated targets as returned by a classifier.
1022+
1023+
sample_weight : array-like of shape (n_samples,), default=None
1024+
Sample weights.
1025+
1026+
Returns
1027+
-------
1028+
loss : float or ndarray of floats
1029+
Macro-Averaged MAE output is non-negative floating point.
1030+
The best value is 0.0.
1031+
1032+
Examples
1033+
--------
1034+
>>> import numpy as np
1035+
>>> from sklearn.metrics import mean_absolute_error
1036+
>>> from imblearn.metrics import macro_averaged_mean_absolute_error
1037+
>>> y_true_balanced = [1, 1, 2, 2]
1038+
>>> y_true_imbalanced = [1, 2, 2, 2]
1039+
>>> y_pred = [1, 2, 1, 2]
1040+
>>> mean_absolute_error(y_true_balanced, y_pred)
1041+
0.5
1042+
>>> mean_absolute_error(y_true_imbalanced, y_pred)
1043+
0.25
1044+
>>> macro_averaged_mean_absolute_error(y_true_balanced, y_pred)
1045+
0.5
1046+
>>> macro_averaged_mean_absolute_error(y_true_imbalanced, y_pred)
1047+
0.16666666666666666
1048+
"""
1049+
_, y_true, y_pred = _check_targets(y_true, y_pred)
1050+
if sample_weight is not None:
1051+
sample_weight = column_or_1d(sample_weight)
1052+
else:
1053+
sample_weight = np.ones(y_true.shape)
1054+
check_consistent_length(y_true, y_pred, sample_weight)
1055+
labels = unique_labels(y_true, y_pred)
1056+
mae = []
1057+
for possible_class in labels:
1058+
indices = np.flatnonzero(y_true == possible_class)
1059+
1060+
mae.append(
1061+
mean_absolute_error(
1062+
y_true[indices],
1063+
y_pred[indices],
1064+
sample_weight=sample_weight[indices],
1065+
)
1066+
)
1067+
1068+
return np.sum(mae) / len(mae)

imblearn/metrics/tests/test_classification.py

+30
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from imblearn.metrics import geometric_mean_score
3030
from imblearn.metrics import make_index_balanced_accuracy
3131
from imblearn.metrics import classification_report_imbalanced
32+
from imblearn.metrics import macro_averaged_mean_absolute_error
3233

3334
from imblearn.utils.testing import warns
3435

@@ -498,3 +499,32 @@ def test_classification_report_imbalanced_dict():
498499

499500
assert outer_keys == expected_outer_keys
500501
assert inner_keys == expected_inner_keys
502+
503+
504+
@pytest.mark.parametrize(
505+
"y_true, y_pred, expected_ma_mae",
506+
[
507+
([1, 1, 1, 2, 2, 2], [1, 2, 1, 2, 1, 2], 0.333),
508+
([1, 1, 1, 1, 1, 2], [1, 2, 1, 2, 1, 2], 0.2),
509+
([1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 3, 1, 2, 1, 1, 2, 3, 3], 0.555),
510+
([1, 1, 1, 1, 1, 1, 2, 3, 3], [1, 3, 1, 2, 1, 1, 2, 3, 3], 0.166),
511+
512+
],
513+
)
514+
def test_macro_averaged_mean_absolute_error(y_true, y_pred, expected_ma_mae):
515+
ma_mae = macro_averaged_mean_absolute_error(y_true, y_pred)
516+
assert ma_mae == pytest.approx(expected_ma_mae, rel=R_TOL)
517+
518+
519+
def test_macro_averaged_mean_absolute_error_sample_weight():
520+
y_true = [1, 1, 1, 2, 2, 2]
521+
y_pred = [1, 2, 1, 2, 1, 2]
522+
523+
ma_mae_no_weights = macro_averaged_mean_absolute_error(y_true, y_pred)
524+
525+
sample_weight = [1, 1, 1, 1, 1, 1]
526+
ma_mae_unit_weights = macro_averaged_mean_absolute_error(
527+
y_true, y_pred, sample_weight=sample_weight,
528+
)
529+
530+
assert ma_mae_unit_weights == pytest.approx(ma_mae_no_weights)

0 commit comments

Comments
 (0)