import torch
from torchmetrics.classification import BinaryJaccardIndex, JaccardIndexMetrics
Jaccard Index (aka Intersection over Union, aka IoU)
Binary
target = torch.tensor([[1, 1], [1, 0]])
preds = torch.tensor([[1, 1], [0, 0]])
metric = BinaryJaccardIndex()
metric(preds, target)tensor(0.6667)
Multiclass
target = torch.randint(0, 2, (10, 25, 25))
pred = target.clone()
pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15]
jaccard = JaccardIndex(task="multiclass", num_classes=2)
jaccard(pred, target)tensor(0.9660)