lucid-hf's picture
CI: deploy Docker/PDM Space
98a3af2 verified
"""
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
import time
import json
import datetime
from pathlib import Path
import torch
import torch.nn as nn
from ..misc import dist_utils
from ._solver import BaseSolver
from .clas_engine import train_one_epoch, evaluate
class ClasSolver(BaseSolver):
def fit(self, ):
print("Start training")
self.train()
args = self.cfg
n_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
print('Number of params:', n_parameters)
output_dir = Path(args.output_dir)
output_dir.mkdir(exist_ok=True)
start_time = time.time()
start_epoch = self.last_epoch + 1
for epoch in range(start_epoch, args.epoches):
if dist_utils.is_dist_available_and_initialized():
self.train_dataloader.sampler.set_epoch(epoch)
train_stats = train_one_epoch(self.model,
self.criterion,
self.train_dataloader,
self.optimizer,
self.ema,
epoch=epoch,
device=self.device)
self.lr_scheduler.step()
self.last_epoch += 1
if output_dir:
checkpoint_paths = [output_dir / 'checkpoint.pth']
# extra checkpoint before LR drop and every 100 epochs
if (epoch + 1) % args.checkpoint_freq == 0:
checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
for checkpoint_path in checkpoint_paths:
dist_utils.save_on_master(self.state_dict(epoch), checkpoint_path)
module = self.ema.module if self.ema else self.model
test_stats = evaluate(module, self.criterion, self.val_dataloader, self.device)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if output_dir and dist_utils.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))