|
29 | 29 | from imblearn.metrics import geometric_mean_score
|
30 | 30 | from imblearn.metrics import make_index_balanced_accuracy
|
31 | 31 | from imblearn.metrics import classification_report_imbalanced
|
| 32 | +from imblearn.metrics import macro_averaged_mean_absolute_error |
32 | 33 |
|
33 | 34 | from imblearn.utils.testing import warns
|
34 | 35 |
|
@@ -498,3 +499,32 @@ def test_classification_report_imbalanced_dict():
|
498 | 499 |
|
499 | 500 | assert outer_keys == expected_outer_keys
|
500 | 501 | assert inner_keys == expected_inner_keys
|
| 502 | + |
| 503 | + |
| 504 | +@pytest.mark.parametrize( |
| 505 | + "y_true, y_pred, expected_ma_mae", |
| 506 | + [ |
| 507 | + ([1, 1, 1, 2, 2, 2], [1, 2, 1, 2, 1, 2], 0.333), |
| 508 | + ([1, 1, 1, 1, 1, 2], [1, 2, 1, 2, 1, 2], 0.2), |
| 509 | + ([1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 3, 1, 2, 1, 1, 2, 3, 3], 0.555), |
| 510 | + ([1, 1, 1, 1, 1, 1, 2, 3, 3], [1, 3, 1, 2, 1, 1, 2, 3, 3], 0.166), |
| 511 | +
|
| 512 | + ], |
| 513 | +) |
| 514 | +def test_macro_averaged_mean_absolute_error(y_true, y_pred, expected_ma_mae): |
| 515 | + ma_mae = macro_averaged_mean_absolute_error(y_true, y_pred) |
| 516 | + assert ma_mae == pytest.approx(expected_ma_mae, rel=R_TOL) |
| 517 | + |
| 518 | + |
| 519 | +def test_macro_averaged_mean_absolute_error_sample_weight(): |
| 520 | + y_true = [1, 1, 1, 2, 2, 2] |
| 521 | + y_pred = [1, 2, 1, 2, 1, 2] |
| 522 | + |
| 523 | + ma_mae_no_weights = macro_averaged_mean_absolute_error(y_true, y_pred) |
| 524 | + |
| 525 | + sample_weight = [1, 1, 1, 1, 1, 1] |
| 526 | + ma_mae_unit_weights = macro_averaged_mean_absolute_error( |
| 527 | + y_true, y_pred, sample_weight=sample_weight, |
| 528 | + ) |
| 529 | + |
| 530 | + assert ma_mae_unit_weights == pytest.approx(ma_mae_no_weights) |
0 commit comments