Spaces:
Sleeping
Sleeping
File size: 5,763 Bytes
ae522ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import torch
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.nn.functional as F
class WeightedDiceBCELoss(nn.Module):
def __init__(self,
weight_1=4.0,
weight_2=333.3,
weight_3=1.0,
bce_weight=1.0,
grad_weight=1.0,
use_focal=False,
use_gradient_loss=False,
use_dist_loss=False,
focal_alpha=0.25,
focal_gamma=2.0,
boundary_dist_weight=1.0,
epsilon=1e-6,
):
"""
Args:
weight_1: Weight for object foreground in Dice loss.
weight_2: Weight for boundary foreground in Dice loss.
weight_3: Weight for boundary channel weight.
bce_weight: Weight for BCE loss.
epsilon: Small constant to prevent division by zero.
use_focal: Whether to use Focal Loss for boundary channel.
focal_gamma: Focal Loss gamma parameter.
focal_alpha: Focal Loss alpha parameter.
use_gradient_loss: Whether to use Sobel gradient loss.
grad_weight: Weight for gradient loss term.
"""
super(WeightedDiceBCELoss, self).__init__()
self.weight_object_foreground = weight_1
self.weight_boundary_foreground = weight_2
self.weight_boundary_channel = weight_3
self.bce_weight = bce_weight
self.epsilon = epsilon
# Focal Loss params
self.use_focal = use_focal
self.focal_gamma = focal_gamma
self.focal_alpha = focal_alpha
# Gradient loss params
self.use_gradient_loss = use_gradient_loss
self.grad_weight = grad_weight
# Boundary distance weight
self.use_dist_loss = use_dist_loss
self.boundary_dist_weight = boundary_dist_weight
# Sobel filter for gradient loss
if use_gradient_loss:
sobel_kernel = torch.tensor([
[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]], # x
[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]] # y
], dtype=torch.float32)
self.sobel_filter = nn.Conv2d(1, 2, kernel_size=3, padding=1, bias=False)
self.sobel_filter.weight.data = sobel_kernel
self.sobel_filter.requires_grad_(False)
def forward(self, logits, targets, sdm_tensor):
"""
Args:
logits: model outputs (assuming sigmoid already applied), shape (B, 2, H, W)
targets: Ground truth masks, shape (B, 2, H, W)
"""
probs = logits # already sigmoid applied, as per note
# Per-pixel weights
weights_obj = torch.where(targets[:,1,:,:] == 1, self.weight_object_foreground, 1.0)
weights_bnd = torch.where(targets[:,0,:,:] == 1, self.weight_boundary_foreground, 1.0)
# ----- Dice Loss -----
boundary_channel_dice = self.dice_loss(
probs[:,0,:,:],
targets[:,0,:,:],
weights_bnd,
self.epsilon
)
object_channel_dice = self.dice_loss(
probs[:,1,:,:],
targets[:,1,:,:],
weights_obj,
self.epsilon
)
# ----- BCE or Focal Loss -----
if self.use_focal:
boundary_bce = self.focal_loss(
probs[:,0,:,:], targets[:,0,:,:], weights_bnd,
gamma=self.focal_gamma, alpha=self.focal_alpha
)
else:
boundary_bce = F.binary_cross_entropy(probs[:,0,:,:], targets[:,0,:,:], weights_bnd)
object_bce = F.binary_cross_entropy(probs[:,1,:,:], targets[:,1,:,:], weights_obj)
# ----- Gradient Loss -----
grad_loss = 0.0
if self.use_gradient_loss:
grad_loss = self.gradient_loss(probs[:,0,:,:], targets[:,0,:,:])
# ----- Boundary Distance Loss -----
boundary_dist_loss = 0.0
if self.use_dist_loss:
boundary_dist_loss = torch.mean(probs[:,0,:,:] * sdm_tensor.to(probs.device))
# ----- Total Loss -----
total_loss = (self.weight_boundary_channel * boundary_channel_dice + object_channel_dice) + \
self.bce_weight * (self.weight_boundary_channel * boundary_bce + object_bce) + \
self.grad_weight * grad_loss + \
self.boundary_dist_weight * boundary_dist_loss
return total_loss
@staticmethod
def dice_loss(probs, targets, weights=1., epsilon=1e-6):
probs = probs.reshape(-1)
targets = targets.reshape(-1)
weights = weights.reshape(-1)
intersection = torch.sum(weights * probs * targets)
denominator = torch.sum(weights * probs) + torch.sum(weights * targets)
dice_score = (2. * intersection + epsilon) / (denominator + epsilon)
return 1. - dice_score
@staticmethod
def focal_loss(probs, targets, weights=1., gamma=2.0, alpha=0.25):
probs = probs.reshape(-1)
targets = targets.reshape(-1)
weights = weights.reshape(-1)
bce = F.binary_cross_entropy(probs, targets, reduction='none')
p_t = probs * targets + (1 - probs) * (1 - targets)
focal_term = (1 - p_t) ** gamma
loss = alpha * focal_term * bce * weights
return loss.mean()
def gradient_loss(self, pred, target):
pred = pred.unsqueeze(1)
target = target.unsqueeze(1)
# Move sobel filter to the same device as input
self.sobel_filter = self.sobel_filter.to(pred.device)
pred_grad = self.sobel_filter(pred)
target_grad = self.sobel_filter(target)
return F.l1_loss(pred_grad, target_grad)
|