YOND / data_process /unprocess.py
hansen97's picture
Initial clean commit
0e07d71
import numpy as np
import torch
import torch.distributions as tdist
# from utils import fn_timer
def random_ccm(camera_type='IMX686'):
"""Generates random RGB -> Camera color correction matrices."""
# Takes a random convex combination of XYZ -> Camera CCMs.
xyz2cams = [[[1.0234, -0.2969, -0.2266],
[-0.5625, 1.6328, -0.0469],
[-0.0703, 0.2188, 0.6406]],
[[0.4913, -0.0541, -0.0202],
[-0.613, 1.3513, 0.2906],
[-0.1564, 0.2151, 0.7183]],
[[0.838, -0.263, -0.0639],
[-0.2887, 1.0725, 0.2496],
[-0.0627, 0.1427, 0.5438]],
[[0.6596, -0.2079, -0.0562],
[-0.4782, 1.3016, 0.1933],
[-0.097, 0.1581, 0.5181]]]
num_ccms = len(xyz2cams)
xyz2cams = torch.FloatTensor(xyz2cams)
weights = torch.FloatTensor(num_ccms, 1, 1).uniform_(1e-8, 1e8)
weights_sum = torch.sum(weights, dim=0)
xyz2cam = torch.sum(xyz2cams * weights, dim=0) / weights_sum
# Multiplies with RGB -> XYZ to get RGB -> Camera CCM.
rgb2xyz = torch.FloatTensor([[0.4124564, 0.3575761, 0.1804375],
[0.2126729, 0.7151522, 0.0721750],
[0.0193339, 0.1191920, 0.9503041]])
rgb2cam = torch.mm(xyz2cam, rgb2xyz)
# if camera_type == 'SonyA7S2':
# # SonyA7S2 ccm's inv
# rgb2cam = [[1.,0.,0.],
# [0.,1.,0.],
# [0.,0.,1.]]
# elif camera_type == 'IMX686':
# # RedMi K30 ccm's inv
# rgb2cam = [[0.61093086,0.31565922,0.07340994],
# [0.09433191,0.7658969,0.1397712 ],
# [0.03532438,0.3020709,0.6626047 ]]
# rgb2cam = torch.FloatTensor(rgb2cam)
# Normalizes each row.
rgb2cam = rgb2cam / torch.sum(rgb2cam, dim=-1, keepdim=True)
return rgb2cam
def random_gains():
"""Generates random gains for brightening and white balance."""
# RGB gain represents brightening.
n = tdist.Normal(loc=torch.tensor([0.8]), scale=torch.tensor([0.1]))
rgb_gain = 1.0 / n.sample() if torch.rand(1) < 0.9 else 5 / n.sample()
# Red and blue gains represent white balance.
red_gain = torch.FloatTensor(1).uniform_(1.4, 2.5)#(1.9, 2.4)
blue_gain = torch.FloatTensor(1).uniform_(1.5, 2.4)#(1.5, 1.9)
return rgb_gain, red_gain, blue_gain
# def random_gains(camera_type='SonyA7S2'):
# # return torch.FloatTensor(np.array([[1.],[1.],[1.]]))
# n = tdist.Normal(loc=torch.tensor([0.8]), scale=torch.tensor([0.1]))
# rgb_gain = 1.0 / n.sample()
# # SonyA7S2
# if camera_type == 'SonyA7S2':
# red_gain = np.random.uniform(1.75, 2.65)
# ployfit = [14.65 ,-9.63942308, 1.80288462 ]
# blue_gain= ployfit[0] + ployfit[1] * red_gain + ployfit[2] * red_gain ** 2# + np.random.uniform(0, 0.4)686
# elif camera_type == 'IMX686':
# red_gain = np.random.uniform(1.4, 2.3)
# ployfit = [6.14381188, -3.65620261, 0.70205967]
# blue_gain= ployfit[0] + ployfit[1] * red_gain + ployfit[2] * red_gain ** 2# + np.random.uniform(0, 0.4)
# else:
# raise NotImplementedError
# red_gain = torch.FloatTensor(np.array([red_gain])).view(1)
# blue_gain = torch.FloatTensor(np.array([blue_gain])).view(1)
# return rgb_gain, red_gain, blue_gain
def inverse_smoothstep(image):
"""Approximately inverts a global tone mapping curve."""
#image = image.permute(1, 2, 0) # Permute the image tensor to HxWxC format from CxHxW format
image = torch.clamp(image, min=0.0, max=1.0)
out = 0.5 - torch.sin(torch.asin(1.0 - 2.0 * image) / 3.0)
#out = out.permute(2, 0, 1) # Re-Permute the tensor back to CxHxW format
return out
def gamma_expansion(image):
"""Converts from gamma to linear space."""
# Clamps to prevent numerical instability of gradients near zero.
#image = image.permute(1, 2, 0) # Permute the image tensor to HxWxC format from CxHxW format
out = torch.clamp(image, min=1e-8) ** 2.2
#out = out.permute(2, 0, 1) # Re-Permute the tensor back to CxHxW format
return out
def apply_ccm(image, ccm):
"""Applies a color correction matrix."""
shape = image.size()
image = torch.reshape(image, [-1, 3])
image = torch.tensordot(image, ccm, dims=[[-1], [-1]])
out = torch.reshape(image, shape)
return out
def safe_invert_gains(image, rgb_gain, red_gain, blue_gain, use_gpu=False):
"""Inverts gains while safely handling saturated pixels."""
# H, W, C
green = torch.tensor([1.0])
if use_gpu: green = green.cuda()
gains = torch.stack((1.0 / red_gain, green, 1.0 / blue_gain)) / rgb_gain
new_shape = (1,) * (len(image.shape) - 1) + (-1,)
gains = gains.view(new_shape)
#gains = gains[None, None, :]
# Prevents dimming of saturated pixels by smoothly masking gains near white.
gray = torch.mean(image, dim=-1, keepdim=True)
inflection = 0.9
mask = (torch.clamp(gray - inflection, min=0.0) / (1.0 - inflection)) ** 2.0
safe_gains = torch.max(mask + (1.0 - mask) * gains, gains)
out = image * safe_gains
return out
def mosaic(image):
"""Extracts RGGB Bayer planes from an RGB image."""
if image.size() == 3:
# image = image.permute(1, 2, 0) # Permute the image tensor to HxWxC format from CxHxW format
shape = image.size()
red = image[0::2, 0::2, 0]
green_red = image[0::2, 1::2, 1]
green_blue = image[1::2, 0::2, 1]
blue = image[1::2, 1::2, 2]
out = torch.stack((red, green_red, green_blue, blue), dim=-1)
# out = torch.reshape(out, (shape[0] // 2, shape[1] // 2, 4))
# out = out.permute(2, 0, 1) # Re-Permute the tensor back to CxHxW format
else: # [crops, t, h, w, c]
shape = image.size()
red = image[..., 0::2, 0::2, 0]
green_red = image[..., 0::2, 1::2, 1]
green_blue = image[..., 1::2, 0::2, 1]
blue = image[..., 1::2, 1::2, 2]
out = torch.stack((red, green_red, green_blue, blue), dim=-1)
# out = torch.reshape(out, (shape[0], shape[1], shape[-3] // 2, shape[-2] // 2, 4))
# out = out.permute(2, 0, 1) # Re-Permute the tensor back to CxHxW format
return out
# def mosaic(image, mode=0):
# """Extracts Random Bayer planes from an RGB image."""
# if mode == 0: # RGGB
# R, Gr, Gb, B = (0,0), (0,1), (1,0), (0,0)
# elif mode == 1: # GRBG
# Gr, R, B, Gb = (0,0), (0,1), (1,0), (0,0)
# elif mode == 2: # GBRG
# Gb, B, R, Gr = (0,0), (0,1), (1,0), (0,0)
# elif mode == 3: # BGGR
# B, Gb, Gr, R = (0,0), (0,1), (1,0), (0,0)
# shape = image.size()
# red = image[..., R[0]::2, R[1]::2, 0]
# green_red = image[..., Gr[0]::2, Gr[1]::2, 1]
# green_blue = image[..., Gb[0]::2, Gb[1]::2, 1]
# blue = image[..., B[0]::2, B[1]::2, 2]
# out = torch.stack((red, green_red, green_blue, blue), dim=-1)
# # out = torch.reshape(out, (shape[0], shape[1], shape[-3] // 2, shape[-2] // 2, 4))
# # out = out.permute(2, 0, 1) # Re-Permute the tensor back to CxHxW format
# return out
# @ fn_timer
def unprocess(image, lock_wb=False, use_gpu=False, camera_type='IMX686', seed=None):
"""Unprocesses an image from sRGB to realistic raw data."""
# Randomly creates image metadata.
rgb2cam = random_ccm()
cam2rgb = torch.inverse(rgb2cam)
# rgb_gain, red_gain, blue_gain = random_gains() if lock_wb is False else torch.FloatTensor(np.array([[1.],[2.],[2.]]))
rgb_gain, red_gain, blue_gain = random_gains() if lock_wb is False else torch.FloatTensor(np.array(lock_wb))
if use_gpu:
rgb_gain, red_gain, blue_gain = rgb_gain.cuda(), red_gain.cuda(), blue_gain.cuda()
if len(image.size()) >= 4:
res = image.clone()
for i in range(image.size()[0]):
temp = image[i]
temp = inverse_smoothstep(temp)
temp = gamma_expansion(temp)
temp = apply_ccm(temp, rgb2cam)
temp = safe_invert_gains(temp, rgb_gain, red_gain, blue_gain, use_gpu)
temp = torch.clamp(temp, min=0.0, max=1.0)
res[i]= temp.clone()
metadata = {
'cam2rgb': cam2rgb,
'rgb_gain': rgb_gain,
'red_gain': red_gain,
'blue_gain': blue_gain,
}
return res, metadata
else:
# Approximately inverts global tone mapping.
image = inverse_smoothstep(image)
# Inverts gamma compression.
image = gamma_expansion(image)
# Inverts color correction.
image = apply_ccm(image, rgb2cam)
# Approximately inverts white balance and brightening.
image = safe_invert_gains(image, rgb_gain, red_gain, blue_gain, use_gpu)
# Clips saturated pixels.
image = torch.clamp(image, min=0.0, max=1.0)
# Applies a Bayer mosaic.
#image = mosaic(image)
metadata = {
'cam2rgb': cam2rgb,
'rgb_gain': rgb_gain,
'red_gain': red_gain,
'blue_gain': blue_gain,
}
return image, metadata
def unprocess_rpdc(image, lock_wb=False, use_gpu=False, camera_type='IMX686', known=None):
"""Unprocesses an image from sRGB to realistic raw data."""
# Randomly creates image metadata.
if known is not None:
cam2rgb = known['cam2rgb']
rgb2cam = known['rgb2cam']
rgb_gain = known['rgb_gain']
red_gain = known['red_gain']
blue_gain = known['blue_gain']
else:
rgb2cam = random_ccm()
cam2rgb = torch.inverse(rgb2cam)
rgb_gain, red_gain, blue_gain = random_gains() if lock_wb is False else torch.FloatTensor(np.array(lock_wb))
if use_gpu:
rgb_gain, red_gain, blue_gain = rgb_gain.cuda(), red_gain.cuda(), blue_gain.cuda()
res = image.clone()
for i in range(image.size()[0]):
temp = image[i]
temp = inverse_smoothstep(temp)
temp = gamma_expansion(temp)
temp = apply_ccm(temp, rgb2cam)
temp = safe_invert_gains(temp, rgb_gain, red_gain, blue_gain, use_gpu)
temp = torch.clamp(temp, min=0.0, max=1.0)
res[i]= temp.clone()
metadata = {
'rgb2cam': rgb2cam,
'cam2rgb': cam2rgb,
'rgb_gain': rgb_gain,
'red_gain': red_gain,
'blue_gain': blue_gain,
}
return res, metadata
def random_noise_levels():
"""Generates random noise levels from a log-log linear distribution."""
log_min_shot_noise = np.log(0.0001)
log_max_shot_noise = np.log(0.012)
log_shot_noise = torch.FloatTensor(1).uniform_(log_min_shot_noise, log_max_shot_noise)
shot_noise = torch.exp(log_shot_noise)
line = lambda x: 2.18 * x + 1.20
n = tdist.Normal(loc=torch.tensor([0.0]), scale=torch.tensor([0.26]))
log_read_noise = line(log_shot_noise) + n.sample()
read_noise = torch.exp(log_read_noise)
return shot_noise, read_noise
def add_noise(image, shot_noise=0.01, read_noise=0.0005):
"""Adds random shot (proportional to image) and read (independent) noise."""
image = image.permute(1, 2, 0) # Permute the image tensor to HxWxC format from CxHxW format
variance = image * shot_noise + read_noise
n = tdist.Normal(loc=torch.zeros_like(variance), scale=torch.sqrt(variance))
noise = n.sample()
out = image + noise
out = out.permute(2, 0, 1) # Re-Permute the tensor back to CxHxW format
return out
if __name__ == '__main__':
m = tdist.Poisson(torch.tensor([10.,100.,1000.]))
for i in range(10):
s = m.sample()
print(s.numpy())