YOND / archs /modules.py
hansen97's picture
Initial clean commit
0e07d71
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import OrderedDict
# 小波分解相关代码
from pytorch_wavelets import DWTForward, DWTInverse # (or import DWT, IDWT)
class WaveletDecompose(nn.Module):
def __init__(self, mode='haar'):
super().__init__()
self.xfm = DWTForward(J=1, wave=mode, mode='reflect')
def forward(self, x):
"""
将一层小波分解结果转换为通道拼接格式
Args:
x: 输入张量,形状为 (B, C, H, W)
Returns:
output: 拼接后的张量,形状为 (B, 4*C, H//2, W//2)
通道顺序: [LL, HL, LH, HH]
"""
yl, yh = self.xfm(x)
# yl: (B, C, H//2, W//2) - LL子带
# yh[0]: (B, C, 3, H//2, W//2) - 高频系数
# 提取三个方向的高频系数
hl = yh[0][:, :, 0, :, :] # HL: 水平细节
lh = yh[0][:, :, 1, :, :] # LH: 垂直细节
hh = yh[0][:, :, 2, :, :] # HH: 对角细节
# 沿通道维度拼接
output = torch.cat([yl, hl, lh, hh], dim=1)
return output
class WaveletReconstruct(nn.Module):
def __init__(self, mode='haar'):
super().__init__()
self.ifm = DWTInverse(wave=mode, mode='reflect')
def forward(self, x):
"""
将通道拼接的小波系数还原为原始图像
Args:
x: 输入张量,形状为 (B, 4*C, H, W)
Returns:
重构后的图像,形状为 (B, C, 2*H, 2*W)
"""
batch_size, total_channels, height, width = x.shape
channels = total_channels // 4
# 分割通道
yl = x[:, :channels, :, :] # LL
hl = x[:, channels:2*channels, :, :] # HL
lh = x[:, 2*channels:3*channels, :, :] # LH
hh = x[:, 3*channels:4*channels, :, :] # HH
# 重新组织为 pytorch_wavelets 需要的格式
# 创建 yh 列表,第一个元素是形状为 (B, C, 3, H, W) 的张量
yh_coeff = torch.stack([hl, lh, hh], dim=2) # 在dim=2上堆叠
yh = [yh_coeff] # 必须放在列表中
# 执行逆变换
reconstructed = self.ifm((yl, yh))
return reconstructed
def make_layer(block, n_layers):
layers = []
for _ in range(n_layers):
layers.append(block)
return nn.Sequential(*layers)
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
class Module_with_Init(nn.Module):
def __init__(self,):
super().__init__()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0.0, 0.02)
if m.bias is not None:
m.bias.data.normal_(0.0, 0.02)
if isinstance(m, nn.ConvTranspose2d):
m.weight.data.normal_(0.0, 0.02)
def lrelu(self, x):
outt = torch.max(0.2*x, x)
return outt
class ResConvBlock_CBAM(nn.Module):
def __init__(self, in_nc, nf=64, res_scale=1):
super().__init__()
self.res_scale = res_scale
self.conv1 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.cbam = CBAM(nf)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.conv1(x))
out = self.res_scale * self.cbam(self.relu(self.conv2(x))) + x
return x + out * self.res_scale
class ResidualBlockNoBN(nn.Module):
"""Residual block without BN.
It has a style of:
---Conv-ReLU-Conv-+-
|________________|
Args:
nf (int): Channel number of intermediate features.
Default: 64.
res_scale (float): Residual scale. Default: 1.
"""
def __init__(self, nf=64, res_scale=1):
super(ResidualBlockNoBN, self).__init__()
self.res_scale = res_scale
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.relu = nn.ReLU()
def forward(self, x):
identity = x
out = self.conv2(self.relu(self.conv1(x)))
return identity + out * self.res_scale
def conv1x1(in_nc, out_nc, groups=1):
return nn.Conv2d(in_nc, out_nc,kernel_size=1,groups=groups,stride=1)
class Identity(nn.Identity):
def __init__(self, args):
super().__init__()
class ResidualBlock3D(nn.Module):
def __init__(self, in_c, out_c, is_activate=True):
super().__init__()
self.activation = nn.ReLU(inplace=True) if is_activate else nn.Sequential()
self.block = nn.Sequential(
nn.Conv3d(in_c, out_c, kernel_size=3, padding=1, stride=1),
self.activation,
nn.Conv3d(out_c, out_c, kernel_size=3, padding=1, stride=1)
)
if in_c != out_c:
self.short_cut = nn.Sequential(
nn.Conv3d(in_c, out_c, kernel_size=1, padding=0, stride=1)
)
else:
self.short_cut = nn.Sequential(OrderedDict([]))
def forward(self, x):
output = self.block(x)
output += self.short_cut(x)
output = self.activation(output)
return output
class conv3x3(nn.Module):
def __init__(self, in_nc, out_nc, stride=2, is_activate=True):
super().__init__()
self.conv =nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1, stride=stride)
if is_activate:
self.conv.add_module("relu", nn.ReLU(inplace=True))
def forward(self, x):
return self.conv(x)
class convWithBN(nn.Module):
def __init__(self, in_c, out_c, kernel_size=3, padding=1, stride=1, is_activate=True, is_bn=True):
super(convWithBN, self).__init__()
self.conv = nn.Sequential(OrderedDict([
("conv", nn.Conv2d(in_c, out_c, kernel_size=kernel_size, padding=padding,
stride=stride, bias=False)),
]))
if is_bn:
self.conv.add_module("BN", nn.BatchNorm2d(out_c))
if is_activate:
self.conv.add_module("relu", nn.ReLU(inplace=True))
def forward(self, x):
return self.conv(x)
class DoubleCvBlock(nn.Module):
def __init__(self, in_c, out_c):
super(DoubleCvBlock, self).__init__()
self.block = nn.Sequential(
convWithBN(in_c, out_c, kernel_size=3, padding=1, stride=1, is_bn=False),
convWithBN(out_c, out_c, kernel_size=3, padding=1, stride=1, is_bn=False)
)
def forward(self, x):
output = self.block(x)
return output
class nResBlocks(nn.Module):
def __init__(self, nf, nlayers=2):
super().__init__()
self.blocks = make_layer(ResidualBlock(nf, nf), n_layers=nlayers)
def forward(self, x):
return self.blocks(x)
class GuidedResidualBlock(nn.Module):
def __init__(self, in_c, out_c, is_activate=False):
super().__init__()
# self.norm = nn.LayerNorm(out_c)
self.act = nn.SiLU()
self.conv1 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True)
self.gamma = nn.Sequential(
conv1x1(1, out_c),
nn.SiLU(),
conv1x1(out_c, out_c),
)
self.beta = nn.Sequential(
nn.SiLU(),
conv1x1(out_c, out_c),
)
if in_c != out_c:
self.short_cut = nn.Sequential(
conv1x1(in_c, out_c)
)
else:
self.short_cut = nn.Sequential(OrderedDict([]))
def forward(self, x, t):
if len(t.shape) > 0 and t.shape[-1] != 1:
t = F.interpolate(t, size=x.shape[2:], mode='bilinear', align_corners=False)
x = self.short_cut(x)
z = self.act(x)
z = self.conv1(z)
tk = self.gamma(t)
tb = self.beta(tk)
z = z * tk + tb
z = self.act(z)
z = self.conv2(z)
z += x
return z
class GuidedConvBlock(nn.Module):
def __init__(self, in_c, out_c, is_activate=False):
super().__init__()
# self.norm = nn.LayerNorm(out_c)
self.act = nn.SiLU()
self.conv1 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True)
self.gamma = nn.Sequential(
conv1x1(1, out_c),
nn.SiLU(),
conv1x1(out_c, out_c),
)
self.beta = nn.Sequential(
nn.SiLU(),
conv1x1(out_c, out_c),
)
if in_c != out_c:
self.short_cut = nn.Sequential(
conv1x1(in_c, out_c)
)
else:
self.short_cut = nn.Sequential(OrderedDict([]))
def forward(self, x, t):
x = self.short_cut(x)
z = self.act(x)
z = self.conv1(z)
tk = self.gamma(t)
tb = self.beta(tk)
z = z * tk + tb
z = self.act(z)
z = self.conv2(z)
return z
class SNR_Block(nn.Module):
def __init__(self, in_c, out_c, is_activate=False):
super().__init__()
# self.norm = nn.LayerNorm(out_c)
self.act = nn.SiLU()
self.conv1 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True)
self.sfm1 = nn.Sequential(
conv1x1(1, out_c),
nn.SiLU(),
conv1x1(out_c, out_c),
)
self.sfm2 = nn.Sequential(
conv1x1(1, out_c),
nn.SiLU(),
conv1x1(out_c, out_c),
)
if in_c != out_c:
self.short_cut = nn.Sequential(
conv1x1(in_c, out_c)
)
else:
self.short_cut = nn.Sequential(OrderedDict([]))
def forward(self, x, t):
x = self.short_cut(x)
z = self.act(x)
z = self.conv1(z)
a1 = self.sfm1(t)
z *= a1
z = self.act(z)
z = self.conv2(z)
a2 = self.sfm2(t)
z *= a2
z += x
return z
class ResBlock(nn.Module):
def __init__(self, in_c, out_c, is_activate=False):
super().__init__()
# self.norm = nn.LayerNorm(out_c)
self.act = nn.LeakyReLU(0.2) if is_activate else nn.SiLU()
self.conv1 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True)
self.gamma = nn.Sequential(
conv1x1(1, out_c),
self.act,
conv1x1(out_c, out_c),
)
self.beta = nn.Sequential(
self.act,
conv1x1(out_c, out_c),
)
if in_c != out_c:
self.short_cut = nn.Sequential(
conv1x1(in_c, out_c)
)
else:
self.short_cut = nn.Sequential(OrderedDict([]))
def forward(self, x):
x = self.short_cut(x)
z = self.act(x)
z = self.conv1(z)
z = self.act(z)
z = self.conv2(z)
z += x
return z
class ResidualBlock(nn.Module):
def __init__(self, in_c, out_c, is_activate=True):
super(ResidualBlock, self).__init__()
self.block = nn.Sequential(
convWithBN(in_c, out_c, kernel_size=3, padding=1, stride=1, is_bn=False),
convWithBN(out_c, out_c, kernel_size=3, padding=1, stride=1, is_activate=False, is_bn=False)
)
if in_c != out_c:
self.short_cut = nn.Sequential(
convWithBN(in_c, out_c, kernel_size=1, padding=0, stride=1, is_activate=False, is_bn=False)
)
else:
self.short_cut = nn.Sequential(OrderedDict([]))
self.activation = nn.LeakyReLU(0.2, inplace=False) if is_activate else nn.Sequential()
def forward(self, x):
output = self.block(x)
output = self.activation(output)
output += self.short_cut(x)
return output
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super().__init__()
self.in_nc = in_planes
self.ratio = ratio
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.sharedMLP = nn.Sequential(
nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), nn.ReLU(),
nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False))
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avgout = self.sharedMLP(self.avg_pool(x))
maxout = self.sharedMLP(self.max_pool(x))
return self.sigmoid(avgout + maxout)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=3):
super().__init__()
self.conv = nn.Conv2d(2,1,kernel_size, padding=1, bias=False)
self.sigmoid = nn.Sigmoid()
self.concat = Concat()
self.mean = torch.mean
self.max = torch.max
def forward(self, x):
avgout = self.mean(x, 1, True)
maxout, _ = self.max(x, 1, True)
x = self.concat([avgout, maxout], 1)
x = self.conv(x)
return self.sigmoid(x)
class CBAM(nn.Module):
def __init__(self, planes):
super().__init__()
self.ca = ChannelAttention(planes)
self.sa = SpatialAttention()
def forward(self, x):
x = self.ca(x) * x
out = self.sa(x) * x
return out
class MaskMul(nn.Module):
def __init__(self, scale_factor=1):
super().__init__()
self.scale_factor = scale_factor
def forward(self, x, mask):
if mask.shape[1] != x.shape[1]:
mask = torch.mean(mask, dim=1, keepdim=True)
pooled_mask = F.avg_pool2d(mask, self.scale_factor)
out = torch.mul(x, pooled_mask)
return out
class UpsampleBLock(nn.Module):
def __init__(self, in_channels, out_channels=None, up_scale=2, mode='bilinear'):
super(UpsampleBLock, self).__init__()
if mode == 'pixel_shuffle':
self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
self.up = nn.PixelShuffle(up_scale)
elif mode=='bilinear':
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.up = nn.UpsamplingBilinear2d(scale_factor=up_scale)
else:
print(f"Please tell me what is '{mode}' mode ????")
raise NotImplementedError
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.up(x)
x = self.relu(x)
return x
def pixel_unshuffle(input, downscale_factor):
'''
input: batchSize * c * k*w * k*h
kdownscale_factor: k
batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h
'''
c = input.shape[1]
kernel = torch.zeros(size=[downscale_factor * downscale_factor * c,
1, downscale_factor, downscale_factor],
device=input.device)
for y in range(downscale_factor):
for x in range(downscale_factor):
kernel[x + y * downscale_factor::downscale_factor*downscale_factor, 0, y, x] = 1
return F.conv2d(input, kernel, stride=downscale_factor, groups=c)
class PixelUnshuffle(nn.Module):
def __init__(self, downscale_factor):
super(PixelUnshuffle, self).__init__()
self.downscale_factor = downscale_factor
def forward(self, input):
'''
input: batchSize * c * k*w * k*h
kdownscale_factor: k
batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h
'''
return pixel_unshuffle(input, self.downscale_factor)
class Concat(nn.Module):
def __init__(self, dim=1):
super().__init__()
self.dim = 1
self.concat = torch.cat
def padding(self, tensors):
if len(tensors) > 2:
return tensors
x , y = tensors
xb, xc, xh, xw = x.size()
yb, yc, yh, yw = y.size()
diffY = xh - yh
diffX = xw - yw
y = F.pad(y, (diffX // 2, diffX - diffX//2,
diffY // 2, diffY - diffY//2))
return (x, y)
def forward(self, x, dim=None):
x = self.padding(x)
return self.concat(x, dim if dim is not None else self.dim)
# if __name__ == '__main__':
# from torchsummary import summary
# x = torch.randn((1,32,16,16))
# for k in range(1,3):
# # up = upsample(32, 2**k)
# # down = downsample(32//(2**k), 2**k)
# # x_up = up(x)
# # x_down = down(x_up)
# # s_up = (32,16,16)
# # summary(up,s,device='cpu')
# # summary(down,s,device='cpu')
# print(k)