Spaces:
Running
Running
| import torch.utils.data as data | |
| import torch | |
| from PIL import Image, ImageFilter | |
| import os, cv2 | |
| import numpy as np | |
| import random | |
| from scipy.stats import norm | |
| from math import floor | |
| def random_translate(image, target): | |
| if random.random() > 0.5: | |
| image_height, image_width = image.size | |
| a = 1 | |
| b = 0 | |
| #c = 30 #left/right (i.e. 5/-5) | |
| c = int((random.random()-0.5) * 60) | |
| d = 0 | |
| e = 1 | |
| #f = 30 #up/down (i.e. 5/-5) | |
| f = int((random.random()-0.5) * 60) | |
| image = image.transform(image.size, Image.AFFINE, (a, b, c, d, e, f)) | |
| target_translate = target.copy() | |
| target_translate = target_translate.reshape(-1, 2) | |
| target_translate[:, 0] -= 1.*c/image_width | |
| target_translate[:, 1] -= 1.*f/image_height | |
| target_translate = target_translate.flatten() | |
| target_translate[target_translate < 0] = 0 | |
| target_translate[target_translate > 1] = 1 | |
| return image, target_translate | |
| else: | |
| return image, target | |
| def random_blur(image): | |
| if random.random() > 0.7: | |
| image = image.filter(ImageFilter.GaussianBlur(random.random()*5)) | |
| return image | |
| def random_occlusion(image): | |
| if random.random() > 0.5: | |
| image_np = np.array(image).astype(np.uint8) | |
| image_np = image_np[:,:,::-1] | |
| image_height, image_width, _ = image_np.shape | |
| occ_height = int(image_height*0.4*random.random()) | |
| occ_width = int(image_width*0.4*random.random()) | |
| occ_xmin = int((image_width - occ_width - 10) * random.random()) | |
| occ_ymin = int((image_height - occ_height - 10) * random.random()) | |
| image_np[occ_ymin:occ_ymin+occ_height, occ_xmin:occ_xmin+occ_width, 0] = int(random.random() * 255) | |
| image_np[occ_ymin:occ_ymin+occ_height, occ_xmin:occ_xmin+occ_width, 1] = int(random.random() * 255) | |
| image_np[occ_ymin:occ_ymin+occ_height, occ_xmin:occ_xmin+occ_width, 2] = int(random.random() * 255) | |
| image_pil = Image.fromarray(image_np[:,:,::-1].astype('uint8'), 'RGB') | |
| return image_pil | |
| else: | |
| return image | |
| def random_flip(image, target, points_flip): | |
| if random.random() > 0.5: | |
| image = image.transpose(Image.FLIP_LEFT_RIGHT) | |
| target = np.array(target).reshape(-1, 2) | |
| target = target[points_flip, :] | |
| target[:,0] = 1-target[:,0] | |
| target = target.flatten() | |
| return image, target | |
| else: | |
| return image, target | |
| def random_rotate(image, target, angle_max): | |
| if random.random() > 0.5: | |
| center_x = 0.5 | |
| center_y = 0.5 | |
| landmark_num= int(len(target) / 2) | |
| target_center = np.array(target) - np.array([center_x, center_y]*landmark_num) | |
| target_center = target_center.reshape(landmark_num, 2) | |
| theta_max = np.radians(angle_max) | |
| theta = random.uniform(-theta_max, theta_max) | |
| angle = np.degrees(theta) | |
| image = image.rotate(angle) | |
| c, s = np.cos(theta), np.sin(theta) | |
| rot = np.array(((c,-s), (s, c))) | |
| target_center_rot = np.matmul(target_center, rot) | |
| target_rot = target_center_rot.reshape(landmark_num*2) + np.array([center_x, center_y]*landmark_num) | |
| return image, target_rot | |
| else: | |
| return image, target | |
| def gen_target_pip(target, meanface_indices, target_map, target_local_x, target_local_y, target_nb_x, target_nb_y): | |
| num_nb = len(meanface_indices[0]) | |
| map_channel, map_height, map_width = target_map.shape | |
| target = target.reshape(-1, 2) | |
| assert map_channel == target.shape[0] | |
| for i in range(map_channel): | |
| mu_x = int(floor(target[i][0] * map_width)) | |
| mu_y = int(floor(target[i][1] * map_height)) | |
| mu_x = max(0, mu_x) | |
| mu_y = max(0, mu_y) | |
| mu_x = min(mu_x, map_width-1) | |
| mu_y = min(mu_y, map_height-1) | |
| target_map[i, mu_y, mu_x] = 1 | |
| shift_x = target[i][0] * map_width - mu_x | |
| shift_y = target[i][1] * map_height - mu_y | |
| target_local_x[i, mu_y, mu_x] = shift_x | |
| target_local_y[i, mu_y, mu_x] = shift_y | |
| for j in range(num_nb): | |
| nb_x = target[meanface_indices[i][j]][0] * map_width - mu_x | |
| nb_y = target[meanface_indices[i][j]][1] * map_height - mu_y | |
| target_nb_x[num_nb*i+j, mu_y, mu_x] = nb_x | |
| target_nb_y[num_nb*i+j, mu_y, mu_x] = nb_y | |
| return target_map, target_local_x, target_local_y, target_nb_x, target_nb_y | |
| class ImageFolder_pip(data.Dataset): | |
| def __init__(self, root, imgs, input_size, num_lms, net_stride, points_flip, meanface_indices, transform=None, target_transform=None): | |
| self.root = root | |
| self.imgs = imgs | |
| self.num_lms = num_lms | |
| self.net_stride = net_stride | |
| self.points_flip = points_flip | |
| self.meanface_indices = meanface_indices | |
| self.num_nb = len(meanface_indices[0]) | |
| self.transform = transform | |
| self.target_transform = target_transform | |
| self.input_size = input_size | |
| def __getitem__(self, index): | |
| img_name, target = self.imgs[index] | |
| img = Image.open(os.path.join(self.root, img_name)).convert('RGB') | |
| img, target = random_translate(img, target) | |
| img = random_occlusion(img) | |
| img, target = random_flip(img, target, self.points_flip) | |
| img, target = random_rotate(img, target, 30) | |
| img = random_blur(img) | |
| target_map = np.zeros((self.num_lms, int(self.input_size/self.net_stride), int(self.input_size/self.net_stride))) | |
| target_local_x = np.zeros((self.num_lms, int(self.input_size/self.net_stride), int(self.input_size/self.net_stride))) | |
| target_local_y = np.zeros((self.num_lms, int(self.input_size/self.net_stride), int(self.input_size/self.net_stride))) | |
| target_nb_x = np.zeros((self.num_nb*self.num_lms, int(self.input_size/self.net_stride), int(self.input_size/self.net_stride))) | |
| target_nb_y = np.zeros((self.num_nb*self.num_lms, int(self.input_size/self.net_stride), int(self.input_size/self.net_stride))) | |
| target_map, target_local_x, target_local_y, target_nb_x, target_nb_y = gen_target_pip(target, self.meanface_indices, target_map, target_local_x, target_local_y, target_nb_x, target_nb_y) | |
| target_map = torch.from_numpy(target_map).float() | |
| target_local_x = torch.from_numpy(target_local_x).float() | |
| target_local_y = torch.from_numpy(target_local_y).float() | |
| target_nb_x = torch.from_numpy(target_nb_x).float() | |
| target_nb_y = torch.from_numpy(target_nb_y).float() | |
| if self.transform is not None: | |
| img = self.transform(img) | |
| if self.target_transform is not None: | |
| target_map = self.target_transform(target_map) | |
| target_local_x = self.target_transform(target_local_x) | |
| target_local_y = self.target_transform(target_local_y) | |
| target_nb_x = self.target_transform(target_nb_x) | |
| target_nb_y = self.target_transform(target_nb_y) | |
| return img, target_map, target_local_x, target_local_y, target_nb_x, target_nb_y | |
| def __len__(self): | |
| return len(self.imgs) | |
| if __name__ == '__main__': | |
| pass | |