Spaces:
Running
on
Zero
Running
on
Zero
| from .modules import * | |
| def data_normalize(data, lower=None, upper=None): | |
| lower = 0#torch.tensor([data[b].min() for b in range(data.shape[0])], | |
| #dtype=data.dtype, device=data.device).view(-1,1,1,1) | |
| # upper = torch.tensor([data[b].max() for b in range(data.shape[0])], | |
| # dtype=data.dtype, device=data.device).view(-1,1,1,1) | |
| upper = torch.amax(data, dim=(1,2,3), keepdim=True).clip(1e-5, 1) # 不会暗到1e-5这么逆天吧…… | |
| data = (data - lower) / (upper - lower) | |
| return data, lower, upper | |
| def data_inv_normalize(data, lower, upper): | |
| data = data * (upper - lower) + lower | |
| return data | |
| # SID Unet | |
| class UNetSeeInDark(nn.Module): | |
| def __init__(self, args=None): | |
| super().__init__() | |
| self.args = args | |
| self.nframes = args['nframes'] | |
| self.cf = 0 | |
| self.res = args['res'] | |
| self.norm = args['norm'] if 'norm' in args else False | |
| nframes = self.args['nframes'] if 'nframes' in args else 1 | |
| nf = args['nf'] | |
| in_nc = args['in_nc'] | |
| out_nc = args['out_nc'] | |
| self.conv1_1 = nn.Conv2d(in_nc*nframes, nf, kernel_size=3, stride=1, padding=1) | |
| self.conv1_2 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1) | |
| self.pool1 = nn.MaxPool2d(kernel_size=2) | |
| self.conv2_1 = nn.Conv2d(nf, nf*2, kernel_size=3, stride=1, padding=1) | |
| self.conv2_2 = nn.Conv2d(nf*2, nf*2, kernel_size=3, stride=1, padding=1) | |
| self.pool2 = nn.MaxPool2d(kernel_size=2) | |
| self.conv3_1 = nn.Conv2d(nf*2, nf*4, kernel_size=3, stride=1, padding=1) | |
| self.conv3_2 = nn.Conv2d(nf*4, nf*4, kernel_size=3, stride=1, padding=1) | |
| self.pool3 = nn.MaxPool2d(kernel_size=2) | |
| self.conv4_1 = nn.Conv2d(nf*4, nf*8, kernel_size=3, stride=1, padding=1) | |
| self.conv4_2 = nn.Conv2d(nf*8, nf*8, kernel_size=3, stride=1, padding=1) | |
| self.pool4 = nn.MaxPool2d(kernel_size=2) | |
| self.conv5_1 = nn.Conv2d(nf*8, nf*16, kernel_size=3, stride=1, padding=1) | |
| self.conv5_2 = nn.Conv2d(nf*16, nf*16, kernel_size=3, stride=1, padding=1) | |
| self.upv6 = nn.ConvTranspose2d(nf*16, nf*8, 2, stride=2) | |
| self.conv6_1 = nn.Conv2d(nf*16, nf*8, kernel_size=3, stride=1, padding=1) | |
| self.conv6_2 = nn.Conv2d(nf*8, nf*8, kernel_size=3, stride=1, padding=1) | |
| self.upv7 = nn.ConvTranspose2d(nf*8, nf*4, 2, stride=2) | |
| self.conv7_1 = nn.Conv2d(nf*8, nf*4, kernel_size=3, stride=1, padding=1) | |
| self.conv7_2 = nn.Conv2d(nf*4, nf*4, kernel_size=3, stride=1, padding=1) | |
| self.upv8 = nn.ConvTranspose2d(nf*4, nf*2, 2, stride=2) | |
| self.conv8_1 = nn.Conv2d(nf*4, nf*2, kernel_size=3, stride=1, padding=1) | |
| self.conv8_2 = nn.Conv2d(nf*2, nf*2, kernel_size=3, stride=1, padding=1) | |
| self.upv9 = nn.ConvTranspose2d(nf*2, nf, 2, stride=2) | |
| self.conv9_1 = nn.Conv2d(nf*2, nf, kernel_size=3, stride=1, padding=1) | |
| self.conv9_2 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1) | |
| self.conv10_1 = nn.Conv2d(nf, out_nc, kernel_size=1, stride=1) | |
| self.relu = nn.LeakyReLU(0.2, inplace=True) | |
| def forward(self, x): | |
| if self.norm: | |
| x, lb, ub = data_normalize(x) | |
| conv1 = self.relu(self.conv1_1(x)) | |
| conv1 = self.relu(self.conv1_2(conv1)) | |
| pool1 = self.pool1(conv1) | |
| conv2 = self.relu(self.conv2_1(pool1)) | |
| conv2 = self.relu(self.conv2_2(conv2)) | |
| pool2 = self.pool1(conv2) | |
| conv3 = self.relu(self.conv3_1(pool2)) | |
| conv3 = self.relu(self.conv3_2(conv3)) | |
| pool3 = self.pool1(conv3) | |
| conv4 = self.relu(self.conv4_1(pool3)) | |
| conv4 = self.relu(self.conv4_2(conv4)) | |
| pool4 = self.pool1(conv4) | |
| conv5 = self.relu(self.conv5_1(pool4)) | |
| conv5 = self.relu(self.conv5_2(conv5)) | |
| up6 = self.upv6(conv5) | |
| up6 = torch.cat([up6, conv4], 1) | |
| conv6 = self.relu(self.conv6_1(up6)) | |
| conv6 = self.relu(self.conv6_2(conv6)) | |
| up7 = self.upv7(conv6) | |
| up7 = torch.cat([up7, conv3], 1) | |
| conv7 = self.relu(self.conv7_1(up7)) | |
| conv7 = self.relu(self.conv7_2(conv7)) | |
| up8 = self.upv8(conv7) | |
| up8 = torch.cat([up8, conv2], 1) | |
| conv8 = self.relu(self.conv8_1(up8)) | |
| conv8 = self.relu(self.conv8_2(conv8)) | |
| up9 = self.upv9(conv8) | |
| up9 = torch.cat([up9, conv1], 1) | |
| conv9 = self.relu(self.conv9_1(up9)) | |
| conv9 = self.relu(self.conv9_2(conv9)) | |
| out = self.conv10_1(conv9) | |
| if self.res: | |
| out = out + x[:, self.cf*4:self.cf*4+4] | |
| if self.norm: | |
| out = data_inv_normalize(out, lb, ub) | |
| return out | |
| def get_updown_module(nf, updown_type='conv', mode='up'): | |
| if updown_type == 'conv': | |
| if mode == 'down': | |
| return conv3x3(nf, nf*2) | |
| elif mode == 'up': | |
| return nn.ConvTranspose2d(nf, nf//2, 2, stride=2) | |
| elif updown_type in ['bilinear', 'bicubic', 'nearest']: | |
| if mode == 'down': | |
| return nn.Sequential( | |
| nn.Upsample(1/2, mode=updown_type), | |
| nn.Conv2d(nf, nf*2, kernel_size=3, stride=1, padding=1), | |
| ) | |
| if mode == 'up': | |
| return nn.Sequential( | |
| nn.Upsample(2, mode=updown_type), | |
| nn.Conv2d(nf, nf//2, kernel_size=3, stride=1, padding=1), | |
| ) | |
| elif updown_type == 'shuffle': | |
| if mode == 'down': | |
| return nn.Sequential( | |
| nn.PixelUnshuffle(2), | |
| nn.Conv2d(nf*4, nf*2, kernel_size=3, stride=1, padding=1), | |
| ) | |
| if mode == 'up': | |
| return nn.Sequential( | |
| nn.PixelShuffle(2), | |
| nn.Conv2d(nf//4, nf//2, kernel_size=3, stride=1, padding=1), | |
| ) | |
| elif updown_type in ['haar','db1','db2','db3']: | |
| if mode == 'down': | |
| return nn.Sequential( | |
| WaveletDecompose(updown_type), | |
| nn.Conv2d(nf*4, nf*2, kernel_size=3, stride=1, padding=1), | |
| ) | |
| if mode == 'up': | |
| return nn.Sequential( | |
| WaveletReconstruct(updown_type), | |
| nn.Conv2d(nf//4, nf//2, kernel_size=3, stride=1, padding=1), | |
| ) | |
| class GuidedResUnet(nn.Module): | |
| def __init__(self, args=None): | |
| super().__init__() | |
| self.args = args | |
| self.cf = 0 | |
| self.nframes = nframes = args.get('nframes', 1) | |
| self.res = args.get('res', False) | |
| self.norm = args.get('norm', False) | |
| self.updown_type = args.get('updown_type', 'conv') | |
| self.downsample = args.get('downsample', False) | |
| if self.downsample == 'shuffle': | |
| self.down_fn = nn.PixelUnshuffle(2) | |
| self.up_fn = nn.PixelShuffle(2) | |
| elif self.downsample != False: | |
| self.down_fn = WaveletDecompose(mode=self.downsample) | |
| self.up_fn = WaveletReconstruct(mode=self.downsample) | |
| ext = 4 if self.downsample else 1 | |
| nf = args.get('nf', 32) | |
| in_nc = args.get('in_nc', 4) | |
| out_nc = args.get('out_nc', 4) | |
| self.conv_in = nn.Conv2d(in_nc*nframes*ext, nf, kernel_size=3, stride=1, padding=1) | |
| self.conv1 = GuidedResidualBlock(nf, nf, is_activate=False) | |
| self.pool1 = get_updown_module(nf, self.updown_type, mode='down') | |
| self.conv2 = GuidedResidualBlock(nf*2, nf*2, is_activate=False) | |
| self.pool2 = get_updown_module(nf*2, self.updown_type, mode='down') | |
| self.conv3 = GuidedResidualBlock(nf*4, nf*4, is_activate=False) | |
| self.pool3 = get_updown_module(nf*4, self.updown_type, mode='down') | |
| self.conv4 = GuidedResidualBlock(nf*8, nf*8, is_activate=False) | |
| self.pool4 = get_updown_module(nf*8, self.updown_type, mode='down') | |
| self.conv5 = GuidedResidualBlock(nf*16, nf*16, is_activate=False) | |
| self.upv6 = get_updown_module(nf*16, self.updown_type, mode='up') | |
| self.conv6 = GuidedResidualBlock(nf*16, nf*8, is_activate=False) | |
| self.upv7 = get_updown_module(nf*8, self.updown_type, mode='up') | |
| self.conv7 = GuidedResidualBlock(nf*8, nf*4, is_activate=False) | |
| self.upv8 = get_updown_module(nf*4, self.updown_type, mode='up') | |
| self.conv8 = GuidedResidualBlock(nf*4, nf*2, is_activate=False) | |
| self.upv9 = get_updown_module(nf*2, self.updown_type, mode='up') | |
| self.conv9 = GuidedResidualBlock(nf*2, nf, is_activate=False) | |
| self.conv10 = nn.Conv2d(nf, out_nc*ext, kernel_size=1, stride=1) | |
| self.lrelu = nn.LeakyReLU(0.01, inplace=True) | |
| def forward(self, x, t): | |
| # shape= x.size() | |
| # x = x.view(-1,shape[-3],shape[-2],shape[-1]) | |
| if self.norm: | |
| x, lb, ub = data_normalize(x) | |
| t = t / (ub-lb) | |
| if self.downsample: | |
| x = self.down_fn(x) | |
| conv_in = self.lrelu(self.conv_in(x)) | |
| conv1 = self.conv1(conv_in, t) | |
| pool1 = self.pool1(conv1) | |
| conv2 = self.conv2(pool1, t) | |
| pool2 = self.pool2(conv2) | |
| conv3 = self.conv3(pool2, t) | |
| pool3 = self.pool3(conv3) | |
| conv4 = self.conv4(pool3, t) | |
| pool4 = self.pool4(conv4) | |
| conv5 = self.conv5(pool4, t) | |
| up6 = self.upv6(conv5) | |
| up6 = torch.cat([up6, conv4], 1) | |
| conv6 = self.conv6(up6, t) | |
| up7 = self.upv7(conv6) | |
| up7 = torch.cat([up7, conv3], 1) | |
| conv7 = self.conv7(up7, t) | |
| up8 = self.upv8(conv7) | |
| up8 = torch.cat([up8, conv2], 1) | |
| conv8 = self.conv8(up8, t) | |
| up9 = self.upv9(conv8) | |
| up9 = torch.cat([up9, conv1], 1) | |
| conv9 = self.conv9(up9, t) | |
| out = self.conv10(conv9) | |
| if self.res: | |
| out = out + x[:, self.cf*4:self.cf*4+4] | |
| if self.downsample: | |
| out = self.up_fn(out) | |
| if self.norm: | |
| out = data_inv_normalize(out, lb, ub) | |
| return out |