Spaces:
Paused
Paused
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| from losses import ramps | |
| class consistency_weight(object): | |
| """ | |
| ramp_types = ['sigmoid_rampup', 'linear_rampup', 'cosine_rampup', 'log_rampup', 'exp_rampup'] | |
| """ | |
| def __init__(self, final_w, iters_per_epoch, rampup_starts=0, rampup_ends=7, ramp_type='sigmoid_rampup'): | |
| self.final_w = final_w | |
| self.iters_per_epoch = iters_per_epoch | |
| self.rampup_starts = rampup_starts * iters_per_epoch | |
| self.rampup_ends = rampup_ends * iters_per_epoch | |
| self.rampup_length = (self.rampup_ends - self.rampup_starts) | |
| self.rampup_func = getattr(ramps, ramp_type) | |
| self.current_rampup = 0 | |
| def __call__(self, epoch, curr_iter): | |
| cur_total_iter = self.iters_per_epoch * epoch + curr_iter | |
| if cur_total_iter < self.rampup_starts: | |
| return 0 | |
| self.current_rampup = self.rampup_func(cur_total_iter - self.rampup_starts, self.rampup_length) | |
| return self.final_w * self.current_rampup | |
| def CE_loss(input_logits, target_targets, ignore_index, temperature=1): | |
| return F.cross_entropy(input_logits/temperature, target_targets, ignore_index=ignore_index) | |
| # for FocalLoss | |
| def softmax_helper(x): | |
| # copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/nd_softmax.py | |
| rpt = [1 for _ in range(len(x.size()))] | |
| rpt[1] = x.size(1) | |
| x_max = x.max(1, keepdim=True)[0].repeat(*rpt) | |
| e_x = torch.exp(x - x_max) | |
| return e_x / e_x.sum(1, keepdim=True).repeat(*rpt) | |
| def get_alpha(supervised_loader): | |
| # get number of classes | |
| num_labels = 0 | |
| for image_batch, label_batch in supervised_loader: | |
| label_batch.data[label_batch.data==255] = 0 # pixels of ignore class added to background | |
| l_unique = torch.unique(label_batch.data) | |
| list_unique = [element.item() for element in l_unique.flatten()] | |
| num_labels = max(max(list_unique),num_labels) | |
| num_classes = num_labels + 1 | |
| # count class occurrences | |
| alpha = [0 for i in range(num_classes)] | |
| for image_batch, label_batch in supervised_loader: | |
| label_batch.data[label_batch.data==255] = 0 # pixels of ignore class added to background | |
| l_unique = torch.unique(label_batch.data) | |
| list_unique = [element.item() for element in l_unique.flatten()] | |
| l_unique_count = torch.stack([(label_batch.data==x_u).sum() for x_u in l_unique]) # tensor([65920, 36480]) | |
| list_count = [count.item() for count in l_unique_count.flatten()] | |
| for index in list_unique: | |
| alpha[index] += list_count[list_unique.index(index)] | |
| return alpha | |
| # for FocalLoss | |
| def softmax_helper(x): | |
| # copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/nd_softmax.py | |
| rpt = [1 for _ in range(len(x.size()))] | |
| rpt[1] = x.size(1) | |
| x_max = x.max(1, keepdim=True)[0].repeat(*rpt) | |
| e_x = torch.exp(x - x_max) | |
| return e_x / e_x.sum(1, keepdim=True).repeat(*rpt) | |
| class FocalLoss(nn.Module): | |
| """ | |
| copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py | |
| This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in | |
| 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' | |
| Focal_Loss= -1*alpha*(1-pt)*log(pt) | |
| :param num_class: | |
| :param alpha: (tensor) 3D or 4D the scalar factor for this criterion | |
| :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more | |
| focus on hard misclassified example | |
| :param smooth: (float,double) smooth value when cross entropy | |
| :param balance_index: (int) balance class index, should be specific when alpha is float | |
| :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. | |
| """ | |
| def __init__(self, apply_nonlin=None, ignore_index = None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True): | |
| super(FocalLoss, self).__init__() | |
| self.apply_nonlin = apply_nonlin | |
| self.alpha = alpha | |
| self.gamma = gamma | |
| self.balance_index = balance_index | |
| self.smooth = smooth | |
| self.size_average = size_average | |
| if self.smooth is not None: | |
| if self.smooth < 0 or self.smooth > 1.0: | |
| raise ValueError('smooth value should be in [0,1]') | |
| def forward(self, logit, target): | |
| if self.apply_nonlin is not None: | |
| logit = self.apply_nonlin(logit) | |
| num_class = logit.shape[1] | |
| if logit.dim() > 2: | |
| # N,C,d1,d2 -> N,C,m (m=d1*d2*...) | |
| logit = logit.view(logit.size(0), logit.size(1), -1) | |
| logit = logit.permute(0, 2, 1).contiguous() | |
| logit = logit.view(-1, logit.size(-1)) | |
| target = torch.squeeze(target, 1) | |
| target = target.view(-1, 1) | |
| valid_mask = None | |
| if self.ignore_index is not None: | |
| valid_mask = target != self.ignore_index | |
| target = target * valid_mask | |
| alpha = self.alpha | |
| if alpha is None: | |
| alpha = torch.ones(num_class, 1) | |
| elif isinstance(alpha, (list, np.ndarray)): | |
| assert len(alpha) == num_class | |
| alpha = torch.FloatTensor(alpha).view(num_class, 1) | |
| alpha = alpha / alpha.sum() | |
| alpha = 1/alpha # inverse of class frequency | |
| elif isinstance(alpha, float): | |
| alpha = torch.ones(num_class, 1) | |
| alpha = alpha * (1 - self.alpha) | |
| alpha[self.balance_index] = self.alpha | |
| else: | |
| raise TypeError('Not support alpha type') | |
| if alpha.device != logit.device: | |
| alpha = alpha.to(logit.device) | |
| idx = target.cpu().long() | |
| one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_() | |
| # to resolve error in idx in scatter_ | |
| idx[idx==225]=0 | |
| one_hot_key = one_hot_key.scatter_(1, idx, 1) | |
| if one_hot_key.device != logit.device: | |
| one_hot_key = one_hot_key.to(logit.device) | |
| if self.smooth: | |
| one_hot_key = torch.clamp( | |
| one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth) | |
| pt = (one_hot_key * logit).sum(1) + self.smooth | |
| logpt = pt.log() | |
| gamma = self.gamma | |
| alpha = alpha[idx] | |
| alpha = torch.squeeze(alpha) | |
| loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt | |
| if valid_mask is not None: | |
| loss = loss * valid_mask.squeeze() | |
| if self.size_average: | |
| loss = loss.mean() | |
| else: | |
| loss = loss.sum() | |
| return loss | |
| class abCE_loss(nn.Module): | |
| """ | |
| Annealed-Bootstrapped cross-entropy loss | |
| """ | |
| def __init__(self, iters_per_epoch, epochs, num_classes, weight=None, | |
| reduction='mean', thresh=0.7, min_kept=1, ramp_type='log_rampup'): | |
| super(abCE_loss, self).__init__() | |
| self.weight = torch.FloatTensor(weight) if weight is not None else weight | |
| self.reduction = reduction | |
| self.thresh = thresh | |
| self.min_kept = min_kept | |
| self.ramp_type = ramp_type | |
| if ramp_type is not None: | |
| self.rampup_func = getattr(ramps, ramp_type) | |
| self.iters_per_epoch = iters_per_epoch | |
| self.num_classes = num_classes | |
| self.start = 1/num_classes | |
| self.end = 0.9 | |
| self.total_num_iters = (epochs - (0.6 * epochs)) * iters_per_epoch | |
| def threshold(self, curr_iter, epoch): | |
| cur_total_iter = self.iters_per_epoch * epoch + curr_iter | |
| current_rampup = self.rampup_func(cur_total_iter, self.total_num_iters) | |
| return current_rampup * (self.end - self.start) + self.start | |
| def forward(self, predict, target, ignore_index, curr_iter, epoch): | |
| batch_kept = self.min_kept * target.size(0) | |
| prob_out = F.softmax(predict, dim=1) | |
| tmp_target = target.clone() | |
| tmp_target[tmp_target == ignore_index] = 0 | |
| prob = prob_out.gather(1, tmp_target.unsqueeze(1)) | |
| mask = target.contiguous().view(-1, ) != ignore_index | |
| sort_prob, sort_indices = prob.contiguous().view(-1, )[mask].contiguous().sort() | |
| if self.ramp_type is not None: | |
| thresh = self.threshold(curr_iter=curr_iter, epoch=epoch) | |
| else: | |
| thresh = self.thresh | |
| min_threshold = sort_prob[min(batch_kept, sort_prob.numel() - 1)] if sort_prob.numel() > 0 else 0.0 | |
| threshold = max(min_threshold, thresh) | |
| loss_matrix = F.cross_entropy(predict, target, | |
| weight=self.weight.to(predict.device) if self.weight is not None else None, | |
| ignore_index=ignore_index, reduction='none') | |
| loss_matirx = loss_matrix.contiguous().view(-1, ) | |
| sort_loss_matirx = loss_matirx[mask][sort_indices] | |
| select_loss_matrix = sort_loss_matirx[sort_prob < threshold] | |
| if self.reduction == 'sum' or select_loss_matrix.numel() == 0: | |
| return select_loss_matrix.sum() | |
| elif self.reduction == 'mean': | |
| return select_loss_matrix.mean() | |
| else: | |
| raise NotImplementedError('Reduction Error!') | |
| def softmax_mse_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False): | |
| assert inputs.requires_grad == True and targets.requires_grad == False | |
| assert inputs.size() == targets.size() # (batch_size * num_classes * H * W) | |
| inputs = F.softmax(inputs, dim=1) | |
| if use_softmax: | |
| targets = F.softmax(targets, dim=1) | |
| if conf_mask: | |
| loss_mat = F.mse_loss(inputs, targets, reduction='none') | |
| mask = (targets.max(1)[0] > threshold) | |
| loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)] | |
| if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device) | |
| return loss_mat.mean() | |
| else: | |
| return F.mse_loss(inputs, targets, reduction='mean') # take the mean over the batch_size | |
| def softmax_kl_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False): | |
| assert inputs.requires_grad == True and targets.requires_grad == False | |
| assert inputs.size() == targets.size() | |
| if use_softmax: | |
| targets = F.softmax(targets, dim=1) | |
| if conf_mask: | |
| loss_mat = F.kl_div(input_log_softmax, targets, reduction='none') | |
| mask = (targets.max(1)[0] > threshold) | |
| loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)] | |
| if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device) | |
| return loss_mat.sum() / mask.shape.numel() | |
| else: | |
| return F.kl_div(inputs, targets, reduction='mean') | |
| def softmax_js_loss(inputs, targets, **_): | |
| assert inputs.requires_grad == True and targets.requires_grad == False | |
| assert inputs.size() == targets.size() | |
| epsilon = 1e-5 | |
| M = (F.softmax(inputs, dim=1) + targets) * 0.5 | |
| kl1 = F.kl_div(F.log_softmax(inputs, dim=1), M, reduction='mean') | |
| kl2 = F.kl_div(torch.log(targets+epsilon), M, reduction='mean') | |
| return (kl1 + kl2) * 0.5 | |
| def pair_wise_loss(unsup_outputs, size_average=True, nbr_of_pairs=8): | |
| """ | |
| Pair-wise loss in the sup. mat. | |
| """ | |
| if isinstance(unsup_outputs, list): | |
| unsup_outputs = torch.stack(unsup_outputs) | |
| # Only for a subset of the aux outputs to reduce computation and memory | |
| unsup_outputs = unsup_outputs[torch.randperm(unsup_outputs.size(0))] | |
| unsup_outputs = unsup_outputs[:nbr_of_pairs] | |
| temp = torch.zeros_like(unsup_outputs) # For grad purposes | |
| for i, u in enumerate(unsup_outputs): | |
| temp[i] = F.softmax(u, dim=1) | |
| mean_prediction = temp.mean(0).unsqueeze(0) # Mean over the auxiliary outputs | |
| pw_loss = ((temp - mean_prediction)**2).mean(0) # Variance | |
| pw_loss = pw_loss.sum(1) # Sum over classes | |
| if size_average: | |
| return pw_loss.mean() | |
| return pw_loss.sum() | |