Metrics

Jaccard Index (aka Intersection over Union, aka IoU)

import torch
from torchmetrics.classification import BinaryJaccardIndex, JaccardIndex

Binary

target = torch.tensor([[1, 1], [1, 0]])
preds = torch.tensor([[1, 1], [0, 0]])
metric = BinaryJaccardIndex()
metric(preds, target)
tensor(0.6667)

Multiclass

target = torch.randint(0, 2, (10, 25, 25))
pred = target.clone()
pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15]
jaccard = JaccardIndex(task="multiclass", num_classes=2)
jaccard(pred, target)
tensor(0.9660)