Source code for towbintools.deep_learning.utils.loss

from typing import Optional

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F


[docs] class FocalTverskyLoss(nn.Module): """ Focal Tversky loss for binary segmentation with class-imbalance handling. Combines the Tversky index (a generalization of Dice that independently weights false positives and false negatives) with a focal exponent to down-weight easy examples and focus training on hard ones. Reference: Abraham & Khan (2019), arXiv:1810.07842. """ def __init__( self, ignore_index=-1, smooth=100, alpha=0.3, beta=0.7, gamma=4 / 3, activation=True, ): """ Parameters: ignore_index (int, optional): Target value to ignore in the loss computation. (default: -1) smooth (float, optional): Smoothing constant added to numerator and denominator of the Tversky index to avoid division by zero. (default: 100) alpha (float, optional): Weight for false positives in the Tversky index. (default: 0.3) beta (float, optional): Weight for false negatives in the Tversky index. (default: 0.7) gamma (float, optional): Focal exponent applied to ``(1 - Tversky)``. (default: 4/3) activation (bool, optional): If ``True``, apply sigmoid to inputs before computing the loss. (default: True) """ super().__init__() self.ignore_index = ignore_index self.smooth = smooth self.alpha = alpha self.beta = beta self.gamma = gamma self.activation = activation
[docs] def forward(self, inputs, targets): if self.activation: inputs = torch.sigmoid(inputs) # flatten label and prediction tensors inputs = inputs.view(-1) targets = targets.view(-1) smooth = self.smooth alpha = self.alpha beta = self.beta gamma = self.gamma # remove elements with ignore_index from the loss calculation mask = targets != self.ignore_index inputs = inputs[mask] targets = targets[mask] # True Positives, False Positives & False Negatives TP = (inputs * targets).sum() FP = ((1 - targets) * inputs).sum() FN = (targets * (1 - inputs)).sum() Tversky = (TP + smooth) / (TP + alpha * FP + beta * FN + smooth) FocalTversky = (1 - Tversky + 1e-10) ** gamma return FocalTversky
# adapted from https://github.com/AdeelH/pytorch-multi-class-focal-loss/blob/master/focal_loss.py
[docs] class MultiClassFocalLoss(nn.Module): """ Focal loss for multi-class classification. Extends cross-entropy with a modulating factor ``(1 - p_t)^gamma`` that down-weights well-classified examples, focusing training on hard negatives. Supports per-class weights and an ignore index. Reference: Lin et al. (2017), arXiv:1708.02002. """ def __init__( self, alpha: Optional[Tensor] = None, gamma: float = 2.0, reduction: str = "mean", ignore_index: int = -1, ): """ Parameters: alpha (Tensor, optional): Per-class weight tensor of shape ``(n_classes,)``. (default: None) gamma (float, optional): Focusing parameter; 0 recovers standard cross-entropy. (default: 2.0) reduction (str, optional): Reduction to apply: ``"mean"``, ``"sum"``, or ``"none"``. (default: ``"mean"``) ignore_index (int, optional): Target value to ignore. (default: -1) Raises: ValueError: If ``reduction`` is not one of the supported values. """ if reduction not in ("mean", "sum", "none"): raise ValueError('Reduction must be one of: "mean", "sum", "none".') super().__init__() self.alpha = alpha self.gamma = gamma self.ignore_index = ignore_index self.reduction = reduction self.ce_loss = nn.NLLLoss( weight=alpha, reduction="none", ignore_index=ignore_index ) def __repr__(self): arg_keys = ["alpha", "gamma", "ignore_index", "reduction"] arg_vals = [self.__dict__[k] for k in arg_keys] arg_strs = [f"{k}={v!r}" for k, v in zip(arg_keys, arg_vals)] arg_str = ", ".join(arg_strs) return f"{type(self).__name__}({arg_str})"
[docs] def forward(self, x: Tensor, y: Tensor) -> Tensor: if x.ndim > 2: # (N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C) c = x.shape[1] x = x.permute(0, *range(2, x.ndim), 1).reshape(-1, c) # (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,) y = y.view(-1) unignored_mask = y != self.ignore_index y = y[unignored_mask] if len(y) == 0: return torch.tensor(0.0) x = x[unignored_mask] # get probabilities for focal term calculation log_p = F.log_softmax(x, dim=-1) ce = self.ce_loss(log_p, y) y = y.long() # get true class column from each row log_pt = torch.gather(log_p, 1, y.unsqueeze(1)).squeeze(1) # compute focal term: (1 - pt)^gamma pt = log_pt.exp() focal_term = (1 - pt) ** self.gamma # the full loss: -alpha * ((1 - pt)^gamma) * log(pt) loss = focal_term * ce if self.reduction == "mean": loss = loss.mean() elif self.reduction == "sum": loss = loss.sum() return loss
[docs] class BCELossWithIgnore(nn.Module): """ Binary cross-entropy loss that ignores a specified target value. Computes element-wise BCE loss, zeroes out entries where ``target == ignore_index``, and returns the mean over non-ignored elements. """ def __init__(self, ignore_index=-1): """ Parameters: ignore_index (int, optional): Target value to exclude from the loss. (default: -1) """ super().__init__() self.bce_loss = nn.BCELoss(reduction="none") self.ignore_index = ignore_index
[docs] def forward(self, input, target): # Calculate BCE loss # Create a mask for non-ignored values mask = (target != self.ignore_index).float() # Set the ignored values to 0 target = target * mask # Calculate the loss loss = self.bce_loss(input, target) # Apply the mask to the loss masked_loss = loss * mask # Return the mean of the masked loss return masked_loss.sum() / mask.sum().clamp(min=1e-8)
[docs] class PeakWeightedMSELoss(nn.Module): """ MSE loss with additional weight on high-value (peak) target positions. Assigns a per-element weight of ``1 + peak_weight * target`` so that positions with larger target values (peaks in a heatmap) contribute more to the loss. Designed for 1D keypoint heatmap regression. """ def __init__(self, ignore_index=-1, peak_weight=3.0): """ Parameters: ignore_index (int, optional): Reserved for API consistency; currently unused in the forward pass. (default: -1) peak_weight (float, optional): Additional weight multiplier for high-value target positions. (default: 3.0) """ super().__init__() self.ignore_index = ignore_index self.peak_weight = peak_weight
[docs] def forward(self, input, target): weights = 1.0 + self.peak_weight * target # Higher weight for higher values return (weights * (input - target) ** 2).mean()