Skip to content

Commit d989faa

Browse files
authored
Modify Jaccard, Dice and Tversky losses (#927)
* Modify Jaccard, Dice and Tversky losses * Modify the Tversky loss
1 parent a669734 commit d989faa

File tree

1 file changed

+18
-24
lines changed

1 file changed

+18
-24
lines changed

segmentation_models_pytorch/losses/_functional.py

+18-24
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Optional
55

66
import torch
7+
import torch.linalg as LA
78
import torch.nn.functional as F
89

910
__all__ = [
@@ -157,15 +158,7 @@ def soft_jaccard_score(
157158
dims=None,
158159
) -> torch.Tensor:
159160
assert output.size() == target.size()
160-
if dims is not None:
161-
intersection = torch.sum(output * target, dim=dims)
162-
cardinality = torch.sum(output + target, dim=dims)
163-
else:
164-
intersection = torch.sum(output * target)
165-
cardinality = torch.sum(output + target)
166-
167-
union = cardinality - intersection
168-
jaccard_score = (intersection + smooth) / (union + smooth).clamp_min(eps)
161+
jaccard_score = soft_tversky_score(output, target, 1.0, 1.0, smooth, eps, dims)
169162
return jaccard_score
170163

171164

@@ -177,13 +170,7 @@ def soft_dice_score(
177170
dims=None,
178171
) -> torch.Tensor:
179172
assert output.size() == target.size()
180-
if dims is not None:
181-
intersection = torch.sum(output * target, dim=dims)
182-
cardinality = torch.sum(output + target, dim=dims)
183-
else:
184-
intersection = torch.sum(output * target)
185-
cardinality = torch.sum(output + target)
186-
dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps)
173+
dice_score = soft_tversky_score(output, target, 0.5, 0.5, smooth, eps, dims)
187174
return dice_score
188175

189176

@@ -196,15 +183,22 @@ def soft_tversky_score(
196183
eps: float = 1e-7,
197184
dims=None,
198185
) -> torch.Tensor:
186+
"""Tversky loss
187+
188+
References:
189+
https://arxiv.org/pdf/2302.05666
190+
https://arxiv.org/pdf/2303.16296
191+
192+
"""
199193
assert output.size() == target.size()
200-
if dims is not None:
201-
intersection = torch.sum(output * target, dim=dims) # TP
202-
fp = torch.sum(output * (1.0 - target), dim=dims)
203-
fn = torch.sum((1 - output) * target, dim=dims)
204-
else:
205-
intersection = torch.sum(output * target) # TP
206-
fp = torch.sum(output * (1.0 - target))
207-
fn = torch.sum((1 - output) * target)
194+
195+
output_sum = torch.sum(output, dim=dims)
196+
target_sum = torch.sum(target, dim=dims)
197+
difference = LA.vector_norm(output - target, ord=1, dim=dims)
198+
199+
intersection = (output_sum + target_sum - difference) / 2 # TP
200+
fp = output_sum - intersection
201+
fn = target_sum - intersection
208202

209203
tversky_score = (intersection + smooth) / (
210204
intersection + alpha * fp + beta * fn + smooth

0 commit comments

Comments
 (0)