Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from losses.consistency_loss import * | |
| class MultiConLoss(nn.Module): | |
| def __init__(self): | |
| super(MultiConLoss, self).__init__() | |
| self.countloss_criterion = nn.MSELoss(reduction='sum') | |
| self.multiconloss = 0.0 | |
| self.losses = {} | |
| def forward(self, unlabeled_results): | |
| self.multiconloss = 0.0 | |
| self.losses = {} | |
| if unlabeled_results is None: | |
| self.multiconloss = 0.0 | |
| elif isinstance(unlabeled_results, list) and len(unlabeled_results) > 0: | |
| count = 0 | |
| for i in range(len(unlabeled_results[0])): | |
| with torch.set_grad_enabled(False): | |
| preds_mean = (unlabeled_results[0][i] + unlabeled_results[1][i] + unlabeled_results[2][i])/len(unlabeled_results) | |
| for j in range(len(unlabeled_results)): | |
| var_sel = softmax_kl_loss(unlabeled_results[j][i], preds_mean) | |
| exp_var = torch.exp(-var_sel) | |
| consistency_dist = (preds_mean - unlabeled_results[j][i]) ** 2 | |
| temploss = (torch.mean(consistency_dist * exp_var) /(exp_var + 1e-8) + var_sel) | |
| self.losses.update({'unlabel_{}_loss'.format(str(i+1)): temploss}) | |
| self.multiconloss += temploss | |
| count += 1 | |
| if count > 0: | |
| self.multiconloss = self.multiconloss / count | |
| return self.multiconloss | |