4
4
from typing import Optional
5
5
6
6
import torch
7
+ import torch .linalg as LA
7
8
import torch .nn .functional as F
8
9
9
10
__all__ = [
@@ -157,15 +158,7 @@ def soft_jaccard_score(
157
158
dims = None ,
158
159
) -> torch .Tensor :
159
160
assert output .size () == target .size ()
160
- if dims is not None :
161
- intersection = torch .sum (output * target , dim = dims )
162
- cardinality = torch .sum (output + target , dim = dims )
163
- else :
164
- intersection = torch .sum (output * target )
165
- cardinality = torch .sum (output + target )
166
-
167
- union = cardinality - intersection
168
- jaccard_score = (intersection + smooth ) / (union + smooth ).clamp_min (eps )
161
+ jaccard_score = soft_tversky_score (output , target , 1.0 , 1.0 , smooth , eps , dims )
169
162
return jaccard_score
170
163
171
164
@@ -177,13 +170,7 @@ def soft_dice_score(
177
170
dims = None ,
178
171
) -> torch .Tensor :
179
172
assert output .size () == target .size ()
180
- if dims is not None :
181
- intersection = torch .sum (output * target , dim = dims )
182
- cardinality = torch .sum (output + target , dim = dims )
183
- else :
184
- intersection = torch .sum (output * target )
185
- cardinality = torch .sum (output + target )
186
- dice_score = (2.0 * intersection + smooth ) / (cardinality + smooth ).clamp_min (eps )
173
+ dice_score = soft_tversky_score (output , target , 0.5 , 0.5 , smooth , eps , dims )
187
174
return dice_score
188
175
189
176
@@ -196,15 +183,22 @@ def soft_tversky_score(
196
183
eps : float = 1e-7 ,
197
184
dims = None ,
198
185
) -> torch .Tensor :
186
+ """Tversky loss
187
+
188
+ References:
189
+ https://arxiv.org/pdf/2302.05666
190
+ https://arxiv.org/pdf/2303.16296
191
+
192
+ """
199
193
assert output .size () == target .size ()
200
- if dims is not None :
201
- intersection = torch .sum (output * target , dim = dims ) # TP
202
- fp = torch .sum (output * ( 1.0 - target ) , dim = dims )
203
- fn = torch . sum (( 1 - output ) * target , dim = dims )
204
- else :
205
- intersection = torch . sum ( output * target ) # TP
206
- fp = torch . sum ( output * ( 1.0 - target ))
207
- fn = torch . sum (( 1 - output ) * target )
194
+
195
+ output_sum = torch .sum (output , dim = dims )
196
+ target_sum = torch .sum (target , dim = dims )
197
+ difference = LA . vector_norm ( output - target , ord = 1 , dim = dims )
198
+
199
+ intersection = ( output_sum + target_sum - difference ) / 2 # TP
200
+ fp = output_sum - intersection
201
+ fn = target_sum - intersection
208
202
209
203
tversky_score = (intersection + smooth ) / (
210
204
intersection + alpha * fp + beta * fn + smooth
0 commit comments