File size: 5,763 Bytes
ae522ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import torch
import torch.nn.functional as F


import torch
import torch.nn as nn
import torch.nn.functional as F

class WeightedDiceBCELoss(nn.Module):
    def __init__(self, 
                 weight_1=4.0, 
                 weight_2=333.3, 
                 weight_3=1.0, 
                 bce_weight=1.0, 
                 grad_weight=1.0,
                 use_focal=False,
                 use_gradient_loss=False,
                 use_dist_loss=False,
                 focal_alpha=0.25,
                 focal_gamma=2.0,
                 boundary_dist_weight=1.0,
                 epsilon=1e-6,
    ):
        """
        Args:
            weight_1: Weight for object foreground in Dice loss.
            weight_2: Weight for boundary foreground in Dice loss.
            weight_3: Weight for boundary channel weight.
            bce_weight: Weight for BCE loss.
            epsilon: Small constant to prevent division by zero.
            use_focal: Whether to use Focal Loss for boundary channel.
            focal_gamma: Focal Loss gamma parameter.
            focal_alpha: Focal Loss alpha parameter.
            use_gradient_loss: Whether to use Sobel gradient loss.
            grad_weight: Weight for gradient loss term.
        """
        super(WeightedDiceBCELoss, self).__init__()
        self.weight_object_foreground = weight_1
        self.weight_boundary_foreground = weight_2
        self.weight_boundary_channel = weight_3
        self.bce_weight = bce_weight
        self.epsilon = epsilon

        # Focal Loss params
        self.use_focal = use_focal
        self.focal_gamma = focal_gamma
        self.focal_alpha = focal_alpha

        # Gradient loss params
        self.use_gradient_loss = use_gradient_loss
        self.grad_weight = grad_weight

        # Boundary distance weight
        self.use_dist_loss = use_dist_loss
        self.boundary_dist_weight = boundary_dist_weight
        
        # Sobel filter for gradient loss
        if use_gradient_loss:
            sobel_kernel = torch.tensor([
                [[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]],  # x
                [[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]   # y
            ], dtype=torch.float32)
            self.sobel_filter = nn.Conv2d(1, 2, kernel_size=3, padding=1, bias=False)
            self.sobel_filter.weight.data = sobel_kernel
            self.sobel_filter.requires_grad_(False)

    def forward(self, logits, targets, sdm_tensor):
        """
        Args:
            logits: model outputs (assuming sigmoid already applied), shape (B, 2, H, W)
            targets: Ground truth masks, shape (B, 2, H, W)
        """
        probs = logits  # already sigmoid applied, as per note

        # Per-pixel weights
        weights_obj = torch.where(targets[:,1,:,:] == 1, self.weight_object_foreground, 1.0)
        weights_bnd = torch.where(targets[:,0,:,:] == 1, self.weight_boundary_foreground, 1.0)

        # ----- Dice Loss -----
        boundary_channel_dice = self.dice_loss(
            probs[:,0,:,:], 
            targets[:,0,:,:], 
            weights_bnd, 
            self.epsilon
        )
        object_channel_dice = self.dice_loss(
            probs[:,1,:,:], 
            targets[:,1,:,:], 
            weights_obj, 
            self.epsilon
        )

        # ----- BCE or Focal Loss -----
        if self.use_focal:
            boundary_bce = self.focal_loss(
                probs[:,0,:,:], targets[:,0,:,:], weights_bnd,
                gamma=self.focal_gamma, alpha=self.focal_alpha
            )
        else:
            boundary_bce = F.binary_cross_entropy(probs[:,0,:,:], targets[:,0,:,:], weights_bnd)

        object_bce = F.binary_cross_entropy(probs[:,1,:,:], targets[:,1,:,:], weights_obj)

        # ----- Gradient Loss -----
        grad_loss = 0.0
        if self.use_gradient_loss:
            grad_loss = self.gradient_loss(probs[:,0,:,:], targets[:,0,:,:])

        # ----- Boundary Distance Loss -----
        boundary_dist_loss = 0.0
        if self.use_dist_loss:
            boundary_dist_loss = torch.mean(probs[:,0,:,:] * sdm_tensor.to(probs.device))

        # ----- Total Loss -----
        total_loss = (self.weight_boundary_channel * boundary_channel_dice + object_channel_dice) + \
                     self.bce_weight * (self.weight_boundary_channel * boundary_bce + object_bce) + \
                     self.grad_weight * grad_loss + \
                     self.boundary_dist_weight * boundary_dist_loss

        return total_loss

    @staticmethod
    def dice_loss(probs, targets, weights=1., epsilon=1e-6):
        probs = probs.reshape(-1)
        targets = targets.reshape(-1)
        weights = weights.reshape(-1)

        intersection = torch.sum(weights * probs * targets)
        denominator = torch.sum(weights * probs) + torch.sum(weights * targets)
        dice_score = (2. * intersection + epsilon) / (denominator + epsilon)
        return 1. - dice_score

    @staticmethod
    def focal_loss(probs, targets, weights=1., gamma=2.0, alpha=0.25):
        probs = probs.reshape(-1)
        targets = targets.reshape(-1)
        weights = weights.reshape(-1)

        bce = F.binary_cross_entropy(probs, targets, reduction='none')
        p_t = probs * targets + (1 - probs) * (1 - targets)
        focal_term = (1 - p_t) ** gamma
        loss = alpha * focal_term * bce * weights
        return loss.mean()

    def gradient_loss(self, pred, target):
        pred = pred.unsqueeze(1)
        target = target.unsqueeze(1)

        # Move sobel filter to the same device as input
        self.sobel_filter = self.sobel_filter.to(pred.device)
        
        pred_grad = self.sobel_filter(pred)
        target_grad = self.sobel_filter(target)
        return F.l1_loss(pred_grad, target_grad)