from .modules import * class DnCNN(nn.Module): def __init__(self, args=None): super().__init__() self.args = args self.res = args['res'] self.raw2rgb = True if args['in_nc']==4 and args['out_nc']==3 else False nf = args['nf'] in_nc = args['in_nc'] out_nc = args['out_nc'] depth = args['depth'] use_bn = args['use_bn'] layers = [] layers.append(nn.Conv2d(in_channels=in_nc, out_channels=nf, kernel_size=3, padding=1, bias=True)) layers.append(nn.ReLU(inplace=True)) for _ in range(depth-2): layers.append(nn.Conv2d(in_channels=nf, out_channels=nf, kernel_size=3, padding=1, bias=False)) if use_bn: layers.append(nn.BatchNorm2d(nf, eps=0.0001, momentum=0.95)) layers.append(nn.ReLU(inplace=True)) layers.append(nn.Conv2d(in_channels=nf, out_channels=out_nc, kernel_size=3, padding=1, bias=False)) self.dncnn = nn.Sequential(*layers) def forward(self, x): out = self.dncnn(x) if self.raw2rgb: out = nn.functional.pixel_shuffle(out, 2) elif self.res: out = x - out #out = out + x return out def conv33(in_channels, out_channels, stride=1, padding=1, bias=True, groups=1): return nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=padding, bias=bias, groups=groups) def upconv2x2(in_channels, out_channels, mode='transpose'): if mode == 'transpose': return nn.ConvTranspose2d( in_channels, out_channels, kernel_size=2, stride=2) else: # out_channels is always going to be the same # as in_channels return nn.Sequential( nn.Upsample(mode='bilinear', scale_factor=2), conv1x1(in_channels, out_channels)) class DownConv(nn.Module): """ A helper Module that performs 2 convolutions and 1 MaxPool. A ReLU activation follows each convolution. """ def __init__(self, in_channels, out_channels, pooling=True): super(DownConv, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.pooling = pooling self.conv1 = conv33(self.in_channels, self.out_channels) self.conv2 = conv33(self.out_channels, self.out_channels) if self.pooling: self.pool = nn.MaxPool2d(kernel_size=2, stride=2) def forward(self, x): x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) before_pool = x if self.pooling: x = self.pool(x) return x, before_pool class UpConv(nn.Module): """ A helper Module that performs 2 convolutions and 1 UpConvolution. A ReLU activation follows each convolution. """ def __init__(self, in_channels, out_channels, merge_mode='concat', up_mode='transpose'): super(UpConv, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.merge_mode = merge_mode self.up_mode = up_mode self.upconv = upconv2x2(self.in_channels, self.out_channels, mode=self.up_mode) if self.merge_mode == 'concat': self.conv1 = conv33( 2*self.out_channels, self.out_channels) else: # num of input channels to conv2 is same self.conv1 = conv33(self.out_channels, self.out_channels) self.conv2 = conv33(self.out_channels, self.out_channels) def forward(self, from_down, from_up): """ Forward pass Arguments: from_down: tensor from the encoder pathway from_up: upconv'd tensor from the decoder pathway """ from_up = self.upconv(from_up) if self.merge_mode == 'concat': x = torch.cat((from_up, from_down), 1) else: x = from_up + from_down x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) return x class est_UNet(nn.Module): """ `UNet` class is based on https://arxiv.org/abs/1505.04597 The U-Net is a convolutional encoder-decoder neural network. Contextual spatial information (from the decoding, expansive pathway) about an input tensor is merged with information representing the localization of details (from the encoding, compressive pathway). Modifications to the original paper: (1) padding is used in 3x3 convolutions to prevent loss of border pixels (2) merging outputs does not require cropping due to (1) (3) residual connections can be used by specifying UNet(merge_mode='add') (4) if non-parametric upsampling is used in the decoder pathway (specified by upmode='upsample'), then an additional 1x1 2d convolution occurs after upsampling to reduce channel dimensionality by a factor of 2. This channel halving happens with the convolution in the tranpose convolution (specified by upmode='transpose') """ def __init__(self, args): """ Arguments: in_channels: int, number of channels in the input tensor. Default is 3 for RGB images. depth: int, number of MaxPools in the U-Net. start_filts: int, number of convolutional filters for the first conv. up_mode: string, type of upconvolution. Choices: 'transpose' for transpose convolution or 'upsample' for nearest neighbour upsampling. """ super(est_UNet, self).__init__() num_classes = args['out_nc'] in_channels = args['in_nc'] depth = args['depth'] start_filts = args['nf'] up_mode='transpose' merge_mode='add' use_type='optimize_gat' self.use_type=use_type if up_mode in ('transpose', 'upsample'): self.up_mode = up_mode else: raise ValueError("\"{}\" is not a valid mode for " "upsampling. Only \"transpose\" and " "\"upsample\" are allowed.".format(up_mode)) if merge_mode in ('concat', 'add'): self.merge_mode = merge_mode else: raise ValueError("\"{}\" is not a valid mode for" "merging up and down paths. " "Only \"concat\" and " "\"add\" are allowed.".format(up_mode)) # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add' if self.up_mode == 'upsample' and self.merge_mode == 'add': raise ValueError("up_mode \"upsample\" is incompatible " "with merge_mode \"add\" at the moment " "because it doesn't make sense to use " "nearest neighbour to reduce " "depth channels (by half).") self.num_classes = num_classes self.in_channels = in_channels self.start_filts = start_filts self.depth = depth self.down_convs = [] self.up_convs = [] self.noiseSTD = nn.Parameter(data=torch.log(torch.tensor(0.5))) # create the encoder pathway and add to a list for i in range(depth): ins = self.in_channels if i == 0 else outs outs = self.start_filts*(2**i) pooling = True if i < depth-1 else False down_conv = DownConv(ins, outs, pooling=pooling) self.down_convs.append(down_conv) # create the decoder pathway and add to a list # - careful! decoding only requires depth-1 blocks for i in range(depth-1): ins = outs outs = ins // 2 up_conv = UpConv(ins, outs, up_mode=up_mode, merge_mode=merge_mode) self.up_convs.append(up_conv) self.conv_final = conv1x1(outs, self.num_classes) self.sigmoid=nn.Sigmoid().cuda() # add the list of modules to current module self.down_convs = nn.ModuleList(self.down_convs) self.up_convs = nn.ModuleList(self.up_convs) self.reset_params() @staticmethod def weight_init(m): if isinstance(m, nn.Conv2d): nn.init.xavier_normal(m.weight) nn.init.constant(m.bias, 0) def reset_params(self): for i, m in enumerate(self.modules()): self.weight_init(m) def forward(self, x): encoder_outs = [] # encoder pathway, save outputs for merging for i, module in enumerate(self.down_convs): x, before_pool = module(x) encoder_outs.append(before_pool) for i, module in enumerate(self.up_convs): before_pool = encoder_outs[-(i+2)] x = module(before_pool, x) before_x=self.conv_final(x) if self.use_type=='optimze_gat': x=before_x else: x = before_x**2 return torch.mean(x, dim=(2,3)).squeeze() class New1(nn.Module): def __init__(self, in_ch, out_ch): super(New1, self).__init__() self.mask = torch.from_numpy(np.array([[1,1,1],[1,0,1],[1,1,1]], dtype=np.float32)).cuda() self.conv1 = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, padding = 1, kernel_size = 3) def forward(self, x): self.conv1.weight.data = self.conv1.weight * self.mask x = self.conv1(x) return x class New2(nn.Module): def __init__(self, in_ch, out_ch): super(New2, self).__init__() self.mask = torch.from_numpy(np.array([[0,1,0,1,0],[1,0,0,0,1],[0,0,1,0,0],[1,0,0,0,1],[0,1,0,1,0]], dtype=np.float32)).cuda() self.conv1 = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, padding = 2, kernel_size = 5) def forward(self, x): self.conv1.weight.data = self.conv1.weight * self.mask x = self.conv1(x) return x class New3(nn.Module): def __init__(self, in_ch, out_ch, dilated_value): super(New3, self).__init__() self.mask = torch.from_numpy(np.array([[1,0,1],[0,1,0],[1,0,1]], dtype=np.float32)).cuda() self.conv1 = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size = 3, padding=dilated_value, dilation=dilated_value) def forward(self, x): self.conv1.weight.data = self.conv1.weight * self.mask x = self.conv1(x) return x class Residual_module(nn.Module): def __init__(self, in_ch, mul = 1): super(Residual_module, self).__init__() self.activation1 = nn.PReLU(in_ch*mul,0).cuda() self.activation2 = nn.PReLU(in_ch,0).cuda() self.conv1_1by1 = nn.Conv2d(in_channels=in_ch, out_channels=in_ch*mul, kernel_size = 1) self.conv2_1by1 = nn.Conv2d(in_channels=in_ch*mul, out_channels=in_ch, kernel_size = 1) def forward(self, input): output_residual = self.conv1_1by1(input) output_residual = self.activation1(output_residual) output_residual = self.conv2_1by1(output_residual) output = (input + output_residual) / 2. output = self.activation2(output) return output class Gaussian(nn.Module): def forward(self,input): return torch.exp(-torch.mul(input,input)) class Receptive_attention(nn.Module): def __init__(self, in_ch, at_type = 'softmax'): super(Receptive_attention, self).__init__() self.activation1 = nn.ReLU().cuda() self.activation2 = nn.ReLU().cuda() self.activation3 = nn.PReLU(in_ch,0).cuda() self.conv1_1by1 = nn.Conv2d(in_channels=in_ch, out_channels=in_ch*4, kernel_size = 1) self.conv2_1by1 = nn.Conv2d(in_channels=in_ch*4, out_channels=in_ch*4, kernel_size = 1) self.conv3_1by1 = nn.Conv2d(in_channels=in_ch*4, out_channels=9, kernel_size = 1) self.at_type = at_type if at_type == 'softmax': self.softmax = nn.Softmax() else: self.gaussian = Gaussian() self.sigmoid = nn.Sigmoid() def forward(self, input, receptive): if self.at_type == 'softmax': output_residual = self.conv1_1by1(input) output_residual = self.activation1(output_residual) output_residual = self.conv2_1by1(output_residual) output_residual = self.activation2(output_residual) output_residual = self.conv3_1by1(output_residual) output_residual = F.adaptive_avg_pool2d(output_residual, (1, 1)) # output_residual = self.Gaussian(output_residual) output_residual = self.softmax(output_residual).permute((1,0,2,3)).unsqueeze(-1) else: output_residual = self.conv1_1by1(input) output_residual = self.activation1(output_residual) output_residual = self.conv2_1by1(output_residual) output_residual = self.activation2(output_residual) output_residual = self.conv3_1by1(output_residual) output_residual = F.adaptive_avg_pool2d(output_residual, (1, 1)) output_residual = self.gaussian(output_residual) output_residual = self.sigmoid(output_residual).permute((1,0,2,3)).unsqueeze(-1) output = torch.sum(receptive * output_residual, dim = 0) output = self.activation3(output) return output class New1_layer(nn.Module): def __init__(self, in_ch, out_ch, case = 'FBI_Net', mul = 1): super(New1_layer, self).__init__() self.case = case self.new1 = New1(in_ch,out_ch).cuda() if case == 'case1' or case == 'case2' or case == 'case7' or case == 'FBI_Net': self.residual_module = Residual_module(out_ch, mul) self.activation_new1 = nn.PReLU(in_ch,0).cuda() def forward(self, x): if self.case == 'case1' or self.case =='case2' or self.case =='case7' or self.case == 'FBI_Net': # plain NN architecture wo residual module and residual connection output_new1 = self.new1(x) output_new1 = self.activation_new1(output_new1) output = self.residual_module(output_new1) return output, output_new1 else: # final model output_new1 = self.new1(x) output = self.activation_new1(output_new1) return output, output_new1 class New2_layer(nn.Module): def __init__(self, in_ch, out_ch, case = 'FBI_Net', mul = 1): super(New2_layer, self).__init__() self.case = case self.new2 = New2(in_ch,out_ch).cuda() self.activation_new1 = nn.PReLU(in_ch,0).cuda() if case == 'case1' or case == 'case2' or case == 'case7' or case == 'FBI_Net': self.residual_module = Residual_module(out_ch, mul) if case == 'case1' or case == 'case3' or case == 'case6' or case == 'FBI_Net': self.activation_new2 = nn.PReLU(in_ch,0).cuda() def forward(self, x, output_new): if self.case == 'case1': # output_new2 = self.new2(output_new) output_new2 = self.activation_new1(output_new2) output = (output_new2 + x) / 2. output = self.activation_new2(output) output = self.residual_module(output) return output, output_new2 elif self.case == 'case2' or self.case == 'case7': # output_new2 = self.new2(x) output_new2 = self.activation_new1(output_new2) output = output_new2 output = self.residual_module(output) return output, output_new2 elif self.case == 'case3' or self.case == 'case6': # output_new2 = self.new2(output_new) output_new2 = self.activation_new1(output_new2) output = (output_new2 + x) / 2. output = self.activation_new2(output) return output, output_new2 elif self.case == 'case4': # output_new2 = self.new2(x) output_new2 = self.activation_new1(output_new2) output = output_new2 return output, output_new2 elif self.case == 'case5' : # output_new2 = self.new2(x) output_new2 = self.activation_new1(output_new2) output = output_new2 return output, output_new2 else: output_new2 = self.new2(output_new) output_new2 = self.activation_new1(output_new2) output = (output_new2 + x) / 2. output = self.activation_new2(output) output = self.residual_module(output) return output, output_new2 class New3_layer(nn.Module): def __init__(self, in_ch, out_ch, dilated_value=3, case = 'FBI_Net', mul = 1): super(New3_layer, self).__init__() self.case = case self.new3 = New3(in_ch,out_ch,dilated_value).cuda() self.activation_new1 = nn.PReLU(in_ch,0).cuda() if case == 'case1' or case == 'case2' or case == 'case7' or case == 'FBI_Net': self.residual_module = Residual_module(out_ch, mul) if case == 'case1' or case == 'case3' or case == 'case6'or case == 'FBI_Net': self.activation_new2 = nn.PReLU(in_ch,0).cuda() def forward(self, x, output_new): if self.case == 'case1': # output_new3 = self.new3(output_new) output_new3 = self.activation_new1(output_new3) output = (output_new3 + x) / 2. output = self.activation_new2(output) output = self.residual_module(output) return output, output_new3 elif self.case == 'case2' or self.case == 'case7': # output_new3 = self.new3(x) output_new3 = self.activation_new1(output_new3) output = output_new3 output = self.residual_module(output) return output, output_new3 elif self.case == 'case3' or self.case == 'case6': # output_new3 = self.new3(output_new) output_new3 = self.activation_new1(output_new3) output = (output_new3 + x) / 2. output = self.activation_new2(output) return output, output_new3 elif self.case == 'case4': # output_new3 = self.new3(x) output_new3 = self.activation_new1(output_new3) output = output_new3 return output, output_new3 elif self.case == 'case5': # output_new3 = self.new3(x) output_new3 = self.activation_new1(output_new3) output = output_new3 return output, output_new3 else: output_new3 = self.new3(output_new) output_new3 = self.activation_new1(output_new3) output = (output_new3 + x) / 2. output = self.activation_new2(output) output = self.residual_module(output) return output, output_new3 class AttrProxy(object): """Translates index lookups into attribute lookups.""" def __init__(self, module, prefix): self.module = module self.prefix = prefix def __getitem__(self, i): return getattr(self.module, self.prefix + str(i)) class FBI_Net(nn.Module): def __init__(self, args): super().__init__() self.args = args channel = args['channel'] output_channel = args['output_channel'] filters = args['nf'] mul = args['mul'] num_of_layers = args['num_of_layers'] case = args['case'] output_type = args['output_type'] sigmoid_value = args['sigmoid_value'] self.res = args['res'] self.case = case self.new1 = New1_layer(channel, filters, mul = mul, case = case).cuda() self.new2 = New2_layer(filters, filters, mul = mul, case = case).cuda() self.num_layers = num_of_layers self.output_type = output_type self.sigmoid_value = sigmoid_value dilated_value = 3 for layer in range (num_of_layers-2): self.add_module('new_' + str(layer), New3_layer(filters, filters, dilated_value, mul = mul, case = case).cuda()) self.residual_module = Residual_module(filters, mul) self.activation = nn.PReLU(filters,0).cuda() self.output_layer = nn.Conv2d(in_channels=filters, out_channels=output_channel, kernel_size = 1).cuda() if self.output_type == 'sigmoid': self.sigmoid=nn.Sigmoid().cuda() self.new = AttrProxy(self, 'new_') def forward(self, x): if self.case == 'FBI_Net' or self.case == 'case2' or self.case == 'case3' or self.case == 'case4': output, output_new = self.new1(x) output_sum = output output, output_new = self.new2(output, output_new) output_sum = output + output_sum for i, (new_layer) in enumerate(self.new): output, output_new = new_layer(output, output_new) output_sum = output + output_sum if i == self.num_layers - 3: break final_output = self.activation(output_sum/self.num_layers) final_output = self.residual_module(final_output) final_output = self.output_layer(final_output) else: output, output_new = self.new1(x) output, output_new = self.new2(output, output_new) for i, (new_layer) in enumerate(self.new): output, output_new = new_layer(output, output_new) if i == self.num_layers - 3: break final_output = self.activation(output) final_output = self.residual_module(final_output) final_output = self.output_layer(final_output) if self.output_type=='sigmoid': final_output[:,0]=(torch.ones_like(final_output[:,0])*self.sigmoid_value)*self.sigmoid(final_output[:,0]) if self.res: final_output = final_output[:,:1] * x + final_output[:,1:] return final_output class SelfSupUNet(nn.Module): def __init__(self, args): """ Args: in_channels (int): number of input channels, Default 4 depth (int): depth of the network, Default 5 nf (int): number of filters in the first layer, Default 32 """ super().__init__() in_channels = args['in_nc'] out_channels = args['out_nc'] depth = args['depth'] if 'depth' in args else 5 nf = args['nf'] if 'nf' in args else 32 slope = args['slope'] if 'slope' in args else 0.1 self.norm = args['norm'] if 'norm' in args else False self.res = args['res'] if 'res' in args else False self.depth = depth self.head = nn.Sequential( LR(in_channels, nf, 3, slope), LR(nf, nf, 3, slope)) self.down_path = nn.ModuleList() for i in range(depth): self.down_path.append(LR(nf, nf, 3, slope)) self.up_path = nn.ModuleList() for i in range(depth): if i != depth-1: self.up_path.append(UP(nf*2 if i==0 else nf*3, nf*2, slope)) else: self.up_path.append(UP(nf*2+in_channels, nf*2, slope)) self.last = nn.Sequential(LR(2*nf, 2*nf, 1, slope), LR(2*nf, 2*nf, 1, slope), conv1x1(2*nf, out_channels, bias=True)) def forward(self, x): if self.norm: x, lb, ub = data_normalize(x) blocks = [] blocks.append(x) x = self.head(x) for i, down in enumerate(self.down_path): x = F.max_pool2d(x, 2) if i != len(self.down_path) - 1: blocks.append(x) x = down(x) for i, up in enumerate(self.up_path): x = up(x, blocks[-i-1]) out = self.last(x) if self.res: out = out + x if self.norm: out = data_inv_normalize(out, lb, ub) return out class LR(nn.Module): def __init__(self, in_size, out_size, ksize=3, slope=0.1): super(LR, self).__init__() block = [] block.append(nn.Conv2d(in_size, out_size, kernel_size=ksize, padding=ksize//2, bias=True)) block.append(nn.LeakyReLU(slope, inplace=False)) self.block = nn.Sequential(*block) def forward(self, x): out = self.block(x) return out class UP(nn.Module): def __init__(self, in_size, out_size, slope=0.1): super(UP, self).__init__() self.conv_1 = LR(in_size, out_size) self.conv_2 = LR(out_size, out_size) def up(self, x): s = x.shape x = x.reshape(s[0], s[1], s[2], 1, s[3], 1) x = x.repeat(1, 1, 1, 2, 1, 2) x = x.reshape(s[0], s[1], s[2]*2, s[3]*2) return x def forward(self, x, pool): x = self.up(x) x = torch.cat([x, pool], 1) x = self.conv_1(x) x = self.conv_2(x) return x class SelfResUNet(nn.Module): def __init__(self, args): """ Args: in_channels (int): number of input channels, Default 4 depth (int): depth of the network, Default 5 nf (int): number of filters in the first layer, Default 32 """ super().__init__() in_channels = args['in_nc'] out_channels = args['out_nc'] depth = args['depth'] if 'depth' in args else 5 nf = args['nf'] if 'nf' in args else 32 slope = args['slope'] if 'slope' in args else 0.1 self.norm = args['norm'] if 'norm' in args else False self.res = args['res'] if 'res' in args else False self.depth = depth self.head = Res(in_channels, nf, slope) self.down_path = nn.ModuleList() for i in range(depth): self.down_path.append(Res(nf, nf, slope, ksize=3)) self.up_path = nn.ModuleList() for i in range(depth): if i != depth-1: self.up_path.append(RUP(nf*2 if i==0 else nf*3, nf*2, slope)) else: self.up_path.append(RUP(nf*2+in_channels, nf*2, slope)) self.last = Res(2*nf, 2*nf, slope, ksize=1) self.out = conv1x1(2*nf, out_channels, bias=True) def forward(self, x): if self.norm: x, lb, ub = data_normalize(x) inp = x blocks = [] blocks.append(x) x = self.head(x) for i, down in enumerate(self.down_path): x = F.max_pool2d(x, 2) if i != len(self.down_path) - 1: blocks.append(x) x = down(x) for i, up in enumerate(self.up_path): x = up(x, blocks[-i-1]) out = self.last(x) out = self.out(out) if self.res: out = out + inp if self.norm: out = data_inv_normalize(out, lb, ub) return out class RUP(nn.Module): def __init__(self, in_size, out_size, slope=0.1, ksize=3): super(RUP, self).__init__() self.conv_1 = LR(out_size, out_size, ksize=ksize, slope=slope) self.conv_2 = LR(out_size, out_size, ksize=ksize, slope=slope) if in_size != out_size: self.short_cut = nn.Sequential(conv1x1(in_size, out_size)) else: self.short_cut = nn.Sequential(OrderedDict([])) def up(self, x): s = x.shape x = x.reshape(s[0], s[1], s[2], 1, s[3], 1) x = x.repeat(1, 1, 1, 2, 1, 2) x = x.reshape(s[0], s[1], s[2]*2, s[3]*2) return x def forward(self, x, pool): x = self.up(x) x = torch.cat([x, pool], 1) x = self.short_cut(x) z = self.conv_1(x) z = self.conv_2(z) z += x return z class Res(nn.Module): def __init__(self, in_size, out_size, slope=0.1, ksize=3): super().__init__() self.conv_1 = LR(out_size, out_size, ksize=ksize, slope=slope) self.conv_2 = LR(out_size, out_size, ksize=ksize, slope=slope) if in_size != out_size: self.short_cut = nn.Sequential(conv1x1(in_size, out_size)) else: self.short_cut = nn.Sequential(OrderedDict([])) def forward(self, x): x = self.short_cut(x) z = self.conv_1(x) z = self.conv_2(z) z += x return z def conv1x1(in_chn, out_chn, bias=True): layer = nn.Conv2d(in_chn, out_chn, kernel_size=1, stride=1, padding=0, bias=bias) return layer class GuidedSelfUnet(nn.Module): def __init__(self, args): """ Args: in_channels (int): number of input channels, Default 4 depth (int): depth of the network, Default 5 nf (int): number of filters in the first layer, Default 32 """ super().__init__() in_channels = args['in_nc'] out_channels = args['out_nc'] depth = args['depth'] if 'depth' in args else 5 nf = args['nf'] if 'nf' in args else 32 slope = args['slope'] if 'slope' in args else 0.1 self.norm = args['norm'] if 'norm' in args else False self.res = args['res'] if 'res' in args else False self.depth = depth self.head = GRes(in_channels, nf, slope) self.down_path = nn.ModuleList() for i in range(depth): self.down_path.append(GLR(nf, nf, 3, slope)) self.up_path = nn.ModuleList() for i in range(depth): if i != depth-1: self.up_path.append(GUP(nf*2 if i==0 else nf*3, nf*2, slope)) else: self.up_path.append(GUP(nf*2+in_channels, nf*2, slope)) self.last = GRes(2*nf, 2*nf, slope, ksize=1) self.out = conv1x1(2*nf, out_channels, bias=True) def forward(self, x, t): if self.norm: x, lb, ub = data_normalize(x) t = t / (ub-lb) blocks = [] blocks.append(x) x = self.head(x, t) for i, down in enumerate(self.down_path): x = F.max_pool2d(x, 2) if i != len(self.down_path) - 1: blocks.append(x) x = down(x, t) for i, up in enumerate(self.up_path): x = up(x, blocks[-i-1], t) out = self.last(x, t) out = self.out(out) if self.res: out = out + x if self.norm: out = data_inv_normalize(out, lb, ub) return out class GLR(nn.Module): def __init__(self, in_size, out_size, ksize=3, slope=0.1): super(GLR, self).__init__() self.block = nn.Conv2d(in_size, out_size, kernel_size=ksize, padding=ksize//2, bias=True) self.act = nn.LeakyReLU(slope, inplace=False) self.gamma = nn.Sequential( conv1x1(1, out_size), nn.SiLU(), conv1x1(out_size, out_size), ) self.beta = nn.Sequential( nn.SiLU(), conv1x1(out_size, out_size), ) def forward(self, x, t): z = self.block(x) tk = self.gamma(t) tb = self.beta(tk) z = z * tk + tb out = self.act(z) return out class GRes(nn.Module): def __init__(self, in_size, out_size, slope=0.1, ksize=3): super(GRes, self).__init__() self.conv_1 = LR(out_size, out_size, ksize=ksize) self.conv_2 = GLR(out_size, out_size, ksize=ksize) if in_size != out_size: self.short_cut = nn.Sequential( conv1x1(in_size, out_size) ) else: self.short_cut = nn.Sequential(OrderedDict([])) def forward(self, x, t): x = self.short_cut(x) z = self.conv_1(x) z = self.conv_2(z, t) z += x return z class GUP(nn.Module): def __init__(self, in_size, out_size, slope=0.1): super(GUP, self).__init__() self.conv_1 = LR(out_size, out_size) self.conv_2 = GLR(out_size, out_size) if in_size != out_size: self.short_cut = nn.Sequential( conv1x1(in_size, out_size) ) else: self.short_cut = nn.Sequential(OrderedDict([])) def up(self, x): s = x.shape x = x.reshape(s[0], s[1], s[2], 1, s[3], 1) x = x.repeat(1, 1, 1, 2, 1, 2) x = x.reshape(s[0], s[1], s[2]*2, s[3]*2) return x def forward(self, x, pool, t): x = self.up(x) x = torch.cat([x, pool], 1) x = self.short_cut(x) z = self.conv_1(x) z = self.conv_2(z, t) z += x return z class N2NF_Unet(nn.Module): def __init__(self, args=None): super().__init__() self.args = args in_nc = args['in_nc'] out_nc = args['out_nc'] self.norm = args['norm'] if 'norm' in args else False # Layers: enc_conv0, enc_conv1, pool1 self._block1 = nn.Sequential( nn.Conv2d(in_nc, 48, 3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(48, 48, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2)) # Layers: enc_conv(i), pool(i); i=2..5 self._block2 = nn.Sequential( nn.Conv2d(48, 48, 3, stride=1, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2)) # Layers: enc_conv6, upsample5 self._block3 = nn.Sequential( nn.Conv2d(48, 48, 3, stride=1, padding=1), nn.ReLU(inplace=True), nn.ConvTranspose2d(48, 48, 3, stride=2, padding=1, output_padding=1)) #nn.Upsample(scale_factor=2, mode='nearest')) # Layers: dec_conv5a, dec_conv5b, upsample4 self._block4 = nn.Sequential( nn.Conv2d(96, 96, 3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(96, 96, 3, stride=1, padding=1), nn.ReLU(inplace=True), nn.ConvTranspose2d(96, 96, 3, stride=2, padding=1, output_padding=1)) #nn.Upsample(scale_factor=2, mode='nearest')) # Layers: dec_deconv(i)a, dec_deconv(i)b, upsample(i-1); i=4..2 self._block5 = nn.Sequential( nn.Conv2d(144, 96, 3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(96, 96, 3, stride=1, padding=1), nn.ReLU(inplace=True), nn.ConvTranspose2d(96, 96, 3, stride=2, padding=1, output_padding=1)) #nn.Upsample(scale_factor=2, mode='nearest')) # Layers: dec_conv1a, dec_conv1b, dec_conv1c, self._block6 = nn.Sequential( nn.Conv2d(96 + in_nc, 64, 3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 32, 3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(32, out_nc, 3, stride=1, padding=1), nn.LeakyReLU(0.1)) # Initialize weights self._init_weights() def _init_weights(self): """Initializes weights using He et al. (2015).""" for m in self.modules(): if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight.data) m.bias.data.zero_() def forward(self, x): if self.norm: x, lb, ub = data_normalize(x) # Encoder pool1 = self._block1(x) pool2 = self._block2(pool1) pool3 = self._block2(pool2) pool4 = self._block2(pool3) pool5 = self._block2(pool4) # Decoder upsample5 = self._block3(pool5) concat5 = torch.cat((upsample5, pool4), dim=1) upsample4 = self._block4(concat5) concat4 = torch.cat((upsample4, pool3), dim=1) upsample3 = self._block5(concat4) concat3 = torch.cat((upsample3, pool2), dim=1) upsample2 = self._block5(concat3) concat2 = torch.cat((upsample2, pool1), dim=1) upsample1 = self._block5(concat2) concat1 = torch.cat((upsample1, x), dim=1) # Final activation out = self._block6(concat1) if self.norm: out = data_inv_normalize(out, lb, ub) return out