Spaces:
Running
Running
| import cv2 | |
| import sys | |
| sys.path.insert(0, "FaceBoxesV2") | |
| sys.path.insert(0, "..") | |
| from math import floor | |
| from faceboxes_detector import * | |
| import torch | |
| import torch.nn.parallel | |
| import torch.utils.data | |
| import torchvision.transforms as transforms | |
| import torchvision.models as models | |
| from networks import * | |
| from functions import * | |
| from PIPNet.reverse_index import ri1, ri2 | |
| class Config: | |
| def __init__(self): | |
| self.det_head = "pip" | |
| self.net_stride = 32 | |
| self.batch_size = 16 | |
| self.init_lr = 0.0001 | |
| self.num_epochs = 60 | |
| self.decay_steps = [30, 50] | |
| self.input_size = 256 | |
| self.backbone = "resnet101" | |
| self.pretrained = True | |
| self.criterion_cls = "l2" | |
| self.criterion_reg = "l1" | |
| self.cls_loss_weight = 10 | |
| self.reg_loss_weight = 1 | |
| self.num_lms = 98 | |
| self.save_interval = self.num_epochs | |
| self.num_nb = 10 | |
| self.use_gpu = True | |
| self.gpu_id = 3 | |
| def get_lmk_model(): | |
| cfg = Config() | |
| resnet101 = models.resnet101(pretrained=cfg.pretrained) | |
| net = Pip_resnet101( | |
| resnet101, | |
| cfg.num_nb, | |
| num_lms=cfg.num_lms, | |
| input_size=cfg.input_size, | |
| net_stride=cfg.net_stride, | |
| ) | |
| if cfg.use_gpu: | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| else: | |
| device = torch.device("cpu") | |
| net = net.to(device) | |
| weight_file = "/apdcephfs/share_1290939/ahbanliang/codes/PIPNet/snapshots/WFLW/pip_32_16_60_r101_l2_l1_10_1_nb10/epoch59.pth" | |
| state_dict = torch.load(weight_file, map_location=device) | |
| net.load_state_dict(state_dict) | |
| detector = FaceBoxesDetector( | |
| "FaceBoxes", | |
| "FaceBoxesV2/weights/FaceBoxesV2.pth", | |
| use_gpu=True, | |
| device="cuda:0", | |
| ) | |
| return net, detector | |
| def demo_image( | |
| image_file, | |
| net, | |
| detector, | |
| input_size=256, | |
| net_stride=32, | |
| num_nb=10, | |
| use_gpu=True, | |
| device="cuda:0", | |
| ): | |
| my_thresh = 0.6 | |
| det_box_scale = 1.2 | |
| net.eval() | |
| preprocess = transforms.Compose( | |
| [ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| reverse_index1, reverse_index2, max_len = ri1, ri2, 17 | |
| image = cv2.imread(image_file) | |
| image_height, image_width, _ = image.shape | |
| detections, _ = detector.detect(image, my_thresh, 1) | |
| for i in range(len(detections)): | |
| det_xmin = detections[i][2] | |
| det_ymin = detections[i][3] | |
| det_width = detections[i][4] | |
| det_height = detections[i][5] | |
| det_xmax = det_xmin + det_width - 1 | |
| det_ymax = det_ymin + det_height - 1 | |
| det_xmin -= int(det_width * (det_box_scale - 1) / 2) | |
| # remove a part of top area for alignment, see paper for details | |
| det_ymin += int(det_height * (det_box_scale - 1) / 2) | |
| det_xmax += int(det_width * (det_box_scale - 1) / 2) | |
| det_ymax += int(det_height * (det_box_scale - 1) / 2) | |
| det_xmin = max(det_xmin, 0) | |
| det_ymin = max(det_ymin, 0) | |
| det_xmax = min(det_xmax, image_width - 1) | |
| det_ymax = min(det_ymax, image_height - 1) | |
| det_width = det_xmax - det_xmin + 1 | |
| det_height = det_ymax - det_ymin + 1 | |
| cv2.rectangle(image, (det_xmin, det_ymin), (det_xmax, det_ymax), (0, 0, 255), 2) | |
| det_crop = image[det_ymin:det_ymax, det_xmin:det_xmax, :] | |
| det_crop = cv2.resize(det_crop, (input_size, input_size)) | |
| inputs = Image.fromarray(det_crop[:, :, ::-1].astype("uint8"), "RGB") | |
| inputs = preprocess(inputs).unsqueeze(0) | |
| inputs = inputs.to(device) | |
| ( | |
| lms_pred_x, | |
| lms_pred_y, | |
| lms_pred_nb_x, | |
| lms_pred_nb_y, | |
| outputs_cls, | |
| max_cls, | |
| ) = forward_pip(net, inputs, preprocess, input_size, net_stride, num_nb) | |
| lms_pred = torch.cat((lms_pred_x, lms_pred_y), dim=1).flatten() | |
| tmp_nb_x = lms_pred_nb_x[reverse_index1, reverse_index2].view(98, max_len) | |
| tmp_nb_y = lms_pred_nb_y[reverse_index1, reverse_index2].view(98, max_len) | |
| tmp_x = torch.mean(torch.cat((lms_pred_x, tmp_nb_x), dim=1), dim=1).view(-1, 1) | |
| tmp_y = torch.mean(torch.cat((lms_pred_y, tmp_nb_y), dim=1), dim=1).view(-1, 1) | |
| lms_pred_merge = torch.cat((tmp_x, tmp_y), dim=1).flatten() | |
| lms_pred = lms_pred.cpu().numpy() | |
| lms_pred_merge = lms_pred_merge.cpu().numpy() | |
| for i in range(98): | |
| x_pred = lms_pred_merge[i * 2] * det_width | |
| y_pred = lms_pred_merge[i * 2 + 1] * det_height | |
| cv2.circle( | |
| image, | |
| (int(x_pred) + det_xmin, int(y_pred) + det_ymin), | |
| 1, | |
| (0, 0, 255), | |
| 2, | |
| ) | |
| cv2.imwrite("images/1_out.jpg", image) | |
| if __name__ == "__main__": | |
| net, detector = get_lmk_model() | |
| demo_image( | |
| "/apdcephfs/private_ahbanliang/codes/Real-ESRGAN-master/tmp_frames/yanikefu/frame00000046.png", | |
| net, | |
| detector, | |
| ) | |