Spaces:
Configuration error
Configuration error
| import os | |
| import numpy as np | |
| import time | |
| import sys | |
| import argparse | |
| import errno | |
| from collections import OrderedDict | |
| import tensorboardX | |
| from tqdm import tqdm | |
| import random | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| from torch.optim.lr_scheduler import StepLR | |
| from torch.utils.data import DataLoader | |
| from lib.utils.tools import * | |
| from lib.utils.learning import * | |
| from lib.model.loss import * | |
| from lib.data.dataset_action import NTURGBD, NTURGBD1Shot | |
| from lib.model.model_action import ActionNet | |
| from lib.model.loss_supcon import SupConLoss | |
| from pytorch_metric_learning import samplers | |
| random.seed(0) | |
| np.random.seed(0) | |
| torch.manual_seed(0) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", type=str, default="configs/pretrain.yaml", help="Path to the config file.") | |
| parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', help='checkpoint directory') | |
| parser.add_argument('-p', '--pretrained', default='checkpoint', type=str, metavar='PATH', help='pretrained checkpoint directory') | |
| parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME', help='checkpoint to resume (file name)') | |
| parser.add_argument('-e', '--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)') | |
| parser.add_argument('-freq', '--print_freq', default=100) | |
| parser.add_argument('-ms', '--selection', default='best_epoch.bin', type=str, metavar='FILENAME', help='checkpoint to finetune (file name)') | |
| opts = parser.parse_args() | |
| return opts | |
| def extract_feats(dataloader_x, model): | |
| all_feats = [] | |
| all_gts = [] | |
| with torch.no_grad(): | |
| for idx, (batch_input, batch_gt) in tqdm(enumerate(dataloader_x)): # (N, 2, T, 17, 3) | |
| if torch.cuda.is_available(): | |
| batch_input = batch_input.cuda() | |
| feat = model(batch_input) | |
| all_feats.append(feat) | |
| all_gts.append(batch_gt) | |
| all_feats = torch.cat(all_feats) | |
| all_gts = torch.cat(all_gts) | |
| return all_feats, all_gts | |
| def validate(anchor_loader, test_loader, model): | |
| train_feats, train_labels = extract_feats(anchor_loader, model) | |
| test_feats, test_labels = extract_feats(test_loader, model) | |
| M = len(train_feats) | |
| N = len(test_feats) | |
| train_feats = train_feats.unsqueeze(1) | |
| test_feats = test_feats.unsqueeze(0) | |
| dis = F.cosine_similarity(train_feats, test_feats, dim=-1) | |
| pred = train_labels[torch.argmax(dis, dim=0)] | |
| assert len(pred)==len(test_labels) | |
| acc = sum(pred==test_labels) / len(pred) | |
| return acc | |
| def train_with_config(args, opts): | |
| print(args) | |
| try: | |
| os.makedirs(opts.checkpoint) | |
| except OSError as e: | |
| if e.errno != errno.EEXIST: | |
| raise RuntimeError('Unable to create checkpoint directory:', opts.checkpoint) | |
| train_writer = tensorboardX.SummaryWriter(os.path.join(opts.checkpoint, "logs")) | |
| model_backbone = load_backbone(args) | |
| if args.finetune: | |
| if opts.resume or opts.evaluate: | |
| pass | |
| else: | |
| chk_filename = os.path.join(opts.pretrained, "best_epoch.bin") | |
| print('Loading backbone', chk_filename) | |
| checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) | |
| new_state_dict = OrderedDict() | |
| for k, v in checkpoint['model_pos'].items(): | |
| name = k[7:] # remove 'module.' | |
| new_state_dict[name] = v | |
| model_backbone.load_state_dict(new_state_dict, strict=True) | |
| if args.partial_train: | |
| model_backbone = partial_train_layers(model_backbone, args.partial_train) | |
| model = ActionNet(backbone=model_backbone, dim_rep=args.dim_rep, dropout_ratio=args.dropout_ratio, version=args.model_version, hidden_dim=args.hidden_dim, num_joints=args.num_joints) | |
| criterion = SupConLoss(temperature=args.temp) | |
| if torch.cuda.is_available(): | |
| model = nn.DataParallel(model) | |
| model = model.cuda() | |
| criterion = criterion.cuda() | |
| chk_filename = os.path.join(opts.checkpoint, "latest_epoch.bin") | |
| if os.path.exists(chk_filename): | |
| opts.resume = chk_filename | |
| if opts.resume or opts.evaluate: | |
| chk_filename = opts.evaluate if opts.evaluate else opts.resume | |
| print('Loading checkpoint', chk_filename) | |
| checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) | |
| model.load_state_dict(checkpoint['model'], strict=True) | |
| best_acc = 0 | |
| model_params = 0 | |
| for parameter in model.parameters(): | |
| model_params = model_params + parameter.numel() | |
| print('INFO: Trainable parameter count:', model_params) | |
| print('Loading dataset...') | |
| anchorloader_params = { | |
| 'batch_size': args.batch_size, | |
| 'shuffle': False, | |
| 'num_workers': 8, | |
| 'pin_memory': True, | |
| 'prefetch_factor': 4, | |
| 'persistent_workers': True | |
| } | |
| testloader_params = { | |
| 'batch_size': args.batch_size, | |
| 'shuffle': False, | |
| 'num_workers': 8, | |
| 'pin_memory': True, | |
| 'prefetch_factor': 4, | |
| 'persistent_workers': True | |
| } | |
| data_path_1shot = 'data/action/ntu120_hrnet_oneshot.pkl' | |
| ntu60_1shot_anchor = NTURGBD(data_path=data_path_1shot, data_split='oneshot_train', n_frames=args.clip_len, random_move=False, scale_range=args.scale_range_test) | |
| ntu60_1shot_test = NTURGBD(data_path=data_path_1shot, data_split='oneshot_val', n_frames=args.clip_len, random_move=False, scale_range=args.scale_range_test) | |
| anchor_loader = DataLoader(ntu60_1shot_anchor, **anchorloader_params) | |
| test_loader = DataLoader(ntu60_1shot_test, **testloader_params) | |
| if not opts.evaluate: | |
| # Load training data (auxiliary set) | |
| data_path = 'data/action/ntu120_hrnet.pkl' | |
| ntu120_1shot_train = NTURGBD1Shot(data_path=data_path, data_split='', n_frames=args.clip_len, random_move=args.random_move, scale_range=args.scale_range_train, check_split=False) | |
| sampler = samplers.MPerClassSampler(ntu120_1shot_train.labels, m=args.n_views, batch_size=args.batch_size, length_before_new_iter=len(ntu120_1shot_train)) | |
| trainloader_params = { | |
| 'batch_size': args.batch_size, | |
| 'shuffle': False, | |
| 'num_workers': 8, | |
| 'pin_memory': True, | |
| 'prefetch_factor': 4, | |
| 'persistent_workers': True, | |
| 'sampler': sampler | |
| } | |
| train_loader = DataLoader(ntu120_1shot_train, **trainloader_params) | |
| optimizer = optim.AdamW( | |
| [ {"params": filter(lambda p: p.requires_grad, model.module.backbone.parameters()), "lr": args.lr_backbone}, | |
| {"params": filter(lambda p: p.requires_grad, model.module.head.parameters()), "lr": args.lr_head}, | |
| ], lr=args.lr_backbone, | |
| weight_decay=args.weight_decay | |
| ) | |
| scheduler = StepLR(optimizer, step_size=1, gamma=args.lr_decay) | |
| st = 0 | |
| print('INFO: Training on {} batches'.format(len(train_loader))) | |
| if opts.resume: | |
| st = checkpoint['epoch'] | |
| if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None: | |
| optimizer.load_state_dict(checkpoint['optimizer']) | |
| else: | |
| print('WARNING: this checkpoint does not contain an optimizer state. The optimizer will be reinitialized.') | |
| lr = checkpoint['lr'] | |
| if 'best_acc' in checkpoint and checkpoint['best_acc'] is not None: | |
| best_acc = checkpoint['best_acc'] | |
| # Training | |
| for epoch in range(st, args.epochs): | |
| print('Training epoch %d.' % epoch) | |
| losses_train = AverageMeter() | |
| batch_time = AverageMeter() | |
| data_time = AverageMeter() | |
| model.train() | |
| end = time.time() | |
| for idx, (batch_input, batch_gt) in tqdm(enumerate(train_loader)): | |
| data_time.update(time.time() - end) | |
| batch_size = len(batch_input) | |
| if torch.cuda.is_available(): | |
| batch_gt = batch_gt.cuda() | |
| batch_input = batch_input.cuda() | |
| feat = model(batch_input) | |
| feat = feat.reshape(batch_size, -1, args.hidden_dim) | |
| optimizer.zero_grad() | |
| loss_train = criterion(feat, batch_gt) | |
| losses_train.update(loss_train.item(), batch_size) | |
| loss_train.backward() | |
| optimizer.step() | |
| batch_time.update(time.time() - end) | |
| end = time.time() | |
| if (idx + 1) % opts.print_freq == 0: | |
| print('Train: [{0}][{1}/{2}]\t' | |
| 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | |
| 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' | |
| 'loss {loss.val:.3f} ({loss.avg:.3f})\t'.format( | |
| epoch, idx + 1, len(train_loader), batch_time=batch_time, | |
| data_time=data_time, loss=losses_train)) | |
| sys.stdout.flush() | |
| test_top1 = validate(anchor_loader, test_loader, model) | |
| train_writer.add_scalar('train_loss_supcon', losses_train.avg, epoch + 1) | |
| train_writer.add_scalar('test_top1', test_top1, epoch + 1) | |
| scheduler.step() | |
| # Save latest checkpoint. | |
| chk_path = os.path.join(opts.checkpoint, 'latest_epoch.bin') | |
| print('Saving checkpoint to', chk_path) | |
| torch.save({ | |
| 'epoch': epoch+1, | |
| 'lr': scheduler.get_last_lr(), | |
| 'optimizer': optimizer.state_dict(), | |
| 'model': model.state_dict(), | |
| 'best_acc' : best_acc | |
| }, chk_path) | |
| # Save best checkpoint | |
| best_chk_path = os.path.join(opts.checkpoint, 'best_epoch.bin'.format(epoch)) | |
| if test_top1 > best_acc: | |
| best_acc = test_top1 | |
| print("save best checkpoint") | |
| torch.save({ | |
| 'epoch': epoch+1, | |
| 'lr': scheduler.get_last_lr(), | |
| 'optimizer': optimizer.state_dict(), | |
| 'model': model.state_dict(), | |
| 'best_acc' : best_acc | |
| }, best_chk_path) | |
| if opts.evaluate: | |
| test_top1 = validate(anchor_loader, test_loader, model) | |
| print(test_top1) | |
| if __name__ == "__main__": | |
| opts = parse_args() | |
| args = get_config(opts.config) | |
| train_with_config(args, opts) | |