Spaces:
Runtime error
Runtime error
| import logging | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from torch import optim | |
| from torch.nn import functional as F | |
| from torch.utils.data import DataLoader | |
| from models.base import BaseLearner | |
| from utils.inc_net import IncrementalNetWithBias | |
| epochs = 170 | |
| lrate = 0.1 | |
| milestones = [60, 100, 140] | |
| lrate_decay = 0.1 | |
| batch_size = 128 | |
| split_ratio = 0.1 | |
| T = 2 | |
| weight_decay = 2e-4 | |
| num_workers = 8 | |
| class BiC(BaseLearner): | |
| def __init__(self, args): | |
| super().__init__(args) | |
| self._network = IncrementalNetWithBias( | |
| args, False, bias_correction=True | |
| ) | |
| self._class_means = None | |
| def after_task(self): | |
| self._old_network = self._network.copy().freeze() | |
| self._known_classes = self._total_classes | |
| logging.info("Exemplar size: {}".format(self.exemplar_size)) | |
| def incremental_train(self, data_manager): | |
| self._cur_task += 1 | |
| self._total_classes = self._known_classes + data_manager.get_task_size( | |
| self._cur_task | |
| ) | |
| self._network.update_fc(self._total_classes) | |
| logging.info( | |
| "Learning on {}-{}".format(self._known_classes, self._total_classes) | |
| ) | |
| if self._cur_task >= 1: | |
| train_dset, val_dset = data_manager.get_dataset_with_split( | |
| np.arange(self._known_classes, self._total_classes), | |
| source="train", | |
| mode="train", | |
| appendent=self._get_memory(), | |
| val_samples_per_class=int( | |
| split_ratio * self._memory_size / self._known_classes | |
| ), | |
| ) | |
| self.val_loader = DataLoader( | |
| val_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers | |
| ) | |
| logging.info( | |
| "Stage1 dset: {}, Stage2 dset: {}".format( | |
| len(train_dset), len(val_dset) | |
| ) | |
| ) | |
| self.lamda = self._known_classes / self._total_classes | |
| logging.info("Lambda: {:.3f}".format(self.lamda)) | |
| else: | |
| train_dset = data_manager.get_dataset( | |
| np.arange(self._known_classes, self._total_classes), | |
| source="train", | |
| mode="train", | |
| appendent=self._get_memory(), | |
| ) | |
| test_dset = data_manager.get_dataset( | |
| np.arange(0, self._total_classes), source="test", mode="test" | |
| ) | |
| self.train_loader = DataLoader( | |
| train_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers | |
| ) | |
| self.test_loader = DataLoader( | |
| test_dset, batch_size=batch_size, shuffle=False, num_workers=num_workers | |
| ) | |
| self._log_bias_params() | |
| self._stage1_training(self.train_loader, self.test_loader) | |
| if self._cur_task >= 1: | |
| self._stage2_bias_correction(self.val_loader, self.test_loader) | |
| self.build_rehearsal_memory(data_manager, self.samples_per_class) | |
| if len(self._multiple_gpus) > 1: | |
| self._network = self._network.module | |
| self._log_bias_params() | |
| def _run(self, train_loader, test_loader, optimizer, scheduler, stage): | |
| for epoch in range(1, epochs + 1): | |
| self._network.train() | |
| losses = 0.0 | |
| for i, (_, inputs, targets) in enumerate(train_loader): | |
| inputs, targets = inputs.to(self._device), targets.to(self._device) | |
| logits = self._network(inputs)["logits"] | |
| if stage == "training": | |
| clf_loss = F.cross_entropy(logits, targets) | |
| if self._old_network is not None: | |
| old_logits = self._old_network(inputs)["logits"].detach() | |
| hat_pai_k = F.softmax(old_logits / T, dim=1) | |
| log_pai_k = F.log_softmax( | |
| logits[:, : self._known_classes] / T, dim=1 | |
| ) | |
| distill_loss = -torch.mean( | |
| torch.sum(hat_pai_k * log_pai_k, dim=1) | |
| ) | |
| loss = distill_loss * self.lamda + clf_loss * (1 - self.lamda) | |
| else: | |
| loss = clf_loss | |
| elif stage == "bias_correction": | |
| loss = F.cross_entropy(torch.softmax(logits, dim=1), targets) | |
| else: | |
| raise NotImplementedError() | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| losses += loss.item() | |
| scheduler.step() | |
| train_acc = self._compute_accuracy(self._network, train_loader) | |
| test_acc = self._compute_accuracy(self._network, test_loader) | |
| info = "{} => Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.3f}, Test_accy {:.3f}".format( | |
| stage, | |
| self._cur_task, | |
| epoch, | |
| epochs, | |
| losses / len(train_loader), | |
| train_acc, | |
| test_acc, | |
| ) | |
| logging.info(info) | |
| def _stage1_training(self, train_loader, test_loader): | |
| """ | |
| if self._cur_task == 0: | |
| loaded_dict = torch.load('./dict_0.pkl') | |
| self._network.load_state_dict(loaded_dict['model_state_dict']) | |
| self._network.to(self._device) | |
| return | |
| """ | |
| ignored_params = list(map(id, self._network.bias_layers.parameters())) | |
| base_params = filter( | |
| lambda p: id(p) not in ignored_params, self._network.parameters() | |
| ) | |
| network_params = [ | |
| {"params": base_params, "lr": lrate, "weight_decay": weight_decay}, | |
| { | |
| "params": self._network.bias_layers.parameters(), | |
| "lr": 0, | |
| "weight_decay": 0, | |
| }, | |
| ] | |
| optimizer = optim.SGD( | |
| network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay | |
| ) | |
| scheduler = optim.lr_scheduler.MultiStepLR( | |
| optimizer=optimizer, milestones=milestones, gamma=lrate_decay | |
| ) | |
| if len(self._multiple_gpus) > 1: | |
| self._network = nn.DataParallel(self._network, self._multiple_gpus) | |
| self._network.to(self._device) | |
| if self._old_network is not None: | |
| self._old_network.to(self._device) | |
| self._run(train_loader, test_loader, optimizer, scheduler, stage="training") | |
| def _stage2_bias_correction(self, val_loader, test_loader): | |
| if isinstance(self._network, nn.DataParallel): | |
| self._network = self._network.module | |
| network_params = [ | |
| { | |
| "params": self._network.bias_layers[-1].parameters(), | |
| "lr": lrate, | |
| "weight_decay": weight_decay, | |
| } | |
| ] | |
| optimizer = optim.SGD( | |
| network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay | |
| ) | |
| scheduler = optim.lr_scheduler.MultiStepLR( | |
| optimizer=optimizer, milestones=milestones, gamma=lrate_decay | |
| ) | |
| if len(self._multiple_gpus) > 1: | |
| self._network = nn.DataParallel(self._network, self._multiple_gpus) | |
| self._network.to(self._device) | |
| self._run( | |
| val_loader, test_loader, optimizer, scheduler, stage="bias_correction" | |
| ) | |
| def _log_bias_params(self): | |
| logging.info("Parameters of bias layer:") | |
| params = self._network.get_bias_params() | |
| for i, param in enumerate(params): | |
| logging.info("{} => {:.3f}, {:.3f}".format(i, param[0], param[1])) | |