YOND / archs /Unet.py
hansen97's picture
Initial clean commit
0e07d71
from .modules import *
def data_normalize(data, lower=None, upper=None):
lower = 0#torch.tensor([data[b].min() for b in range(data.shape[0])],
#dtype=data.dtype, device=data.device).view(-1,1,1,1)
# upper = torch.tensor([data[b].max() for b in range(data.shape[0])],
# dtype=data.dtype, device=data.device).view(-1,1,1,1)
upper = torch.amax(data, dim=(1,2,3), keepdim=True).clip(1e-5, 1) # 不会暗到1e-5这么逆天吧……
data = (data - lower) / (upper - lower)
return data, lower, upper
def data_inv_normalize(data, lower, upper):
data = data * (upper - lower) + lower
return data
# SID Unet
class UNetSeeInDark(nn.Module):
def __init__(self, args=None):
super().__init__()
self.args = args
self.nframes = args['nframes']
self.cf = 0
self.res = args['res']
self.norm = args['norm'] if 'norm' in args else False
nframes = self.args['nframes'] if 'nframes' in args else 1
nf = args['nf']
in_nc = args['in_nc']
out_nc = args['out_nc']
self.conv1_1 = nn.Conv2d(in_nc*nframes, nf, kernel_size=3, stride=1, padding=1)
self.conv1_2 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1)
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.conv2_1 = nn.Conv2d(nf, nf*2, kernel_size=3, stride=1, padding=1)
self.conv2_2 = nn.Conv2d(nf*2, nf*2, kernel_size=3, stride=1, padding=1)
self.pool2 = nn.MaxPool2d(kernel_size=2)
self.conv3_1 = nn.Conv2d(nf*2, nf*4, kernel_size=3, stride=1, padding=1)
self.conv3_2 = nn.Conv2d(nf*4, nf*4, kernel_size=3, stride=1, padding=1)
self.pool3 = nn.MaxPool2d(kernel_size=2)
self.conv4_1 = nn.Conv2d(nf*4, nf*8, kernel_size=3, stride=1, padding=1)
self.conv4_2 = nn.Conv2d(nf*8, nf*8, kernel_size=3, stride=1, padding=1)
self.pool4 = nn.MaxPool2d(kernel_size=2)
self.conv5_1 = nn.Conv2d(nf*8, nf*16, kernel_size=3, stride=1, padding=1)
self.conv5_2 = nn.Conv2d(nf*16, nf*16, kernel_size=3, stride=1, padding=1)
self.upv6 = nn.ConvTranspose2d(nf*16, nf*8, 2, stride=2)
self.conv6_1 = nn.Conv2d(nf*16, nf*8, kernel_size=3, stride=1, padding=1)
self.conv6_2 = nn.Conv2d(nf*8, nf*8, kernel_size=3, stride=1, padding=1)
self.upv7 = nn.ConvTranspose2d(nf*8, nf*4, 2, stride=2)
self.conv7_1 = nn.Conv2d(nf*8, nf*4, kernel_size=3, stride=1, padding=1)
self.conv7_2 = nn.Conv2d(nf*4, nf*4, kernel_size=3, stride=1, padding=1)
self.upv8 = nn.ConvTranspose2d(nf*4, nf*2, 2, stride=2)
self.conv8_1 = nn.Conv2d(nf*4, nf*2, kernel_size=3, stride=1, padding=1)
self.conv8_2 = nn.Conv2d(nf*2, nf*2, kernel_size=3, stride=1, padding=1)
self.upv9 = nn.ConvTranspose2d(nf*2, nf, 2, stride=2)
self.conv9_1 = nn.Conv2d(nf*2, nf, kernel_size=3, stride=1, padding=1)
self.conv9_2 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1)
self.conv10_1 = nn.Conv2d(nf, out_nc, kernel_size=1, stride=1)
self.relu = nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
if self.norm:
x, lb, ub = data_normalize(x)
conv1 = self.relu(self.conv1_1(x))
conv1 = self.relu(self.conv1_2(conv1))
pool1 = self.pool1(conv1)
conv2 = self.relu(self.conv2_1(pool1))
conv2 = self.relu(self.conv2_2(conv2))
pool2 = self.pool1(conv2)
conv3 = self.relu(self.conv3_1(pool2))
conv3 = self.relu(self.conv3_2(conv3))
pool3 = self.pool1(conv3)
conv4 = self.relu(self.conv4_1(pool3))
conv4 = self.relu(self.conv4_2(conv4))
pool4 = self.pool1(conv4)
conv5 = self.relu(self.conv5_1(pool4))
conv5 = self.relu(self.conv5_2(conv5))
up6 = self.upv6(conv5)
up6 = torch.cat([up6, conv4], 1)
conv6 = self.relu(self.conv6_1(up6))
conv6 = self.relu(self.conv6_2(conv6))
up7 = self.upv7(conv6)
up7 = torch.cat([up7, conv3], 1)
conv7 = self.relu(self.conv7_1(up7))
conv7 = self.relu(self.conv7_2(conv7))
up8 = self.upv8(conv7)
up8 = torch.cat([up8, conv2], 1)
conv8 = self.relu(self.conv8_1(up8))
conv8 = self.relu(self.conv8_2(conv8))
up9 = self.upv9(conv8)
up9 = torch.cat([up9, conv1], 1)
conv9 = self.relu(self.conv9_1(up9))
conv9 = self.relu(self.conv9_2(conv9))
out = self.conv10_1(conv9)
if self.res:
out = out + x[:, self.cf*4:self.cf*4+4]
if self.norm:
out = data_inv_normalize(out, lb, ub)
return out
def get_updown_module(nf, updown_type='conv', mode='up'):
if updown_type == 'conv':
if mode == 'down':
return conv3x3(nf, nf*2)
elif mode == 'up':
return nn.ConvTranspose2d(nf, nf//2, 2, stride=2)
elif updown_type in ['bilinear', 'bicubic', 'nearest']:
if mode == 'down':
return nn.Sequential(
nn.Upsample(1/2, mode=updown_type),
nn.Conv2d(nf, nf*2, kernel_size=3, stride=1, padding=1),
)
if mode == 'up':
return nn.Sequential(
nn.Upsample(2, mode=updown_type),
nn.Conv2d(nf, nf//2, kernel_size=3, stride=1, padding=1),
)
elif updown_type == 'shuffle':
if mode == 'down':
return nn.Sequential(
nn.PixelUnshuffle(2),
nn.Conv2d(nf*4, nf*2, kernel_size=3, stride=1, padding=1),
)
if mode == 'up':
return nn.Sequential(
nn.PixelShuffle(2),
nn.Conv2d(nf//4, nf//2, kernel_size=3, stride=1, padding=1),
)
elif updown_type in ['haar','db1','db2','db3']:
if mode == 'down':
return nn.Sequential(
WaveletDecompose(updown_type),
nn.Conv2d(nf*4, nf*2, kernel_size=3, stride=1, padding=1),
)
if mode == 'up':
return nn.Sequential(
WaveletReconstruct(updown_type),
nn.Conv2d(nf//4, nf//2, kernel_size=3, stride=1, padding=1),
)
class GuidedResUnet(nn.Module):
def __init__(self, args=None):
super().__init__()
self.args = args
self.cf = 0
self.nframes = nframes = args.get('nframes', 1)
self.res = args.get('res', False)
self.norm = args.get('norm', False)
self.updown_type = args.get('updown_type', 'conv')
self.downsample = args.get('downsample', False)
if self.downsample == 'shuffle':
self.down_fn = nn.PixelUnshuffle(2)
self.up_fn = nn.PixelShuffle(2)
elif self.downsample != False:
self.down_fn = WaveletDecompose(mode=self.downsample)
self.up_fn = WaveletReconstruct(mode=self.downsample)
ext = 4 if self.downsample else 1
nf = args.get('nf', 32)
in_nc = args.get('in_nc', 4)
out_nc = args.get('out_nc', 4)
self.conv_in = nn.Conv2d(in_nc*nframes*ext, nf, kernel_size=3, stride=1, padding=1)
self.conv1 = GuidedResidualBlock(nf, nf, is_activate=False)
self.pool1 = get_updown_module(nf, self.updown_type, mode='down')
self.conv2 = GuidedResidualBlock(nf*2, nf*2, is_activate=False)
self.pool2 = get_updown_module(nf*2, self.updown_type, mode='down')
self.conv3 = GuidedResidualBlock(nf*4, nf*4, is_activate=False)
self.pool3 = get_updown_module(nf*4, self.updown_type, mode='down')
self.conv4 = GuidedResidualBlock(nf*8, nf*8, is_activate=False)
self.pool4 = get_updown_module(nf*8, self.updown_type, mode='down')
self.conv5 = GuidedResidualBlock(nf*16, nf*16, is_activate=False)
self.upv6 = get_updown_module(nf*16, self.updown_type, mode='up')
self.conv6 = GuidedResidualBlock(nf*16, nf*8, is_activate=False)
self.upv7 = get_updown_module(nf*8, self.updown_type, mode='up')
self.conv7 = GuidedResidualBlock(nf*8, nf*4, is_activate=False)
self.upv8 = get_updown_module(nf*4, self.updown_type, mode='up')
self.conv8 = GuidedResidualBlock(nf*4, nf*2, is_activate=False)
self.upv9 = get_updown_module(nf*2, self.updown_type, mode='up')
self.conv9 = GuidedResidualBlock(nf*2, nf, is_activate=False)
self.conv10 = nn.Conv2d(nf, out_nc*ext, kernel_size=1, stride=1)
self.lrelu = nn.LeakyReLU(0.01, inplace=True)
def forward(self, x, t):
# shape= x.size()
# x = x.view(-1,shape[-3],shape[-2],shape[-1])
if self.norm:
x, lb, ub = data_normalize(x)
t = t / (ub-lb)
if self.downsample:
x = self.down_fn(x)
conv_in = self.lrelu(self.conv_in(x))
conv1 = self.conv1(conv_in, t)
pool1 = self.pool1(conv1)
conv2 = self.conv2(pool1, t)
pool2 = self.pool2(conv2)
conv3 = self.conv3(pool2, t)
pool3 = self.pool3(conv3)
conv4 = self.conv4(pool3, t)
pool4 = self.pool4(conv4)
conv5 = self.conv5(pool4, t)
up6 = self.upv6(conv5)
up6 = torch.cat([up6, conv4], 1)
conv6 = self.conv6(up6, t)
up7 = self.upv7(conv6)
up7 = torch.cat([up7, conv3], 1)
conv7 = self.conv7(up7, t)
up8 = self.upv8(conv7)
up8 = torch.cat([up8, conv2], 1)
conv8 = self.conv8(up8, t)
up9 = self.upv9(conv8)
up9 = torch.cat([up9, conv1], 1)
conv9 = self.conv9(up9, t)
out = self.conv10(conv9)
if self.res:
out = out + x[:, self.cf*4:self.cf*4+4]
if self.downsample:
out = self.up_fn(out)
if self.norm:
out = data_inv_normalize(out, lb, ub)
return out