import numpy as np import torch import torch.distributions as tdist # from utils import fn_timer def random_ccm(camera_type='IMX686'): """Generates random RGB -> Camera color correction matrices.""" # Takes a random convex combination of XYZ -> Camera CCMs. xyz2cams = [[[1.0234, -0.2969, -0.2266], [-0.5625, 1.6328, -0.0469], [-0.0703, 0.2188, 0.6406]], [[0.4913, -0.0541, -0.0202], [-0.613, 1.3513, 0.2906], [-0.1564, 0.2151, 0.7183]], [[0.838, -0.263, -0.0639], [-0.2887, 1.0725, 0.2496], [-0.0627, 0.1427, 0.5438]], [[0.6596, -0.2079, -0.0562], [-0.4782, 1.3016, 0.1933], [-0.097, 0.1581, 0.5181]]] num_ccms = len(xyz2cams) xyz2cams = torch.FloatTensor(xyz2cams) weights = torch.FloatTensor(num_ccms, 1, 1).uniform_(1e-8, 1e8) weights_sum = torch.sum(weights, dim=0) xyz2cam = torch.sum(xyz2cams * weights, dim=0) / weights_sum # Multiplies with RGB -> XYZ to get RGB -> Camera CCM. rgb2xyz = torch.FloatTensor([[0.4124564, 0.3575761, 0.1804375], [0.2126729, 0.7151522, 0.0721750], [0.0193339, 0.1191920, 0.9503041]]) rgb2cam = torch.mm(xyz2cam, rgb2xyz) # if camera_type == 'SonyA7S2': # # SonyA7S2 ccm's inv # rgb2cam = [[1.,0.,0.], # [0.,1.,0.], # [0.,0.,1.]] # elif camera_type == 'IMX686': # # RedMi K30 ccm's inv # rgb2cam = [[0.61093086,0.31565922,0.07340994], # [0.09433191,0.7658969,0.1397712 ], # [0.03532438,0.3020709,0.6626047 ]] # rgb2cam = torch.FloatTensor(rgb2cam) # Normalizes each row. rgb2cam = rgb2cam / torch.sum(rgb2cam, dim=-1, keepdim=True) return rgb2cam def random_gains(): """Generates random gains for brightening and white balance.""" # RGB gain represents brightening. n = tdist.Normal(loc=torch.tensor([0.8]), scale=torch.tensor([0.1])) rgb_gain = 1.0 / n.sample() if torch.rand(1) < 0.9 else 5 / n.sample() # Red and blue gains represent white balance. red_gain = torch.FloatTensor(1).uniform_(1.4, 2.5)#(1.9, 2.4) blue_gain = torch.FloatTensor(1).uniform_(1.5, 2.4)#(1.5, 1.9) return rgb_gain, red_gain, blue_gain # def random_gains(camera_type='SonyA7S2'): # # return torch.FloatTensor(np.array([[1.],[1.],[1.]])) # n = tdist.Normal(loc=torch.tensor([0.8]), scale=torch.tensor([0.1])) # rgb_gain = 1.0 / n.sample() # # SonyA7S2 # if camera_type == 'SonyA7S2': # red_gain = np.random.uniform(1.75, 2.65) # ployfit = [14.65 ,-9.63942308, 1.80288462 ] # blue_gain= ployfit[0] + ployfit[1] * red_gain + ployfit[2] * red_gain ** 2# + np.random.uniform(0, 0.4)686 # elif camera_type == 'IMX686': # red_gain = np.random.uniform(1.4, 2.3) # ployfit = [6.14381188, -3.65620261, 0.70205967] # blue_gain= ployfit[0] + ployfit[1] * red_gain + ployfit[2] * red_gain ** 2# + np.random.uniform(0, 0.4) # else: # raise NotImplementedError # red_gain = torch.FloatTensor(np.array([red_gain])).view(1) # blue_gain = torch.FloatTensor(np.array([blue_gain])).view(1) # return rgb_gain, red_gain, blue_gain def inverse_smoothstep(image): """Approximately inverts a global tone mapping curve.""" #image = image.permute(1, 2, 0) # Permute the image tensor to HxWxC format from CxHxW format image = torch.clamp(image, min=0.0, max=1.0) out = 0.5 - torch.sin(torch.asin(1.0 - 2.0 * image) / 3.0) #out = out.permute(2, 0, 1) # Re-Permute the tensor back to CxHxW format return out def gamma_expansion(image): """Converts from gamma to linear space.""" # Clamps to prevent numerical instability of gradients near zero. #image = image.permute(1, 2, 0) # Permute the image tensor to HxWxC format from CxHxW format out = torch.clamp(image, min=1e-8) ** 2.2 #out = out.permute(2, 0, 1) # Re-Permute the tensor back to CxHxW format return out def apply_ccm(image, ccm): """Applies a color correction matrix.""" shape = image.size() image = torch.reshape(image, [-1, 3]) image = torch.tensordot(image, ccm, dims=[[-1], [-1]]) out = torch.reshape(image, shape) return out def safe_invert_gains(image, rgb_gain, red_gain, blue_gain, use_gpu=False): """Inverts gains while safely handling saturated pixels.""" # H, W, C green = torch.tensor([1.0]) if use_gpu: green = green.cuda() gains = torch.stack((1.0 / red_gain, green, 1.0 / blue_gain)) / rgb_gain new_shape = (1,) * (len(image.shape) - 1) + (-1,) gains = gains.view(new_shape) #gains = gains[None, None, :] # Prevents dimming of saturated pixels by smoothly masking gains near white. gray = torch.mean(image, dim=-1, keepdim=True) inflection = 0.9 mask = (torch.clamp(gray - inflection, min=0.0) / (1.0 - inflection)) ** 2.0 safe_gains = torch.max(mask + (1.0 - mask) * gains, gains) out = image * safe_gains return out def mosaic(image): """Extracts RGGB Bayer planes from an RGB image.""" if image.size() == 3: # image = image.permute(1, 2, 0) # Permute the image tensor to HxWxC format from CxHxW format shape = image.size() red = image[0::2, 0::2, 0] green_red = image[0::2, 1::2, 1] green_blue = image[1::2, 0::2, 1] blue = image[1::2, 1::2, 2] out = torch.stack((red, green_red, green_blue, blue), dim=-1) # out = torch.reshape(out, (shape[0] // 2, shape[1] // 2, 4)) # out = out.permute(2, 0, 1) # Re-Permute the tensor back to CxHxW format else: # [crops, t, h, w, c] shape = image.size() red = image[..., 0::2, 0::2, 0] green_red = image[..., 0::2, 1::2, 1] green_blue = image[..., 1::2, 0::2, 1] blue = image[..., 1::2, 1::2, 2] out = torch.stack((red, green_red, green_blue, blue), dim=-1) # out = torch.reshape(out, (shape[0], shape[1], shape[-3] // 2, shape[-2] // 2, 4)) # out = out.permute(2, 0, 1) # Re-Permute the tensor back to CxHxW format return out # def mosaic(image, mode=0): # """Extracts Random Bayer planes from an RGB image.""" # if mode == 0: # RGGB # R, Gr, Gb, B = (0,0), (0,1), (1,0), (0,0) # elif mode == 1: # GRBG # Gr, R, B, Gb = (0,0), (0,1), (1,0), (0,0) # elif mode == 2: # GBRG # Gb, B, R, Gr = (0,0), (0,1), (1,0), (0,0) # elif mode == 3: # BGGR # B, Gb, Gr, R = (0,0), (0,1), (1,0), (0,0) # shape = image.size() # red = image[..., R[0]::2, R[1]::2, 0] # green_red = image[..., Gr[0]::2, Gr[1]::2, 1] # green_blue = image[..., Gb[0]::2, Gb[1]::2, 1] # blue = image[..., B[0]::2, B[1]::2, 2] # out = torch.stack((red, green_red, green_blue, blue), dim=-1) # # out = torch.reshape(out, (shape[0], shape[1], shape[-3] // 2, shape[-2] // 2, 4)) # # out = out.permute(2, 0, 1) # Re-Permute the tensor back to CxHxW format # return out # @ fn_timer def unprocess(image, lock_wb=False, use_gpu=False, camera_type='IMX686', seed=None): """Unprocesses an image from sRGB to realistic raw data.""" # Randomly creates image metadata. rgb2cam = random_ccm() cam2rgb = torch.inverse(rgb2cam) # rgb_gain, red_gain, blue_gain = random_gains() if lock_wb is False else torch.FloatTensor(np.array([[1.],[2.],[2.]])) rgb_gain, red_gain, blue_gain = random_gains() if lock_wb is False else torch.FloatTensor(np.array(lock_wb)) if use_gpu: rgb_gain, red_gain, blue_gain = rgb_gain.cuda(), red_gain.cuda(), blue_gain.cuda() if len(image.size()) >= 4: res = image.clone() for i in range(image.size()[0]): temp = image[i] temp = inverse_smoothstep(temp) temp = gamma_expansion(temp) temp = apply_ccm(temp, rgb2cam) temp = safe_invert_gains(temp, rgb_gain, red_gain, blue_gain, use_gpu) temp = torch.clamp(temp, min=0.0, max=1.0) res[i]= temp.clone() metadata = { 'cam2rgb': cam2rgb, 'rgb_gain': rgb_gain, 'red_gain': red_gain, 'blue_gain': blue_gain, } return res, metadata else: # Approximately inverts global tone mapping. image = inverse_smoothstep(image) # Inverts gamma compression. image = gamma_expansion(image) # Inverts color correction. image = apply_ccm(image, rgb2cam) # Approximately inverts white balance and brightening. image = safe_invert_gains(image, rgb_gain, red_gain, blue_gain, use_gpu) # Clips saturated pixels. image = torch.clamp(image, min=0.0, max=1.0) # Applies a Bayer mosaic. #image = mosaic(image) metadata = { 'cam2rgb': cam2rgb, 'rgb_gain': rgb_gain, 'red_gain': red_gain, 'blue_gain': blue_gain, } return image, metadata def unprocess_rpdc(image, lock_wb=False, use_gpu=False, camera_type='IMX686', known=None): """Unprocesses an image from sRGB to realistic raw data.""" # Randomly creates image metadata. if known is not None: cam2rgb = known['cam2rgb'] rgb2cam = known['rgb2cam'] rgb_gain = known['rgb_gain'] red_gain = known['red_gain'] blue_gain = known['blue_gain'] else: rgb2cam = random_ccm() cam2rgb = torch.inverse(rgb2cam) rgb_gain, red_gain, blue_gain = random_gains() if lock_wb is False else torch.FloatTensor(np.array(lock_wb)) if use_gpu: rgb_gain, red_gain, blue_gain = rgb_gain.cuda(), red_gain.cuda(), blue_gain.cuda() res = image.clone() for i in range(image.size()[0]): temp = image[i] temp = inverse_smoothstep(temp) temp = gamma_expansion(temp) temp = apply_ccm(temp, rgb2cam) temp = safe_invert_gains(temp, rgb_gain, red_gain, blue_gain, use_gpu) temp = torch.clamp(temp, min=0.0, max=1.0) res[i]= temp.clone() metadata = { 'rgb2cam': rgb2cam, 'cam2rgb': cam2rgb, 'rgb_gain': rgb_gain, 'red_gain': red_gain, 'blue_gain': blue_gain, } return res, metadata def random_noise_levels(): """Generates random noise levels from a log-log linear distribution.""" log_min_shot_noise = np.log(0.0001) log_max_shot_noise = np.log(0.012) log_shot_noise = torch.FloatTensor(1).uniform_(log_min_shot_noise, log_max_shot_noise) shot_noise = torch.exp(log_shot_noise) line = lambda x: 2.18 * x + 1.20 n = tdist.Normal(loc=torch.tensor([0.0]), scale=torch.tensor([0.26])) log_read_noise = line(log_shot_noise) + n.sample() read_noise = torch.exp(log_read_noise) return shot_noise, read_noise def add_noise(image, shot_noise=0.01, read_noise=0.0005): """Adds random shot (proportional to image) and read (independent) noise.""" image = image.permute(1, 2, 0) # Permute the image tensor to HxWxC format from CxHxW format variance = image * shot_noise + read_noise n = tdist.Normal(loc=torch.zeros_like(variance), scale=torch.sqrt(variance)) noise = n.sample() out = image + noise out = out.permute(2, 0, 1) # Re-Permute the tensor back to CxHxW format return out if __name__ == '__main__': m = tdist.Poisson(torch.tensor([10.,100.,1000.])) for i in range(10): s = m.sample() print(s.numpy())