Spaces:
Running
Running
| import copy | |
| import torch | |
| from torch import nn | |
| __all__ = ['build_optimizer'] | |
| def param_groups_weight_decay(model: nn.Module, | |
| weight_decay=1e-5, | |
| no_weight_decay_list=()): | |
| no_weight_decay_list = set(no_weight_decay_list) | |
| decay = [] | |
| no_decay = [] | |
| for name, param in model.named_parameters(): | |
| if not param.requires_grad: | |
| continue | |
| if param.ndim <= 1 or name.endswith( | |
| '.bias') or any(nd in name for nd in no_weight_decay_list): | |
| no_decay.append(param) | |
| else: | |
| decay.append(param) | |
| return [ | |
| { | |
| 'params': no_decay, | |
| 'weight_decay': 0.0 | |
| }, | |
| { | |
| 'params': decay, | |
| 'weight_decay': weight_decay | |
| }, | |
| ] | |
| def build_optimizer(optim_config, lr_scheduler_config, epochs, step_each_epoch, | |
| model): | |
| from . import lr | |
| config = copy.deepcopy(optim_config) | |
| if isinstance(model, nn.Module): | |
| # a model was passed in, extract parameters and add weight decays to appropriate layers | |
| weight_decay = config.get('weight_decay', 0.0) | |
| filter_bias_and_bn = (config.pop('filter_bias_and_bn') | |
| if 'filter_bias_and_bn' in config else False) | |
| if weight_decay > 0.0 and filter_bias_and_bn: | |
| no_weight_decay = {} | |
| if hasattr(model, 'no_weight_decay'): | |
| no_weight_decay = model.no_weight_decay() | |
| parameters = param_groups_weight_decay(model, weight_decay, | |
| no_weight_decay) | |
| config['weight_decay'] = 0.0 | |
| # print('debug adamw') | |
| else: | |
| parameters = model.parameters() | |
| else: | |
| # iterable of parameters or param groups passed in | |
| parameters = model | |
| optim = getattr(torch.optim, config.pop('name'))(params=parameters, | |
| **config) | |
| lr_config = copy.deepcopy(lr_scheduler_config) | |
| lr_config.update({ | |
| 'epochs': epochs, | |
| 'step_each_epoch': step_each_epoch, | |
| 'lr': config['lr'] | |
| }) | |
| lr_scheduler = getattr(lr, | |
| lr_config.pop('name'))(**lr_config)(optimizer=optim) | |
| return optim, lr_scheduler | |