Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class WeightedDiceBCELoss(nn.Module): | |
| def __init__(self, | |
| weight_1=4.0, | |
| weight_2=333.3, | |
| weight_3=1.0, | |
| bce_weight=1.0, | |
| grad_weight=1.0, | |
| use_focal=False, | |
| use_gradient_loss=False, | |
| use_dist_loss=False, | |
| focal_alpha=0.25, | |
| focal_gamma=2.0, | |
| boundary_dist_weight=1.0, | |
| epsilon=1e-6, | |
| ): | |
| """ | |
| Args: | |
| weight_1: Weight for object foreground in Dice loss. | |
| weight_2: Weight for boundary foreground in Dice loss. | |
| weight_3: Weight for boundary channel weight. | |
| bce_weight: Weight for BCE loss. | |
| epsilon: Small constant to prevent division by zero. | |
| use_focal: Whether to use Focal Loss for boundary channel. | |
| focal_gamma: Focal Loss gamma parameter. | |
| focal_alpha: Focal Loss alpha parameter. | |
| use_gradient_loss: Whether to use Sobel gradient loss. | |
| grad_weight: Weight for gradient loss term. | |
| """ | |
| super(WeightedDiceBCELoss, self).__init__() | |
| self.weight_object_foreground = weight_1 | |
| self.weight_boundary_foreground = weight_2 | |
| self.weight_boundary_channel = weight_3 | |
| self.bce_weight = bce_weight | |
| self.epsilon = epsilon | |
| # Focal Loss params | |
| self.use_focal = use_focal | |
| self.focal_gamma = focal_gamma | |
| self.focal_alpha = focal_alpha | |
| # Gradient loss params | |
| self.use_gradient_loss = use_gradient_loss | |
| self.grad_weight = grad_weight | |
| # Boundary distance weight | |
| self.use_dist_loss = use_dist_loss | |
| self.boundary_dist_weight = boundary_dist_weight | |
| # Sobel filter for gradient loss | |
| if use_gradient_loss: | |
| sobel_kernel = torch.tensor([ | |
| [[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]], # x | |
| [[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]] # y | |
| ], dtype=torch.float32) | |
| self.sobel_filter = nn.Conv2d(1, 2, kernel_size=3, padding=1, bias=False) | |
| self.sobel_filter.weight.data = sobel_kernel | |
| self.sobel_filter.requires_grad_(False) | |
| def forward(self, logits, targets, sdm_tensor): | |
| """ | |
| Args: | |
| logits: model outputs (assuming sigmoid already applied), shape (B, 2, H, W) | |
| targets: Ground truth masks, shape (B, 2, H, W) | |
| """ | |
| probs = logits # already sigmoid applied, as per note | |
| # Per-pixel weights | |
| weights_obj = torch.where(targets[:,1,:,:] == 1, self.weight_object_foreground, 1.0) | |
| weights_bnd = torch.where(targets[:,0,:,:] == 1, self.weight_boundary_foreground, 1.0) | |
| # ----- Dice Loss ----- | |
| boundary_channel_dice = self.dice_loss( | |
| probs[:,0,:,:], | |
| targets[:,0,:,:], | |
| weights_bnd, | |
| self.epsilon | |
| ) | |
| object_channel_dice = self.dice_loss( | |
| probs[:,1,:,:], | |
| targets[:,1,:,:], | |
| weights_obj, | |
| self.epsilon | |
| ) | |
| # ----- BCE or Focal Loss ----- | |
| if self.use_focal: | |
| boundary_bce = self.focal_loss( | |
| probs[:,0,:,:], targets[:,0,:,:], weights_bnd, | |
| gamma=self.focal_gamma, alpha=self.focal_alpha | |
| ) | |
| else: | |
| boundary_bce = F.binary_cross_entropy(probs[:,0,:,:], targets[:,0,:,:], weights_bnd) | |
| object_bce = F.binary_cross_entropy(probs[:,1,:,:], targets[:,1,:,:], weights_obj) | |
| # ----- Gradient Loss ----- | |
| grad_loss = 0.0 | |
| if self.use_gradient_loss: | |
| grad_loss = self.gradient_loss(probs[:,0,:,:], targets[:,0,:,:]) | |
| # ----- Boundary Distance Loss ----- | |
| boundary_dist_loss = 0.0 | |
| if self.use_dist_loss: | |
| boundary_dist_loss = torch.mean(probs[:,0,:,:] * sdm_tensor.to(probs.device)) | |
| # ----- Total Loss ----- | |
| total_loss = (self.weight_boundary_channel * boundary_channel_dice + object_channel_dice) + \ | |
| self.bce_weight * (self.weight_boundary_channel * boundary_bce + object_bce) + \ | |
| self.grad_weight * grad_loss + \ | |
| self.boundary_dist_weight * boundary_dist_loss | |
| return total_loss | |
| def dice_loss(probs, targets, weights=1., epsilon=1e-6): | |
| probs = probs.reshape(-1) | |
| targets = targets.reshape(-1) | |
| weights = weights.reshape(-1) | |
| intersection = torch.sum(weights * probs * targets) | |
| denominator = torch.sum(weights * probs) + torch.sum(weights * targets) | |
| dice_score = (2. * intersection + epsilon) / (denominator + epsilon) | |
| return 1. - dice_score | |
| def focal_loss(probs, targets, weights=1., gamma=2.0, alpha=0.25): | |
| probs = probs.reshape(-1) | |
| targets = targets.reshape(-1) | |
| weights = weights.reshape(-1) | |
| bce = F.binary_cross_entropy(probs, targets, reduction='none') | |
| p_t = probs * targets + (1 - probs) * (1 - targets) | |
| focal_term = (1 - p_t) ** gamma | |
| loss = alpha * focal_term * bce * weights | |
| return loss.mean() | |
| def gradient_loss(self, pred, target): | |
| pred = pred.unsqueeze(1) | |
| target = target.unsqueeze(1) | |
| # Move sobel filter to the same device as input | |
| self.sobel_filter = self.sobel_filter.to(pred.device) | |
| pred_grad = self.sobel_filter(pred) | |
| target_grad = self.sobel_filter(target) | |
| return F.l1_loss(pred_grad, target_grad) | |