Skip to content

Commit 5ab92db

Browse files
authored
Correct functional.py _compute_metric (#659)
Correct functional.py _compute_metric when reduction == "weighted"
1 parent b3cbb75 commit 5ab92db

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

segmentation_models_pytorch/metrics/functional.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -260,14 +260,23 @@ def _compute_metric(
260260
tn = tn.sum()
261261
score = metric_fn(tp, fp, fn, tn, **metric_kwargs)
262262

263-
elif reduction == "macro" or reduction == "weighted":
263+
elif reduction == "macro" :
264264
tp = tp.sum(0)
265265
fp = fp.sum(0)
266266
fn = fn.sum(0)
267267
tn = tn.sum(0)
268268
score = metric_fn(tp, fp, fn, tn, **metric_kwargs)
269269
score = _handle_zero_division(score, zero_division)
270270
score = (score * class_weights).mean()
271+
272+
elif reduction == "weighted":
273+
tp = tp.sum(0)
274+
fp = fp.sum(0)
275+
fn = fn.sum(0)
276+
tn = tn.sum(0)
277+
score = metric_fn(tp, fp, fn, tn, **metric_kwargs)
278+
score = _handle_zero_division(score, zero_division)
279+
score = (score * class_weights).sum()
271280

272281
elif reduction == "micro-imagewise":
273282
tp = tp.sum(1)

0 commit comments

Comments
 (0)