YOND / archs /comp.py
hansen97's picture
Initial clean commit
0e07d71
from .modules import *
class DnCNN(nn.Module):
def __init__(self, args=None):
super().__init__()
self.args = args
self.res = args['res']
self.raw2rgb = True if args['in_nc']==4 and args['out_nc']==3 else False
nf = args['nf']
in_nc = args['in_nc']
out_nc = args['out_nc']
depth = args['depth']
use_bn = args['use_bn']
layers = []
layers.append(nn.Conv2d(in_channels=in_nc, out_channels=nf, kernel_size=3, padding=1, bias=True))
layers.append(nn.ReLU(inplace=True))
for _ in range(depth-2):
layers.append(nn.Conv2d(in_channels=nf, out_channels=nf, kernel_size=3, padding=1, bias=False))
if use_bn:
layers.append(nn.BatchNorm2d(nf, eps=0.0001, momentum=0.95))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Conv2d(in_channels=nf, out_channels=out_nc, kernel_size=3, padding=1, bias=False))
self.dncnn = nn.Sequential(*layers)
def forward(self, x):
out = self.dncnn(x)
if self.raw2rgb:
out = nn.functional.pixel_shuffle(out, 2)
elif self.res:
out = x - out #out = out + x
return out
def conv33(in_channels, out_channels, stride=1,
padding=1, bias=True, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=padding,
bias=bias,
groups=groups)
def upconv2x2(in_channels, out_channels, mode='transpose'):
if mode == 'transpose':
return nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=2,
stride=2)
else:
# out_channels is always going to be the same
# as in_channels
return nn.Sequential(
nn.Upsample(mode='bilinear', scale_factor=2),
conv1x1(in_channels, out_channels))
class DownConv(nn.Module):
"""
A helper Module that performs 2 convolutions and 1 MaxPool.
A ReLU activation follows each convolution.
"""
def __init__(self, in_channels, out_channels, pooling=True):
super(DownConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.pooling = pooling
self.conv1 = conv33(self.in_channels, self.out_channels)
self.conv2 = conv33(self.out_channels, self.out_channels)
if self.pooling:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
before_pool = x
if self.pooling:
x = self.pool(x)
return x, before_pool
class UpConv(nn.Module):
"""
A helper Module that performs 2 convolutions and 1 UpConvolution.
A ReLU activation follows each convolution.
"""
def __init__(self, in_channels, out_channels,
merge_mode='concat', up_mode='transpose'):
super(UpConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.merge_mode = merge_mode
self.up_mode = up_mode
self.upconv = upconv2x2(self.in_channels, self.out_channels,
mode=self.up_mode)
if self.merge_mode == 'concat':
self.conv1 = conv33(
2*self.out_channels, self.out_channels)
else:
# num of input channels to conv2 is same
self.conv1 = conv33(self.out_channels, self.out_channels)
self.conv2 = conv33(self.out_channels, self.out_channels)
def forward(self, from_down, from_up):
""" Forward pass
Arguments:
from_down: tensor from the encoder pathway
from_up: upconv'd tensor from the decoder pathway
"""
from_up = self.upconv(from_up)
if self.merge_mode == 'concat':
x = torch.cat((from_up, from_down), 1)
else:
x = from_up + from_down
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
return x
class est_UNet(nn.Module):
""" `UNet` class is based on https://arxiv.org/abs/1505.04597
The U-Net is a convolutional encoder-decoder neural network.
Contextual spatial information (from the decoding,
expansive pathway) about an input tensor is merged with
information representing the localization of details
(from the encoding, compressive pathway).
Modifications to the original paper:
(1) padding is used in 3x3 convolutions to prevent loss
of border pixels
(2) merging outputs does not require cropping due to (1)
(3) residual connections can be used by specifying
UNet(merge_mode='add')
(4) if non-parametric upsampling is used in the decoder
pathway (specified by upmode='upsample'), then an
additional 1x1 2d convolution occurs after upsampling
to reduce channel dimensionality by a factor of 2.
This channel halving happens with the convolution in
the tranpose convolution (specified by upmode='transpose')
"""
def __init__(self, args):
"""
Arguments:
in_channels: int, number of channels in the input tensor.
Default is 3 for RGB images.
depth: int, number of MaxPools in the U-Net.
start_filts: int, number of convolutional filters for the
first conv.
up_mode: string, type of upconvolution. Choices: 'transpose'
for transpose convolution or 'upsample' for nearest neighbour
upsampling.
"""
super(est_UNet, self).__init__()
num_classes = args['out_nc']
in_channels = args['in_nc']
depth = args['depth']
start_filts = args['nf']
up_mode='transpose'
merge_mode='add'
use_type='optimize_gat'
self.use_type=use_type
if up_mode in ('transpose', 'upsample'):
self.up_mode = up_mode
else:
raise ValueError("\"{}\" is not a valid mode for "
"upsampling. Only \"transpose\" and "
"\"upsample\" are allowed.".format(up_mode))
if merge_mode in ('concat', 'add'):
self.merge_mode = merge_mode
else:
raise ValueError("\"{}\" is not a valid mode for"
"merging up and down paths. "
"Only \"concat\" and "
"\"add\" are allowed.".format(up_mode))
# NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
if self.up_mode == 'upsample' and self.merge_mode == 'add':
raise ValueError("up_mode \"upsample\" is incompatible "
"with merge_mode \"add\" at the moment "
"because it doesn't make sense to use "
"nearest neighbour to reduce "
"depth channels (by half).")
self.num_classes = num_classes
self.in_channels = in_channels
self.start_filts = start_filts
self.depth = depth
self.down_convs = []
self.up_convs = []
self.noiseSTD = nn.Parameter(data=torch.log(torch.tensor(0.5)))
# create the encoder pathway and add to a list
for i in range(depth):
ins = self.in_channels if i == 0 else outs
outs = self.start_filts*(2**i)
pooling = True if i < depth-1 else False
down_conv = DownConv(ins, outs, pooling=pooling)
self.down_convs.append(down_conv)
# create the decoder pathway and add to a list
# - careful! decoding only requires depth-1 blocks
for i in range(depth-1):
ins = outs
outs = ins // 2
up_conv = UpConv(ins, outs, up_mode=up_mode,
merge_mode=merge_mode)
self.up_convs.append(up_conv)
self.conv_final = conv1x1(outs, self.num_classes)
self.sigmoid=nn.Sigmoid().cuda()
# add the list of modules to current module
self.down_convs = nn.ModuleList(self.down_convs)
self.up_convs = nn.ModuleList(self.up_convs)
self.reset_params()
@staticmethod
def weight_init(m):
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal(m.weight)
nn.init.constant(m.bias, 0)
def reset_params(self):
for i, m in enumerate(self.modules()):
self.weight_init(m)
def forward(self, x):
encoder_outs = []
# encoder pathway, save outputs for merging
for i, module in enumerate(self.down_convs):
x, before_pool = module(x)
encoder_outs.append(before_pool)
for i, module in enumerate(self.up_convs):
before_pool = encoder_outs[-(i+2)]
x = module(before_pool, x)
before_x=self.conv_final(x)
if self.use_type=='optimze_gat':
x=before_x
else:
x = before_x**2
return torch.mean(x, dim=(2,3)).squeeze()
class New1(nn.Module):
def __init__(self, in_ch, out_ch):
super(New1, self).__init__()
self.mask = torch.from_numpy(np.array([[1,1,1],[1,0,1],[1,1,1]], dtype=np.float32)).cuda()
self.conv1 = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, padding = 1, kernel_size = 3)
def forward(self, x):
self.conv1.weight.data = self.conv1.weight * self.mask
x = self.conv1(x)
return x
class New2(nn.Module):
def __init__(self, in_ch, out_ch):
super(New2, self).__init__()
self.mask = torch.from_numpy(np.array([[0,1,0,1,0],[1,0,0,0,1],[0,0,1,0,0],[1,0,0,0,1],[0,1,0,1,0]], dtype=np.float32)).cuda()
self.conv1 = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, padding = 2, kernel_size = 5)
def forward(self, x):
self.conv1.weight.data = self.conv1.weight * self.mask
x = self.conv1(x)
return x
class New3(nn.Module):
def __init__(self, in_ch, out_ch, dilated_value):
super(New3, self).__init__()
self.mask = torch.from_numpy(np.array([[1,0,1],[0,1,0],[1,0,1]], dtype=np.float32)).cuda()
self.conv1 = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size = 3, padding=dilated_value, dilation=dilated_value)
def forward(self, x):
self.conv1.weight.data = self.conv1.weight * self.mask
x = self.conv1(x)
return x
class Residual_module(nn.Module):
def __init__(self, in_ch, mul = 1):
super(Residual_module, self).__init__()
self.activation1 = nn.PReLU(in_ch*mul,0).cuda()
self.activation2 = nn.PReLU(in_ch,0).cuda()
self.conv1_1by1 = nn.Conv2d(in_channels=in_ch, out_channels=in_ch*mul, kernel_size = 1)
self.conv2_1by1 = nn.Conv2d(in_channels=in_ch*mul, out_channels=in_ch, kernel_size = 1)
def forward(self, input):
output_residual = self.conv1_1by1(input)
output_residual = self.activation1(output_residual)
output_residual = self.conv2_1by1(output_residual)
output = (input + output_residual) / 2.
output = self.activation2(output)
return output
class Gaussian(nn.Module):
def forward(self,input):
return torch.exp(-torch.mul(input,input))
class Receptive_attention(nn.Module):
def __init__(self, in_ch, at_type = 'softmax'):
super(Receptive_attention, self).__init__()
self.activation1 = nn.ReLU().cuda()
self.activation2 = nn.ReLU().cuda()
self.activation3 = nn.PReLU(in_ch,0).cuda()
self.conv1_1by1 = nn.Conv2d(in_channels=in_ch, out_channels=in_ch*4, kernel_size = 1)
self.conv2_1by1 = nn.Conv2d(in_channels=in_ch*4, out_channels=in_ch*4, kernel_size = 1)
self.conv3_1by1 = nn.Conv2d(in_channels=in_ch*4, out_channels=9, kernel_size = 1)
self.at_type = at_type
if at_type == 'softmax':
self.softmax = nn.Softmax()
else:
self.gaussian = Gaussian()
self.sigmoid = nn.Sigmoid()
def forward(self, input, receptive):
if self.at_type == 'softmax':
output_residual = self.conv1_1by1(input)
output_residual = self.activation1(output_residual)
output_residual = self.conv2_1by1(output_residual)
output_residual = self.activation2(output_residual)
output_residual = self.conv3_1by1(output_residual)
output_residual = F.adaptive_avg_pool2d(output_residual, (1, 1))
# output_residual = self.Gaussian(output_residual)
output_residual = self.softmax(output_residual).permute((1,0,2,3)).unsqueeze(-1)
else:
output_residual = self.conv1_1by1(input)
output_residual = self.activation1(output_residual)
output_residual = self.conv2_1by1(output_residual)
output_residual = self.activation2(output_residual)
output_residual = self.conv3_1by1(output_residual)
output_residual = F.adaptive_avg_pool2d(output_residual, (1, 1))
output_residual = self.gaussian(output_residual)
output_residual = self.sigmoid(output_residual).permute((1,0,2,3)).unsqueeze(-1)
output = torch.sum(receptive * output_residual, dim = 0)
output = self.activation3(output)
return output
class New1_layer(nn.Module):
def __init__(self, in_ch, out_ch, case = 'FBI_Net', mul = 1):
super(New1_layer, self).__init__()
self.case = case
self.new1 = New1(in_ch,out_ch).cuda()
if case == 'case1' or case == 'case2' or case == 'case7' or case == 'FBI_Net':
self.residual_module = Residual_module(out_ch, mul)
self.activation_new1 = nn.PReLU(in_ch,0).cuda()
def forward(self, x):
if self.case == 'case1' or self.case =='case2' or self.case =='case7' or self.case == 'FBI_Net': # plain NN architecture wo residual module and residual connection
output_new1 = self.new1(x)
output_new1 = self.activation_new1(output_new1)
output = self.residual_module(output_new1)
return output, output_new1
else: # final model
output_new1 = self.new1(x)
output = self.activation_new1(output_new1)
return output, output_new1
class New2_layer(nn.Module):
def __init__(self, in_ch, out_ch, case = 'FBI_Net', mul = 1):
super(New2_layer, self).__init__()
self.case = case
self.new2 = New2(in_ch,out_ch).cuda()
self.activation_new1 = nn.PReLU(in_ch,0).cuda()
if case == 'case1' or case == 'case2' or case == 'case7' or case == 'FBI_Net':
self.residual_module = Residual_module(out_ch, mul)
if case == 'case1' or case == 'case3' or case == 'case6' or case == 'FBI_Net':
self.activation_new2 = nn.PReLU(in_ch,0).cuda()
def forward(self, x, output_new):
if self.case == 'case1': #
output_new2 = self.new2(output_new)
output_new2 = self.activation_new1(output_new2)
output = (output_new2 + x) / 2.
output = self.activation_new2(output)
output = self.residual_module(output)
return output, output_new2
elif self.case == 'case2' or self.case == 'case7': #
output_new2 = self.new2(x)
output_new2 = self.activation_new1(output_new2)
output = output_new2
output = self.residual_module(output)
return output, output_new2
elif self.case == 'case3' or self.case == 'case6': #
output_new2 = self.new2(output_new)
output_new2 = self.activation_new1(output_new2)
output = (output_new2 + x) / 2.
output = self.activation_new2(output)
return output, output_new2
elif self.case == 'case4': #
output_new2 = self.new2(x)
output_new2 = self.activation_new1(output_new2)
output = output_new2
return output, output_new2
elif self.case == 'case5' : #
output_new2 = self.new2(x)
output_new2 = self.activation_new1(output_new2)
output = output_new2
return output, output_new2
else:
output_new2 = self.new2(output_new)
output_new2 = self.activation_new1(output_new2)
output = (output_new2 + x) / 2.
output = self.activation_new2(output)
output = self.residual_module(output)
return output, output_new2
class New3_layer(nn.Module):
def __init__(self, in_ch, out_ch, dilated_value=3, case = 'FBI_Net', mul = 1):
super(New3_layer, self).__init__()
self.case = case
self.new3 = New3(in_ch,out_ch,dilated_value).cuda()
self.activation_new1 = nn.PReLU(in_ch,0).cuda()
if case == 'case1' or case == 'case2' or case == 'case7' or case == 'FBI_Net':
self.residual_module = Residual_module(out_ch, mul)
if case == 'case1' or case == 'case3' or case == 'case6'or case == 'FBI_Net':
self.activation_new2 = nn.PReLU(in_ch,0).cuda()
def forward(self, x, output_new):
if self.case == 'case1': #
output_new3 = self.new3(output_new)
output_new3 = self.activation_new1(output_new3)
output = (output_new3 + x) / 2.
output = self.activation_new2(output)
output = self.residual_module(output)
return output, output_new3
elif self.case == 'case2' or self.case == 'case7': #
output_new3 = self.new3(x)
output_new3 = self.activation_new1(output_new3)
output = output_new3
output = self.residual_module(output)
return output, output_new3
elif self.case == 'case3' or self.case == 'case6': #
output_new3 = self.new3(output_new)
output_new3 = self.activation_new1(output_new3)
output = (output_new3 + x) / 2.
output = self.activation_new2(output)
return output, output_new3
elif self.case == 'case4': #
output_new3 = self.new3(x)
output_new3 = self.activation_new1(output_new3)
output = output_new3
return output, output_new3
elif self.case == 'case5': #
output_new3 = self.new3(x)
output_new3 = self.activation_new1(output_new3)
output = output_new3
return output, output_new3
else:
output_new3 = self.new3(output_new)
output_new3 = self.activation_new1(output_new3)
output = (output_new3 + x) / 2.
output = self.activation_new2(output)
output = self.residual_module(output)
return output, output_new3
class AttrProxy(object):
"""Translates index lookups into attribute lookups."""
def __init__(self, module, prefix):
self.module = module
self.prefix = prefix
def __getitem__(self, i):
return getattr(self.module, self.prefix + str(i))
class FBI_Net(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
channel = args['channel']
output_channel = args['output_channel']
filters = args['nf']
mul = args['mul']
num_of_layers = args['num_of_layers']
case = args['case']
output_type = args['output_type']
sigmoid_value = args['sigmoid_value']
self.res = args['res']
self.case = case
self.new1 = New1_layer(channel, filters, mul = mul, case = case).cuda()
self.new2 = New2_layer(filters, filters, mul = mul, case = case).cuda()
self.num_layers = num_of_layers
self.output_type = output_type
self.sigmoid_value = sigmoid_value
dilated_value = 3
for layer in range (num_of_layers-2):
self.add_module('new_' + str(layer), New3_layer(filters, filters, dilated_value, mul = mul, case = case).cuda())
self.residual_module = Residual_module(filters, mul)
self.activation = nn.PReLU(filters,0).cuda()
self.output_layer = nn.Conv2d(in_channels=filters, out_channels=output_channel, kernel_size = 1).cuda()
if self.output_type == 'sigmoid':
self.sigmoid=nn.Sigmoid().cuda()
self.new = AttrProxy(self, 'new_')
def forward(self, x):
if self.case == 'FBI_Net' or self.case == 'case2' or self.case == 'case3' or self.case == 'case4':
output, output_new = self.new1(x)
output_sum = output
output, output_new = self.new2(output, output_new)
output_sum = output + output_sum
for i, (new_layer) in enumerate(self.new):
output, output_new = new_layer(output, output_new)
output_sum = output + output_sum
if i == self.num_layers - 3:
break
final_output = self.activation(output_sum/self.num_layers)
final_output = self.residual_module(final_output)
final_output = self.output_layer(final_output)
else:
output, output_new = self.new1(x)
output, output_new = self.new2(output, output_new)
for i, (new_layer) in enumerate(self.new):
output, output_new = new_layer(output, output_new)
if i == self.num_layers - 3:
break
final_output = self.activation(output)
final_output = self.residual_module(final_output)
final_output = self.output_layer(final_output)
if self.output_type=='sigmoid':
final_output[:,0]=(torch.ones_like(final_output[:,0])*self.sigmoid_value)*self.sigmoid(final_output[:,0])
if self.res:
final_output = final_output[:,:1] * x + final_output[:,1:]
return final_output
class SelfSupUNet(nn.Module):
def __init__(self, args):
"""
Args:
in_channels (int): number of input channels, Default 4
depth (int): depth of the network, Default 5
nf (int): number of filters in the first layer, Default 32
"""
super().__init__()
in_channels = args['in_nc']
out_channels = args['out_nc']
depth = args['depth'] if 'depth' in args else 5
nf = args['nf'] if 'nf' in args else 32
slope = args['slope'] if 'slope' in args else 0.1
self.norm = args['norm'] if 'norm' in args else False
self.res = args['res'] if 'res' in args else False
self.depth = depth
self.head = nn.Sequential(
LR(in_channels, nf, 3, slope), LR(nf, nf, 3, slope))
self.down_path = nn.ModuleList()
for i in range(depth):
self.down_path.append(LR(nf, nf, 3, slope))
self.up_path = nn.ModuleList()
for i in range(depth):
if i != depth-1:
self.up_path.append(UP(nf*2 if i==0 else nf*3, nf*2, slope))
else:
self.up_path.append(UP(nf*2+in_channels, nf*2, slope))
self.last = nn.Sequential(LR(2*nf, 2*nf, 1, slope),
LR(2*nf, 2*nf, 1, slope), conv1x1(2*nf, out_channels, bias=True))
def forward(self, x):
if self.norm:
x, lb, ub = data_normalize(x)
blocks = []
blocks.append(x)
x = self.head(x)
for i, down in enumerate(self.down_path):
x = F.max_pool2d(x, 2)
if i != len(self.down_path) - 1:
blocks.append(x)
x = down(x)
for i, up in enumerate(self.up_path):
x = up(x, blocks[-i-1])
out = self.last(x)
if self.res:
out = out + x
if self.norm:
out = data_inv_normalize(out, lb, ub)
return out
class LR(nn.Module):
def __init__(self, in_size, out_size, ksize=3, slope=0.1):
super(LR, self).__init__()
block = []
block.append(nn.Conv2d(in_size, out_size,
kernel_size=ksize, padding=ksize//2, bias=True))
block.append(nn.LeakyReLU(slope, inplace=False))
self.block = nn.Sequential(*block)
def forward(self, x):
out = self.block(x)
return out
class UP(nn.Module):
def __init__(self, in_size, out_size, slope=0.1):
super(UP, self).__init__()
self.conv_1 = LR(in_size, out_size)
self.conv_2 = LR(out_size, out_size)
def up(self, x):
s = x.shape
x = x.reshape(s[0], s[1], s[2], 1, s[3], 1)
x = x.repeat(1, 1, 1, 2, 1, 2)
x = x.reshape(s[0], s[1], s[2]*2, s[3]*2)
return x
def forward(self, x, pool):
x = self.up(x)
x = torch.cat([x, pool], 1)
x = self.conv_1(x)
x = self.conv_2(x)
return x
class SelfResUNet(nn.Module):
def __init__(self, args):
"""
Args:
in_channels (int): number of input channels, Default 4
depth (int): depth of the network, Default 5
nf (int): number of filters in the first layer, Default 32
"""
super().__init__()
in_channels = args['in_nc']
out_channels = args['out_nc']
depth = args['depth'] if 'depth' in args else 5
nf = args['nf'] if 'nf' in args else 32
slope = args['slope'] if 'slope' in args else 0.1
self.norm = args['norm'] if 'norm' in args else False
self.res = args['res'] if 'res' in args else False
self.depth = depth
self.head = Res(in_channels, nf, slope)
self.down_path = nn.ModuleList()
for i in range(depth):
self.down_path.append(Res(nf, nf, slope, ksize=3))
self.up_path = nn.ModuleList()
for i in range(depth):
if i != depth-1:
self.up_path.append(RUP(nf*2 if i==0 else nf*3, nf*2, slope))
else:
self.up_path.append(RUP(nf*2+in_channels, nf*2, slope))
self.last = Res(2*nf, 2*nf, slope, ksize=1)
self.out = conv1x1(2*nf, out_channels, bias=True)
def forward(self, x):
if self.norm:
x, lb, ub = data_normalize(x)
inp = x
blocks = []
blocks.append(x)
x = self.head(x)
for i, down in enumerate(self.down_path):
x = F.max_pool2d(x, 2)
if i != len(self.down_path) - 1:
blocks.append(x)
x = down(x)
for i, up in enumerate(self.up_path):
x = up(x, blocks[-i-1])
out = self.last(x)
out = self.out(out)
if self.res:
out = out + inp
if self.norm:
out = data_inv_normalize(out, lb, ub)
return out
class RUP(nn.Module):
def __init__(self, in_size, out_size, slope=0.1, ksize=3):
super(RUP, self).__init__()
self.conv_1 = LR(out_size, out_size, ksize=ksize, slope=slope)
self.conv_2 = LR(out_size, out_size, ksize=ksize, slope=slope)
if in_size != out_size:
self.short_cut = nn.Sequential(conv1x1(in_size, out_size))
else:
self.short_cut = nn.Sequential(OrderedDict([]))
def up(self, x):
s = x.shape
x = x.reshape(s[0], s[1], s[2], 1, s[3], 1)
x = x.repeat(1, 1, 1, 2, 1, 2)
x = x.reshape(s[0], s[1], s[2]*2, s[3]*2)
return x
def forward(self, x, pool):
x = self.up(x)
x = torch.cat([x, pool], 1)
x = self.short_cut(x)
z = self.conv_1(x)
z = self.conv_2(z)
z += x
return z
class Res(nn.Module):
def __init__(self, in_size, out_size, slope=0.1, ksize=3):
super().__init__()
self.conv_1 = LR(out_size, out_size, ksize=ksize, slope=slope)
self.conv_2 = LR(out_size, out_size, ksize=ksize, slope=slope)
if in_size != out_size:
self.short_cut = nn.Sequential(conv1x1(in_size, out_size))
else:
self.short_cut = nn.Sequential(OrderedDict([]))
def forward(self, x):
x = self.short_cut(x)
z = self.conv_1(x)
z = self.conv_2(z)
z += x
return z
def conv1x1(in_chn, out_chn, bias=True):
layer = nn.Conv2d(in_chn, out_chn, kernel_size=1,
stride=1, padding=0, bias=bias)
return layer
class GuidedSelfUnet(nn.Module):
def __init__(self, args):
"""
Args:
in_channels (int): number of input channels, Default 4
depth (int): depth of the network, Default 5
nf (int): number of filters in the first layer, Default 32
"""
super().__init__()
in_channels = args['in_nc']
out_channels = args['out_nc']
depth = args['depth'] if 'depth' in args else 5
nf = args['nf'] if 'nf' in args else 32
slope = args['slope'] if 'slope' in args else 0.1
self.norm = args['norm'] if 'norm' in args else False
self.res = args['res'] if 'res' in args else False
self.depth = depth
self.head = GRes(in_channels, nf, slope)
self.down_path = nn.ModuleList()
for i in range(depth):
self.down_path.append(GLR(nf, nf, 3, slope))
self.up_path = nn.ModuleList()
for i in range(depth):
if i != depth-1:
self.up_path.append(GUP(nf*2 if i==0 else nf*3, nf*2, slope))
else:
self.up_path.append(GUP(nf*2+in_channels, nf*2, slope))
self.last = GRes(2*nf, 2*nf, slope, ksize=1)
self.out = conv1x1(2*nf, out_channels, bias=True)
def forward(self, x, t):
if self.norm:
x, lb, ub = data_normalize(x)
t = t / (ub-lb)
blocks = []
blocks.append(x)
x = self.head(x, t)
for i, down in enumerate(self.down_path):
x = F.max_pool2d(x, 2)
if i != len(self.down_path) - 1:
blocks.append(x)
x = down(x, t)
for i, up in enumerate(self.up_path):
x = up(x, blocks[-i-1], t)
out = self.last(x, t)
out = self.out(out)
if self.res:
out = out + x
if self.norm:
out = data_inv_normalize(out, lb, ub)
return out
class GLR(nn.Module):
def __init__(self, in_size, out_size, ksize=3, slope=0.1):
super(GLR, self).__init__()
self.block = nn.Conv2d(in_size, out_size,
kernel_size=ksize, padding=ksize//2, bias=True)
self.act = nn.LeakyReLU(slope, inplace=False)
self.gamma = nn.Sequential(
conv1x1(1, out_size),
nn.SiLU(),
conv1x1(out_size, out_size),
)
self.beta = nn.Sequential(
nn.SiLU(),
conv1x1(out_size, out_size),
)
def forward(self, x, t):
z = self.block(x)
tk = self.gamma(t)
tb = self.beta(tk)
z = z * tk + tb
out = self.act(z)
return out
class GRes(nn.Module):
def __init__(self, in_size, out_size, slope=0.1, ksize=3):
super(GRes, self).__init__()
self.conv_1 = LR(out_size, out_size, ksize=ksize)
self.conv_2 = GLR(out_size, out_size, ksize=ksize)
if in_size != out_size:
self.short_cut = nn.Sequential(
conv1x1(in_size, out_size)
)
else:
self.short_cut = nn.Sequential(OrderedDict([]))
def forward(self, x, t):
x = self.short_cut(x)
z = self.conv_1(x)
z = self.conv_2(z, t)
z += x
return z
class GUP(nn.Module):
def __init__(self, in_size, out_size, slope=0.1):
super(GUP, self).__init__()
self.conv_1 = LR(out_size, out_size)
self.conv_2 = GLR(out_size, out_size)
if in_size != out_size:
self.short_cut = nn.Sequential(
conv1x1(in_size, out_size)
)
else:
self.short_cut = nn.Sequential(OrderedDict([]))
def up(self, x):
s = x.shape
x = x.reshape(s[0], s[1], s[2], 1, s[3], 1)
x = x.repeat(1, 1, 1, 2, 1, 2)
x = x.reshape(s[0], s[1], s[2]*2, s[3]*2)
return x
def forward(self, x, pool, t):
x = self.up(x)
x = torch.cat([x, pool], 1)
x = self.short_cut(x)
z = self.conv_1(x)
z = self.conv_2(z, t)
z += x
return z
class N2NF_Unet(nn.Module):
def __init__(self, args=None):
super().__init__()
self.args = args
in_nc = args['in_nc']
out_nc = args['out_nc']
self.norm = args['norm'] if 'norm' in args else False
# Layers: enc_conv0, enc_conv1, pool1
self._block1 = nn.Sequential(
nn.Conv2d(in_nc, 48, 3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(48, 48, 3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2))
# Layers: enc_conv(i), pool(i); i=2..5
self._block2 = nn.Sequential(
nn.Conv2d(48, 48, 3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2))
# Layers: enc_conv6, upsample5
self._block3 = nn.Sequential(
nn.Conv2d(48, 48, 3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(48, 48, 3, stride=2, padding=1, output_padding=1))
#nn.Upsample(scale_factor=2, mode='nearest'))
# Layers: dec_conv5a, dec_conv5b, upsample4
self._block4 = nn.Sequential(
nn.Conv2d(96, 96, 3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(96, 96, 3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(96, 96, 3, stride=2, padding=1, output_padding=1))
#nn.Upsample(scale_factor=2, mode='nearest'))
# Layers: dec_deconv(i)a, dec_deconv(i)b, upsample(i-1); i=4..2
self._block5 = nn.Sequential(
nn.Conv2d(144, 96, 3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(96, 96, 3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(96, 96, 3, stride=2, padding=1, output_padding=1))
#nn.Upsample(scale_factor=2, mode='nearest'))
# Layers: dec_conv1a, dec_conv1b, dec_conv1c,
self._block6 = nn.Sequential(
nn.Conv2d(96 + in_nc, 64, 3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 32, 3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(32, out_nc, 3, stride=1, padding=1),
nn.LeakyReLU(0.1))
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Initializes weights using He et al. (2015)."""
for m in self.modules():
if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight.data)
m.bias.data.zero_()
def forward(self, x):
if self.norm:
x, lb, ub = data_normalize(x)
# Encoder
pool1 = self._block1(x)
pool2 = self._block2(pool1)
pool3 = self._block2(pool2)
pool4 = self._block2(pool3)
pool5 = self._block2(pool4)
# Decoder
upsample5 = self._block3(pool5)
concat5 = torch.cat((upsample5, pool4), dim=1)
upsample4 = self._block4(concat5)
concat4 = torch.cat((upsample4, pool3), dim=1)
upsample3 = self._block5(concat4)
concat3 = torch.cat((upsample3, pool2), dim=1)
upsample2 = self._block5(concat3)
concat2 = torch.cat((upsample2, pool1), dim=1)
upsample1 = self._block5(concat2)
concat1 = torch.cat((upsample1, x), dim=1)
# Final activation
out = self._block6(concat1)
if self.norm:
out = data_inv_normalize(out, lb, ub)
return out