If logit shape is [N, C, d1, d2] (where N is the number of images and C is the number of classes to predict), then target (i.e. label) shape must be [N, d1, d2].

Then the loss will be calculated as: nn.CrossEntropyLoss(weight)(logit, target)

weight is a tensor for unbalanced datasets. Must be tensor.float.

When using masking, one should use masked:

def masked_cross_entropy_loss_fn(y_pred, y_true):
    out_of_bounds_mask = (y_true == out_of_bounds_value) # find out the out-of-bounds
    return nn.CrossEntropyLoss(weight=weights)(
        y_pred.masked_fill(out_of_bounds.unsqueeze(axis=1), 0), 
        y_true.masked_fill(out_of_bounds, 0)