Skip to content

FEA Add macro-averaged mean absolute error #780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ Imbalance-learn provides some fast-prototyping tools.
metrics.sensitivity_score
metrics.specificity_score
metrics.geometric_mean_score
metrics.macro_averaged_mean_absolute_error
metrics.make_index_balanced_accuracy

.. _datasets_ref:
Expand Down
16 changes: 16 additions & 0 deletions doc/bibtex/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -207,4 +207,20 @@ @article{torelli2014rose
issn = {1573-756X},
url = {https://doi.org/10.1007/s10618-012-0295-5},
doi = {10.1007/s10618-012-0295-5}
}

@article{esuli2009ordinal,
author = {A. Esuli and S. Baccianella and F. Sebastiani},
title = {Evaluation Measures for Ordinal Regression},
journal = {Intelligent Systems Design and Applications, International Conference on},
year = {2009},
volume = {1},
issn = {},
pages = {283-287},
keywords = {ordinal regression;ordinal classification;evaluation measures;class imbalance;product reviews},
doi = {10.1109/ISDA.2009.230},
url = {https://doi.ieeecomputersociety.org/10.1109/ISDA.2009.230},
publisher = {IEEE Computer Society},
address = {Los Alamitos, CA, USA},
month = {dec}
}
12 changes: 12 additions & 0 deletions doc/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ The :func:`make_index_balanced_accuracy` :cite:`garcia2012effectiveness` can
wrap any metric and give more importance to a specific class using the
parameter ``alpha``.

.. _macro_averaged_mean_absolute_error:

Macro-Averaged Mean Absolute Error (MA-MAE)
-------------------------------------------

Ordinal classification is used when there is a rank among classes, for example
levels of functionality or movie ratings.

The :func:`macro_averaged_mean_absolute_error` :cite:`esuli2009ordinal` is used
for imbalanced ordinal classification. The mean absolute error is computed for
each class and averaged over classes, giving an equal weight to each class.

.. _classification_report:

Summary of important metrics
Expand Down
6 changes: 6 additions & 0 deletions doc/whats_new/v0.7.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ Enhancements
dictionary instead of a string.
:pr:`770` by :user:`Guillaume Lemaitre <glemaitre>`.

- Add the the function
:func:`imblearn.metrics.macro_averaged_mean_absolute_error` returning the
average across class of the MAE. This metric is used in ordinal
classification.
:pr:`780` by :user:`Aurélien Massiot <AurelienMassiot>`.

Deprecation
...........

Expand Down
2 changes: 2 additions & 0 deletions imblearn/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ._classification import geometric_mean_score
from ._classification import make_index_balanced_accuracy
from ._classification import classification_report_imbalanced
from ._classification import macro_averaged_mean_absolute_error

__all__ = [
"sensitivity_specificity_support",
Expand All @@ -17,4 +18,5 @@
"geometric_mean_score",
"make_index_balanced_accuracy",
"classification_report_imbalanced",
"macro_averaged_mean_absolute_error",
]
71 changes: 70 additions & 1 deletion imblearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
import numpy as np
import scipy as sp

from sklearn.metrics import mean_absolute_error
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics._classification import _check_targets
from sklearn.metrics._classification import _prf_divide

from sklearn.preprocessing import LabelEncoder
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import (
check_consistent_length,
column_or_1d,
)

try:
from inspect import signature
Expand Down Expand Up @@ -997,3 +1001,68 @@ class 2 1.00 0.67 1.00 0.80 0.82 0.64\
if output_dict:
return report_dict
return report


def macro_averaged_mean_absolute_error(y_true, y_pred, *, sample_weight=None):
"""Compute Macro-Averaged Mean Absolute Error (MA-MAE)
for imbalanced ordinal classification.

This function computes each MAE for each class and average them,
giving an equal weight to each class.

Read more in the :ref:`User Guide <macro_averaged_mean_absolute_error>`.

Parameters
----------
y_true : array-like of shape (n_samples,) or (n_samples, n_outputs)
Ground truth (correct) target values.

y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
Estimated targets as returned by a classifier.

sample_weight : array-like of shape (n_samples,), default=None
Sample weights.

Returns
-------
loss : float or ndarray of floats
Macro-Averaged MAE output is non-negative floating point.
The best value is 0.0.

Examples
--------
>>> import numpy as np
>>> from sklearn.metrics import mean_absolute_error
>>> from imblearn.metrics import macro_averaged_mean_absolute_error
>>> y_true_balanced = [1, 1, 2, 2]
>>> y_true_imbalanced = [1, 2, 2, 2]
>>> y_pred = [1, 2, 1, 2]
>>> mean_absolute_error(y_true_balanced, y_pred)
0.5
>>> mean_absolute_error(y_true_imbalanced, y_pred)
0.25
>>> macro_averaged_mean_absolute_error(y_true_balanced, y_pred)
0.5
>>> macro_averaged_mean_absolute_error(y_true_imbalanced, y_pred)
0.16666666666666666
"""
_, y_true, y_pred = _check_targets(y_true, y_pred)
if sample_weight is not None:
sample_weight = column_or_1d(sample_weight)
else:
sample_weight = np.ones(y_true.shape)
check_consistent_length(y_true, y_pred, sample_weight)
labels = unique_labels(y_true, y_pred)
mae = []
for possible_class in labels:
indices = np.flatnonzero(y_true == possible_class)

mae.append(
mean_absolute_error(
y_true[indices],
y_pred[indices],
sample_weight=sample_weight[indices],
)
)

return np.sum(mae) / len(mae)
30 changes: 30 additions & 0 deletions imblearn/metrics/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from imblearn.metrics import geometric_mean_score
from imblearn.metrics import make_index_balanced_accuracy
from imblearn.metrics import classification_report_imbalanced
from imblearn.metrics import macro_averaged_mean_absolute_error

from imblearn.utils.testing import warns

Expand Down Expand Up @@ -498,3 +499,32 @@ def test_classification_report_imbalanced_dict():

assert outer_keys == expected_outer_keys
assert inner_keys == expected_inner_keys


@pytest.mark.parametrize(
"y_true, y_pred, expected_ma_mae",
[
([1, 1, 1, 2, 2, 2], [1, 2, 1, 2, 1, 2], 0.333),
([1, 1, 1, 1, 1, 2], [1, 2, 1, 2, 1, 2], 0.2),
([1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 3, 1, 2, 1, 1, 2, 3, 3], 0.555),
([1, 1, 1, 1, 1, 1, 2, 3, 3], [1, 3, 1, 2, 1, 1, 2, 3, 3], 0.166),

],
)
def test_macro_averaged_mean_absolute_error(y_true, y_pred, expected_ma_mae):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we introduce labels, we will need another test with a bit more corner cases.
Otherwise, I think this is good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above for labels.

ma_mae = macro_averaged_mean_absolute_error(y_true, y_pred)
assert ma_mae == pytest.approx(expected_ma_mae, rel=R_TOL)


def test_macro_averaged_mean_absolute_error_sample_weight():
y_true = [1, 1, 1, 2, 2, 2]
y_pred = [1, 2, 1, 2, 1, 2]

ma_mae_no_weights = macro_averaged_mean_absolute_error(y_true, y_pred)

sample_weight = [1, 1, 1, 1, 1, 1]
ma_mae_unit_weights = macro_averaged_mean_absolute_error(
y_true, y_pred, sample_weight=sample_weight,
)

assert ma_mae_unit_weights == pytest.approx(ma_mae_no_weights)