Spaces:
Configuration error
Configuration error
| import os | |
| import random | |
| import copy | |
| import time | |
| import sys | |
| import shutil | |
| import argparse | |
| import errno | |
| import math | |
| import numpy as np | |
| from collections import defaultdict, OrderedDict | |
| import tensorboardX | |
| from tqdm import tqdm | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from torch.optim.lr_scheduler import StepLR | |
| from lib.utils.tools import * | |
| from lib.model.loss import * | |
| from lib.model.loss_mesh import * | |
| from lib.utils.utils_mesh import * | |
| from lib.utils.utils_smpl import * | |
| from lib.utils.utils_data import * | |
| from lib.utils.learning import * | |
| from lib.data.dataset_mesh import MotionSMPL | |
| from lib.model.model_mesh import MeshRegressor | |
| from torch.utils.data import DataLoader | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", type=str, default="configs/pretrain.yaml", help="Path to the config file.") | |
| parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', help='checkpoint directory') | |
| parser.add_argument('-p', '--pretrained', default='checkpoint', type=str, metavar='PATH', help='pretrained checkpoint directory') | |
| parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME', help='checkpoint to resume (file name)') | |
| parser.add_argument('-e', '--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)') | |
| parser.add_argument('-freq', '--print_freq', default=100) | |
| parser.add_argument('-ms', '--selection', default='latest_epoch.bin', type=str, metavar='FILENAME', help='checkpoint to finetune (file name)') | |
| parser.add_argument('-sd', '--seed', default=0, type=int, help='random seed') | |
| opts = parser.parse_args() | |
| return opts | |
| def set_random_seed(seed): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| def validate(test_loader, model, criterion, dataset_name='h36m'): | |
| model.eval() | |
| print(f'===========> validating {dataset_name}') | |
| batch_time = AverageMeter() | |
| losses = AverageMeter() | |
| losses_dict = {'loss_3d_pos': AverageMeter(), | |
| 'loss_3d_scale': AverageMeter(), | |
| 'loss_3d_velocity': AverageMeter(), | |
| 'loss_lv': AverageMeter(), | |
| 'loss_lg': AverageMeter(), | |
| 'loss_a': AverageMeter(), | |
| 'loss_av': AverageMeter(), | |
| 'loss_pose': AverageMeter(), | |
| 'loss_shape': AverageMeter(), | |
| 'loss_norm': AverageMeter(), | |
| } | |
| mpjpes = AverageMeter() | |
| mpves = AverageMeter() | |
| results = defaultdict(list) | |
| smpl = SMPL(args.data_root, batch_size=1).cuda() | |
| J_regressor = smpl.J_regressor_h36m | |
| with torch.no_grad(): | |
| end = time.time() | |
| for idx, (batch_input, batch_gt) in tqdm(enumerate(test_loader)): | |
| batch_size, clip_len = batch_input.shape[:2] | |
| if torch.cuda.is_available(): | |
| batch_gt['theta'] = batch_gt['theta'].cuda().float() | |
| batch_gt['kp_3d'] = batch_gt['kp_3d'].cuda().float() | |
| batch_gt['verts'] = batch_gt['verts'].cuda().float() | |
| batch_input = batch_input.cuda().float() | |
| output = model(batch_input) | |
| output_final = output | |
| if args.flip: | |
| batch_input_flip = flip_data(batch_input) | |
| output_flip = model(batch_input_flip) | |
| output_flip_pose = output_flip[0]['theta'][:, :, :72] | |
| output_flip_shape = output_flip[0]['theta'][:, :, 72:] | |
| output_flip_pose = flip_thetas_batch(output_flip_pose) | |
| output_flip_pose = output_flip_pose.reshape(-1, 72) | |
| output_flip_shape = output_flip_shape.reshape(-1, 10) | |
| output_flip_smpl = smpl( | |
| betas=output_flip_shape, | |
| body_pose=output_flip_pose[:, 3:], | |
| global_orient=output_flip_pose[:, :3], | |
| pose2rot=True | |
| ) | |
| output_flip_verts = output_flip_smpl.vertices.detach()*1000.0 | |
| J_regressor_batch = J_regressor[None, :].expand(output_flip_verts.shape[0], -1, -1).to(output_flip_verts.device) | |
| output_flip_kp3d = torch.matmul(J_regressor_batch, output_flip_verts) # (NT,17,3) | |
| output_flip_back = [{ | |
| 'theta': torch.cat((output_flip_pose.reshape(batch_size, clip_len, -1), output_flip_shape.reshape(batch_size, clip_len, -1)), dim=-1), | |
| 'verts': output_flip_verts.reshape(batch_size, clip_len, -1, 3), | |
| 'kp_3d': output_flip_kp3d.reshape(batch_size, clip_len, -1, 3), | |
| }] | |
| output_final = [{}] | |
| for k, v in output_flip[0].items(): | |
| output_final[0][k] = (output[0][k] + output_flip_back[0][k])*0.5 | |
| output = output_final | |
| loss_dict = criterion(output, batch_gt) | |
| loss = args.lambda_3d * loss_dict['loss_3d_pos'] + \ | |
| args.lambda_scale * loss_dict['loss_3d_scale'] + \ | |
| args.lambda_3dv * loss_dict['loss_3d_velocity'] + \ | |
| args.lambda_lv * loss_dict['loss_lv'] + \ | |
| args.lambda_lg * loss_dict['loss_lg'] + \ | |
| args.lambda_a * loss_dict['loss_a'] + \ | |
| args.lambda_av * loss_dict['loss_av'] + \ | |
| args.lambda_shape * loss_dict['loss_shape'] + \ | |
| args.lambda_pose * loss_dict['loss_pose'] + \ | |
| args.lambda_norm * loss_dict['loss_norm'] | |
| # update metric | |
| losses.update(loss.item(), batch_size) | |
| loss_str = '' | |
| for k, v in loss_dict.items(): | |
| losses_dict[k].update(v.item(), batch_size) | |
| loss_str += '{0} {loss.val:.3f} ({loss.avg:.3f})\t'.format(k, loss=losses_dict[k]) | |
| mpjpe, mpve = compute_error(output, batch_gt) | |
| mpjpes.update(mpjpe, batch_size) | |
| mpves.update(mpve, batch_size) | |
| for keys in output[0].keys(): | |
| output[0][keys] = output[0][keys].detach().cpu().numpy() | |
| batch_gt[keys] = batch_gt[keys].detach().cpu().numpy() | |
| results['kp_3d'].append(output[0]['kp_3d']) | |
| results['verts'].append(output[0]['verts']) | |
| results['kp_3d_gt'].append(batch_gt['kp_3d']) | |
| results['verts_gt'].append(batch_gt['verts']) | |
| # measure elapsed time | |
| batch_time.update(time.time() - end) | |
| end = time.time() | |
| if idx % int(opts.print_freq) == 0: | |
| print('Test: [{0}/{1}]\t' | |
| 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | |
| 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | |
| '{2}' | |
| 'PVE {mpves.val:.3f} ({mpves.avg:.3f})\t' | |
| 'JPE {mpjpes.val:.3f} ({mpjpes.avg:.3f})'.format( | |
| idx, len(test_loader), loss_str, batch_time=batch_time, | |
| loss=losses, mpves=mpves, mpjpes=mpjpes)) | |
| print(f'==> start concating results of {dataset_name}') | |
| for term in results.keys(): | |
| results[term] = np.concatenate(results[term]) | |
| print(f'==> start evaluating {dataset_name}...') | |
| error_dict = evaluate_mesh(results) | |
| err_str = '' | |
| for err_key, err_val in error_dict.items(): | |
| err_str += '{}: {:.2f}mm \t'.format(err_key, err_val) | |
| print(f'=======================> {dataset_name} validation done: ', loss_str) | |
| print(f'=======================> {dataset_name} validation done: ', err_str) | |
| return losses.avg, error_dict['mpjpe'], error_dict['pa_mpjpe'], error_dict['mpve'], losses_dict | |
| def train_epoch(args, opts, model, train_loader, losses_train, losses_dict, mpjpes, mpves, criterion, optimizer, batch_time, data_time, epoch): | |
| model.train() | |
| end = time.time() | |
| for idx, (batch_input, batch_gt) in tqdm(enumerate(train_loader)): | |
| data_time.update(time.time() - end) | |
| batch_size = len(batch_input) | |
| if torch.cuda.is_available(): | |
| batch_gt['theta'] = batch_gt['theta'].cuda().float() | |
| batch_gt['kp_3d'] = batch_gt['kp_3d'].cuda().float() | |
| batch_gt['verts'] = batch_gt['verts'].cuda().float() | |
| batch_input = batch_input.cuda().float() | |
| output = model(batch_input) | |
| optimizer.zero_grad() | |
| loss_dict = criterion(output, batch_gt) | |
| loss_train = args.lambda_3d * loss_dict['loss_3d_pos'] + \ | |
| args.lambda_scale * loss_dict['loss_3d_scale'] + \ | |
| args.lambda_3dv * loss_dict['loss_3d_velocity'] + \ | |
| args.lambda_lv * loss_dict['loss_lv'] + \ | |
| args.lambda_lg * loss_dict['loss_lg'] + \ | |
| args.lambda_a * loss_dict['loss_a'] + \ | |
| args.lambda_av * loss_dict['loss_av'] + \ | |
| args.lambda_shape * loss_dict['loss_shape'] + \ | |
| args.lambda_pose * loss_dict['loss_pose'] + \ | |
| args.lambda_norm * loss_dict['loss_norm'] | |
| losses_train.update(loss_train.item(), batch_size) | |
| loss_str = '' | |
| for k, v in loss_dict.items(): | |
| losses_dict[k].update(v.item(), batch_size) | |
| loss_str += '{0} {loss.val:.3f} ({loss.avg:.3f})\t'.format(k, loss=losses_dict[k]) | |
| mpjpe, mpve = compute_error(output, batch_gt) | |
| mpjpes.update(mpjpe, batch_size) | |
| mpves.update(mpve, batch_size) | |
| loss_train.backward() | |
| optimizer.step() | |
| batch_time.update(time.time() - end) | |
| end = time.time() | |
| if idx % int(opts.print_freq) == 0: | |
| print('Train: [{0}][{1}/{2}]\t' | |
| 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | |
| 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' | |
| 'loss {loss.val:.3f} ({loss.avg:.3f})\t' | |
| '{3}' | |
| 'PVE {mpves.val:.3f} ({mpves.avg:.3f})\t' | |
| 'JPE {mpjpes.val:.3f} ({mpjpes.avg:.3f})'.format( | |
| epoch, idx + 1, len(train_loader), loss_str, batch_time=batch_time, | |
| data_time=data_time, loss=losses_train, mpves=mpves, mpjpes=mpjpes)) | |
| sys.stdout.flush() | |
| def train_with_config(args, opts): | |
| print(args) | |
| try: | |
| os.makedirs(opts.checkpoint) | |
| shutil.copy(opts.config, opts.checkpoint) | |
| except OSError as e: | |
| if e.errno != errno.EEXIST: | |
| raise RuntimeError('Unable to create checkpoint directory:', opts.checkpoint) | |
| train_writer = tensorboardX.SummaryWriter(os.path.join(opts.checkpoint, "logs")) | |
| model_backbone = load_backbone(args) | |
| if args.finetune: | |
| if opts.resume or opts.evaluate: | |
| pass | |
| else: | |
| chk_filename = os.path.join(opts.pretrained, opts.selection) | |
| print('Loading backbone', chk_filename) | |
| checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)['model_pos'] | |
| model_backbone = load_pretrained_weights(model_backbone, checkpoint) | |
| if args.partial_train: | |
| model_backbone = partial_train_layers(model_backbone, args.partial_train) | |
| model = MeshRegressor(args, backbone=model_backbone, dim_rep=args.dim_rep, hidden_dim=args.hidden_dim, dropout_ratio=args.dropout, num_joints=args.num_joints) | |
| criterion = MeshLoss(loss_type = args.loss_type) | |
| best_jpe = 9999.0 | |
| model_params = 0 | |
| for parameter in model.parameters(): | |
| if parameter.requires_grad == True: | |
| model_params = model_params + parameter.numel() | |
| print('INFO: Trainable parameter count:', model_params) | |
| print('Loading dataset...') | |
| trainloader_params = { | |
| 'batch_size': args.batch_size, | |
| 'shuffle': True, | |
| 'num_workers': 8, | |
| 'pin_memory': True, | |
| 'prefetch_factor': 4, | |
| 'persistent_workers': True | |
| } | |
| testloader_params = { | |
| 'batch_size': args.batch_size, | |
| 'shuffle': False, | |
| 'num_workers': 8, | |
| 'pin_memory': True, | |
| 'prefetch_factor': 4, | |
| 'persistent_workers': True | |
| } | |
| if hasattr(args, "dt_file_h36m"): | |
| mesh_train = MotionSMPL(args, data_split='train', dataset="h36m") | |
| mesh_val = MotionSMPL(args, data_split='test', dataset="h36m") | |
| train_loader = DataLoader(mesh_train, **trainloader_params) | |
| test_loader = DataLoader(mesh_val, **testloader_params) | |
| print('INFO: Training on {} batches (h36m)'.format(len(train_loader))) | |
| if hasattr(args, "dt_file_pw3d"): | |
| if args.train_pw3d: | |
| mesh_train_pw3d = MotionSMPL(args, data_split='train', dataset="pw3d") | |
| train_loader_pw3d = DataLoader(mesh_train_pw3d, **trainloader_params) | |
| print('INFO: Training on {} batches (pw3d)'.format(len(train_loader_pw3d))) | |
| mesh_val_pw3d = MotionSMPL(args, data_split='test', dataset="pw3d") | |
| test_loader_pw3d = DataLoader(mesh_val_pw3d, **testloader_params) | |
| trainloader_img_params = { | |
| 'batch_size': args.batch_size_img, | |
| 'shuffle': True, | |
| 'num_workers': 8, | |
| 'pin_memory': True, | |
| 'prefetch_factor': 4, | |
| 'persistent_workers': True | |
| } | |
| testloader_img_params = { | |
| 'batch_size': args.batch_size_img, | |
| 'shuffle': False, | |
| 'num_workers': 8, | |
| 'pin_memory': True, | |
| 'prefetch_factor': 4, | |
| 'persistent_workers': True | |
| } | |
| if hasattr(args, "dt_file_coco"): | |
| mesh_train_coco = MotionSMPL(args, data_split='train', dataset="coco") | |
| mesh_val_coco = MotionSMPL(args, data_split='test', dataset="coco") | |
| train_loader_coco = DataLoader(mesh_train_coco, **trainloader_img_params) | |
| test_loader_coco = DataLoader(mesh_val_coco, **testloader_img_params) | |
| print('INFO: Training on {} batches (coco)'.format(len(train_loader_coco))) | |
| if torch.cuda.is_available(): | |
| model = nn.DataParallel(model) | |
| model = model.cuda() | |
| chk_filename = os.path.join(opts.checkpoint, "latest_epoch.bin") | |
| if os.path.exists(chk_filename): | |
| opts.resume = chk_filename | |
| if opts.resume or opts.evaluate: | |
| chk_filename = opts.evaluate if opts.evaluate else opts.resume | |
| print('Loading checkpoint', chk_filename) | |
| checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) | |
| model.load_state_dict(checkpoint['model'], strict=True) | |
| if not opts.evaluate: | |
| optimizer = optim.AdamW( | |
| [ {"params": filter(lambda p: p.requires_grad, model.module.backbone.parameters()), "lr": args.lr_backbone}, | |
| {"params": filter(lambda p: p.requires_grad, model.module.head.parameters()), "lr": args.lr_head}, | |
| ], lr=args.lr_backbone, | |
| weight_decay=args.weight_decay | |
| ) | |
| scheduler = StepLR(optimizer, step_size=1, gamma=args.lr_decay) | |
| st = 0 | |
| if opts.resume: | |
| st = checkpoint['epoch'] | |
| if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None: | |
| optimizer.load_state_dict(checkpoint['optimizer']) | |
| else: | |
| print('WARNING: this checkpoint does not contain an optimizer state. The optimizer will be reinitialized.') | |
| lr = checkpoint['lr'] | |
| if 'best_jpe' in checkpoint and checkpoint['best_jpe'] is not None: | |
| best_jpe = checkpoint['best_jpe'] | |
| # Training | |
| for epoch in range(st, args.epochs): | |
| print('Training epoch %d.' % epoch) | |
| losses_train = AverageMeter() | |
| losses_dict = { | |
| 'loss_3d_pos': AverageMeter(), | |
| 'loss_3d_scale': AverageMeter(), | |
| 'loss_3d_velocity': AverageMeter(), | |
| 'loss_lv': AverageMeter(), | |
| 'loss_lg': AverageMeter(), | |
| 'loss_a': AverageMeter(), | |
| 'loss_av': AverageMeter(), | |
| 'loss_pose': AverageMeter(), | |
| 'loss_shape': AverageMeter(), | |
| 'loss_norm': AverageMeter(), | |
| } | |
| mpjpes = AverageMeter() | |
| mpves = AverageMeter() | |
| batch_time = AverageMeter() | |
| data_time = AverageMeter() | |
| if hasattr(args, "dt_file_h36m") and epoch < args.warmup_h36m: | |
| train_epoch(args, opts, model, train_loader, losses_train, losses_dict, mpjpes, mpves, criterion, optimizer, batch_time, data_time, epoch) | |
| test_loss, test_mpjpe, test_pa_mpjpe, test_mpve, test_losses_dict = validate(test_loader, model, criterion, 'h36m') | |
| for k, v in test_losses_dict.items(): | |
| train_writer.add_scalar('test_loss/'+k, v.avg, epoch + 1) | |
| train_writer.add_scalar('test_loss', test_loss, epoch + 1) | |
| train_writer.add_scalar('test_mpjpe', test_mpjpe, epoch + 1) | |
| train_writer.add_scalar('test_pa_mpjpe', test_pa_mpjpe, epoch + 1) | |
| train_writer.add_scalar('test_mpve', test_mpve, epoch + 1) | |
| if hasattr(args, "dt_file_coco") and epoch < args.warmup_coco: | |
| train_epoch(args, opts, model, train_loader_coco, losses_train, losses_dict, mpjpes, mpves, criterion, optimizer, batch_time, data_time, epoch) | |
| if hasattr(args, "dt_file_pw3d"): | |
| if args.train_pw3d: | |
| train_epoch(args, opts, model, train_loader_pw3d, losses_train, losses_dict, mpjpes, mpves, criterion, optimizer, batch_time, data_time, epoch) | |
| test_loss_pw3d, test_mpjpe_pw3d, test_pa_mpjpe_pw3d, test_mpve_pw3d, test_losses_dict_pw3d = validate(test_loader_pw3d, model, criterion, 'pw3d') | |
| for k, v in test_losses_dict_pw3d.items(): | |
| train_writer.add_scalar('test_loss_pw3d/'+k, v.avg, epoch + 1) | |
| train_writer.add_scalar('test_loss_pw3d', test_loss_pw3d, epoch + 1) | |
| train_writer.add_scalar('test_mpjpe_pw3d', test_mpjpe_pw3d, epoch + 1) | |
| train_writer.add_scalar('test_pa_mpjpe_pw3d', test_pa_mpjpe_pw3d, epoch + 1) | |
| train_writer.add_scalar('test_mpve_pw3d', test_mpve_pw3d, epoch + 1) | |
| for k, v in losses_dict.items(): | |
| train_writer.add_scalar('train_loss/'+k, v.avg, epoch + 1) | |
| train_writer.add_scalar('train_loss', losses_train.avg, epoch + 1) | |
| train_writer.add_scalar('train_mpjpe', mpjpes.avg, epoch + 1) | |
| train_writer.add_scalar('train_mpve', mpves.avg, epoch + 1) | |
| # Decay learning rate exponentially | |
| scheduler.step() | |
| # Save latest checkpoint. | |
| chk_path = os.path.join(opts.checkpoint, 'latest_epoch.bin') | |
| print('Saving checkpoint to', chk_path) | |
| torch.save({ | |
| 'epoch': epoch+1, | |
| 'lr': scheduler.get_last_lr(), | |
| 'optimizer': optimizer.state_dict(), | |
| 'model': model.state_dict(), | |
| 'best_jpe' : best_jpe | |
| }, chk_path) | |
| # Save checkpoint if necessary. | |
| if (epoch+1) % args.checkpoint_frequency == 0: | |
| chk_path = os.path.join(opts.checkpoint, 'epoch_{}.bin'.format(epoch)) | |
| print('Saving checkpoint to', chk_path) | |
| torch.save({ | |
| 'epoch': epoch+1, | |
| 'lr': scheduler.get_last_lr(), | |
| 'optimizer': optimizer.state_dict(), | |
| 'model': model.state_dict(), | |
| 'best_jpe' : best_jpe | |
| }, chk_path) | |
| if hasattr(args, "dt_file_pw3d"): | |
| best_jpe_cur = test_mpjpe_pw3d | |
| else: | |
| best_jpe_cur = test_mpjpe | |
| # Save best checkpoint. | |
| best_chk_path = os.path.join(opts.checkpoint, 'best_epoch.bin'.format(epoch)) | |
| if best_jpe_cur < best_jpe: | |
| best_jpe = best_jpe_cur | |
| print("save best checkpoint") | |
| torch.save({ | |
| 'epoch': epoch+1, | |
| 'lr': scheduler.get_last_lr(), | |
| 'optimizer': optimizer.state_dict(), | |
| 'model': model.state_dict(), | |
| 'best_jpe' : best_jpe | |
| }, best_chk_path) | |
| if opts.evaluate: | |
| if hasattr(args, "dt_file_h36m"): | |
| test_loss, test_mpjpe, test_pa_mpjpe, test_mpve, _ = validate(test_loader, model, criterion, 'h36m') | |
| if hasattr(args, "dt_file_pw3d"): | |
| test_loss_pw3d, test_mpjpe_pw3d, test_pa_mpjpe_pw3d, test_mpve_pw3d, _ = validate(test_loader_pw3d, model, criterion, 'pw3d') | |
| if __name__ == "__main__": | |
| opts = parse_args() | |
| set_random_seed(opts.seed) | |
| args = get_config(opts.config) | |
| train_with_config(args, opts) |