Spaces:
Running
on
Zero
Running
on
Zero
| import functools | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from collections import OrderedDict | |
| # 小波分解相关代码 | |
| from pytorch_wavelets import DWTForward, DWTInverse # (or import DWT, IDWT) | |
| class WaveletDecompose(nn.Module): | |
| def __init__(self, mode='haar'): | |
| super().__init__() | |
| self.xfm = DWTForward(J=1, wave=mode, mode='reflect') | |
| def forward(self, x): | |
| """ | |
| 将一层小波分解结果转换为通道拼接格式 | |
| Args: | |
| x: 输入张量,形状为 (B, C, H, W) | |
| Returns: | |
| output: 拼接后的张量,形状为 (B, 4*C, H//2, W//2) | |
| 通道顺序: [LL, HL, LH, HH] | |
| """ | |
| yl, yh = self.xfm(x) | |
| # yl: (B, C, H//2, W//2) - LL子带 | |
| # yh[0]: (B, C, 3, H//2, W//2) - 高频系数 | |
| # 提取三个方向的高频系数 | |
| hl = yh[0][:, :, 0, :, :] # HL: 水平细节 | |
| lh = yh[0][:, :, 1, :, :] # LH: 垂直细节 | |
| hh = yh[0][:, :, 2, :, :] # HH: 对角细节 | |
| # 沿通道维度拼接 | |
| output = torch.cat([yl, hl, lh, hh], dim=1) | |
| return output | |
| class WaveletReconstruct(nn.Module): | |
| def __init__(self, mode='haar'): | |
| super().__init__() | |
| self.ifm = DWTInverse(wave=mode, mode='reflect') | |
| def forward(self, x): | |
| """ | |
| 将通道拼接的小波系数还原为原始图像 | |
| Args: | |
| x: 输入张量,形状为 (B, 4*C, H, W) | |
| Returns: | |
| 重构后的图像,形状为 (B, C, 2*H, 2*W) | |
| """ | |
| batch_size, total_channels, height, width = x.shape | |
| channels = total_channels // 4 | |
| # 分割通道 | |
| yl = x[:, :channels, :, :] # LL | |
| hl = x[:, channels:2*channels, :, :] # HL | |
| lh = x[:, 2*channels:3*channels, :, :] # LH | |
| hh = x[:, 3*channels:4*channels, :, :] # HH | |
| # 重新组织为 pytorch_wavelets 需要的格式 | |
| # 创建 yh 列表,第一个元素是形状为 (B, C, 3, H, W) 的张量 | |
| yh_coeff = torch.stack([hl, lh, hh], dim=2) # 在dim=2上堆叠 | |
| yh = [yh_coeff] # 必须放在列表中 | |
| # 执行逆变换 | |
| reconstructed = self.ifm((yl, yh)) | |
| return reconstructed | |
| def make_layer(block, n_layers): | |
| layers = [] | |
| for _ in range(n_layers): | |
| layers.append(block) | |
| return nn.Sequential(*layers) | |
| class SiLU(nn.Module): | |
| def forward(self, x): | |
| return x * torch.sigmoid(x) | |
| class GroupNorm32(nn.GroupNorm): | |
| def forward(self, x): | |
| return super().forward(x.float()).type(x.dtype) | |
| class Module_with_Init(nn.Module): | |
| def __init__(self,): | |
| super().__init__() | |
| def _initialize_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| m.weight.data.normal_(0.0, 0.02) | |
| if m.bias is not None: | |
| m.bias.data.normal_(0.0, 0.02) | |
| if isinstance(m, nn.ConvTranspose2d): | |
| m.weight.data.normal_(0.0, 0.02) | |
| def lrelu(self, x): | |
| outt = torch.max(0.2*x, x) | |
| return outt | |
| class ResConvBlock_CBAM(nn.Module): | |
| def __init__(self, in_nc, nf=64, res_scale=1): | |
| super().__init__() | |
| self.res_scale = res_scale | |
| self.conv1 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) | |
| self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) | |
| self.cbam = CBAM(nf) | |
| self.relu = nn.ReLU() | |
| def forward(self, x): | |
| x = self.relu(self.conv1(x)) | |
| out = self.res_scale * self.cbam(self.relu(self.conv2(x))) + x | |
| return x + out * self.res_scale | |
| class ResidualBlockNoBN(nn.Module): | |
| """Residual block without BN. | |
| It has a style of: | |
| ---Conv-ReLU-Conv-+- | |
| |________________| | |
| Args: | |
| nf (int): Channel number of intermediate features. | |
| Default: 64. | |
| res_scale (float): Residual scale. Default: 1. | |
| """ | |
| def __init__(self, nf=64, res_scale=1): | |
| super(ResidualBlockNoBN, self).__init__() | |
| self.res_scale = res_scale | |
| self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) | |
| self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) | |
| self.relu = nn.ReLU() | |
| def forward(self, x): | |
| identity = x | |
| out = self.conv2(self.relu(self.conv1(x))) | |
| return identity + out * self.res_scale | |
| def conv1x1(in_nc, out_nc, groups=1): | |
| return nn.Conv2d(in_nc, out_nc,kernel_size=1,groups=groups,stride=1) | |
| class Identity(nn.Identity): | |
| def __init__(self, args): | |
| super().__init__() | |
| class ResidualBlock3D(nn.Module): | |
| def __init__(self, in_c, out_c, is_activate=True): | |
| super().__init__() | |
| self.activation = nn.ReLU(inplace=True) if is_activate else nn.Sequential() | |
| self.block = nn.Sequential( | |
| nn.Conv3d(in_c, out_c, kernel_size=3, padding=1, stride=1), | |
| self.activation, | |
| nn.Conv3d(out_c, out_c, kernel_size=3, padding=1, stride=1) | |
| ) | |
| if in_c != out_c: | |
| self.short_cut = nn.Sequential( | |
| nn.Conv3d(in_c, out_c, kernel_size=1, padding=0, stride=1) | |
| ) | |
| else: | |
| self.short_cut = nn.Sequential(OrderedDict([])) | |
| def forward(self, x): | |
| output = self.block(x) | |
| output += self.short_cut(x) | |
| output = self.activation(output) | |
| return output | |
| class conv3x3(nn.Module): | |
| def __init__(self, in_nc, out_nc, stride=2, is_activate=True): | |
| super().__init__() | |
| self.conv =nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1, stride=stride) | |
| if is_activate: | |
| self.conv.add_module("relu", nn.ReLU(inplace=True)) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class convWithBN(nn.Module): | |
| def __init__(self, in_c, out_c, kernel_size=3, padding=1, stride=1, is_activate=True, is_bn=True): | |
| super(convWithBN, self).__init__() | |
| self.conv = nn.Sequential(OrderedDict([ | |
| ("conv", nn.Conv2d(in_c, out_c, kernel_size=kernel_size, padding=padding, | |
| stride=stride, bias=False)), | |
| ])) | |
| if is_bn: | |
| self.conv.add_module("BN", nn.BatchNorm2d(out_c)) | |
| if is_activate: | |
| self.conv.add_module("relu", nn.ReLU(inplace=True)) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class DoubleCvBlock(nn.Module): | |
| def __init__(self, in_c, out_c): | |
| super(DoubleCvBlock, self).__init__() | |
| self.block = nn.Sequential( | |
| convWithBN(in_c, out_c, kernel_size=3, padding=1, stride=1, is_bn=False), | |
| convWithBN(out_c, out_c, kernel_size=3, padding=1, stride=1, is_bn=False) | |
| ) | |
| def forward(self, x): | |
| output = self.block(x) | |
| return output | |
| class nResBlocks(nn.Module): | |
| def __init__(self, nf, nlayers=2): | |
| super().__init__() | |
| self.blocks = make_layer(ResidualBlock(nf, nf), n_layers=nlayers) | |
| def forward(self, x): | |
| return self.blocks(x) | |
| class GuidedResidualBlock(nn.Module): | |
| def __init__(self, in_c, out_c, is_activate=False): | |
| super().__init__() | |
| # self.norm = nn.LayerNorm(out_c) | |
| self.act = nn.SiLU() | |
| self.conv1 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True) | |
| self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True) | |
| self.gamma = nn.Sequential( | |
| conv1x1(1, out_c), | |
| nn.SiLU(), | |
| conv1x1(out_c, out_c), | |
| ) | |
| self.beta = nn.Sequential( | |
| nn.SiLU(), | |
| conv1x1(out_c, out_c), | |
| ) | |
| if in_c != out_c: | |
| self.short_cut = nn.Sequential( | |
| conv1x1(in_c, out_c) | |
| ) | |
| else: | |
| self.short_cut = nn.Sequential(OrderedDict([])) | |
| def forward(self, x, t): | |
| if len(t.shape) > 0 and t.shape[-1] != 1: | |
| t = F.interpolate(t, size=x.shape[2:], mode='bilinear', align_corners=False) | |
| x = self.short_cut(x) | |
| z = self.act(x) | |
| z = self.conv1(z) | |
| tk = self.gamma(t) | |
| tb = self.beta(tk) | |
| z = z * tk + tb | |
| z = self.act(z) | |
| z = self.conv2(z) | |
| z += x | |
| return z | |
| class GuidedConvBlock(nn.Module): | |
| def __init__(self, in_c, out_c, is_activate=False): | |
| super().__init__() | |
| # self.norm = nn.LayerNorm(out_c) | |
| self.act = nn.SiLU() | |
| self.conv1 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True) | |
| self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True) | |
| self.gamma = nn.Sequential( | |
| conv1x1(1, out_c), | |
| nn.SiLU(), | |
| conv1x1(out_c, out_c), | |
| ) | |
| self.beta = nn.Sequential( | |
| nn.SiLU(), | |
| conv1x1(out_c, out_c), | |
| ) | |
| if in_c != out_c: | |
| self.short_cut = nn.Sequential( | |
| conv1x1(in_c, out_c) | |
| ) | |
| else: | |
| self.short_cut = nn.Sequential(OrderedDict([])) | |
| def forward(self, x, t): | |
| x = self.short_cut(x) | |
| z = self.act(x) | |
| z = self.conv1(z) | |
| tk = self.gamma(t) | |
| tb = self.beta(tk) | |
| z = z * tk + tb | |
| z = self.act(z) | |
| z = self.conv2(z) | |
| return z | |
| class SNR_Block(nn.Module): | |
| def __init__(self, in_c, out_c, is_activate=False): | |
| super().__init__() | |
| # self.norm = nn.LayerNorm(out_c) | |
| self.act = nn.SiLU() | |
| self.conv1 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True) | |
| self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True) | |
| self.sfm1 = nn.Sequential( | |
| conv1x1(1, out_c), | |
| nn.SiLU(), | |
| conv1x1(out_c, out_c), | |
| ) | |
| self.sfm2 = nn.Sequential( | |
| conv1x1(1, out_c), | |
| nn.SiLU(), | |
| conv1x1(out_c, out_c), | |
| ) | |
| if in_c != out_c: | |
| self.short_cut = nn.Sequential( | |
| conv1x1(in_c, out_c) | |
| ) | |
| else: | |
| self.short_cut = nn.Sequential(OrderedDict([])) | |
| def forward(self, x, t): | |
| x = self.short_cut(x) | |
| z = self.act(x) | |
| z = self.conv1(z) | |
| a1 = self.sfm1(t) | |
| z *= a1 | |
| z = self.act(z) | |
| z = self.conv2(z) | |
| a2 = self.sfm2(t) | |
| z *= a2 | |
| z += x | |
| return z | |
| class ResBlock(nn.Module): | |
| def __init__(self, in_c, out_c, is_activate=False): | |
| super().__init__() | |
| # self.norm = nn.LayerNorm(out_c) | |
| self.act = nn.LeakyReLU(0.2) if is_activate else nn.SiLU() | |
| self.conv1 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True) | |
| self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True) | |
| self.gamma = nn.Sequential( | |
| conv1x1(1, out_c), | |
| self.act, | |
| conv1x1(out_c, out_c), | |
| ) | |
| self.beta = nn.Sequential( | |
| self.act, | |
| conv1x1(out_c, out_c), | |
| ) | |
| if in_c != out_c: | |
| self.short_cut = nn.Sequential( | |
| conv1x1(in_c, out_c) | |
| ) | |
| else: | |
| self.short_cut = nn.Sequential(OrderedDict([])) | |
| def forward(self, x): | |
| x = self.short_cut(x) | |
| z = self.act(x) | |
| z = self.conv1(z) | |
| z = self.act(z) | |
| z = self.conv2(z) | |
| z += x | |
| return z | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, in_c, out_c, is_activate=True): | |
| super(ResidualBlock, self).__init__() | |
| self.block = nn.Sequential( | |
| convWithBN(in_c, out_c, kernel_size=3, padding=1, stride=1, is_bn=False), | |
| convWithBN(out_c, out_c, kernel_size=3, padding=1, stride=1, is_activate=False, is_bn=False) | |
| ) | |
| if in_c != out_c: | |
| self.short_cut = nn.Sequential( | |
| convWithBN(in_c, out_c, kernel_size=1, padding=0, stride=1, is_activate=False, is_bn=False) | |
| ) | |
| else: | |
| self.short_cut = nn.Sequential(OrderedDict([])) | |
| self.activation = nn.LeakyReLU(0.2, inplace=False) if is_activate else nn.Sequential() | |
| def forward(self, x): | |
| output = self.block(x) | |
| output = self.activation(output) | |
| output += self.short_cut(x) | |
| return output | |
| class ChannelAttention(nn.Module): | |
| def __init__(self, in_planes, ratio=16): | |
| super().__init__() | |
| self.in_nc = in_planes | |
| self.ratio = ratio | |
| self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
| self.max_pool = nn.AdaptiveMaxPool2d(1) | |
| self.sharedMLP = nn.Sequential( | |
| nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), nn.ReLU(), | |
| nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| avgout = self.sharedMLP(self.avg_pool(x)) | |
| maxout = self.sharedMLP(self.max_pool(x)) | |
| return self.sigmoid(avgout + maxout) | |
| class SpatialAttention(nn.Module): | |
| def __init__(self, kernel_size=3): | |
| super().__init__() | |
| self.conv = nn.Conv2d(2,1,kernel_size, padding=1, bias=False) | |
| self.sigmoid = nn.Sigmoid() | |
| self.concat = Concat() | |
| self.mean = torch.mean | |
| self.max = torch.max | |
| def forward(self, x): | |
| avgout = self.mean(x, 1, True) | |
| maxout, _ = self.max(x, 1, True) | |
| x = self.concat([avgout, maxout], 1) | |
| x = self.conv(x) | |
| return self.sigmoid(x) | |
| class CBAM(nn.Module): | |
| def __init__(self, planes): | |
| super().__init__() | |
| self.ca = ChannelAttention(planes) | |
| self.sa = SpatialAttention() | |
| def forward(self, x): | |
| x = self.ca(x) * x | |
| out = self.sa(x) * x | |
| return out | |
| class MaskMul(nn.Module): | |
| def __init__(self, scale_factor=1): | |
| super().__init__() | |
| self.scale_factor = scale_factor | |
| def forward(self, x, mask): | |
| if mask.shape[1] != x.shape[1]: | |
| mask = torch.mean(mask, dim=1, keepdim=True) | |
| pooled_mask = F.avg_pool2d(mask, self.scale_factor) | |
| out = torch.mul(x, pooled_mask) | |
| return out | |
| class UpsampleBLock(nn.Module): | |
| def __init__(self, in_channels, out_channels=None, up_scale=2, mode='bilinear'): | |
| super(UpsampleBLock, self).__init__() | |
| if mode == 'pixel_shuffle': | |
| self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1) | |
| self.up = nn.PixelShuffle(up_scale) | |
| elif mode=='bilinear': | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) | |
| self.up = nn.UpsamplingBilinear2d(scale_factor=up_scale) | |
| else: | |
| print(f"Please tell me what is '{mode}' mode ????") | |
| raise NotImplementedError | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| x = self.up(x) | |
| x = self.relu(x) | |
| return x | |
| def pixel_unshuffle(input, downscale_factor): | |
| ''' | |
| input: batchSize * c * k*w * k*h | |
| kdownscale_factor: k | |
| batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h | |
| ''' | |
| c = input.shape[1] | |
| kernel = torch.zeros(size=[downscale_factor * downscale_factor * c, | |
| 1, downscale_factor, downscale_factor], | |
| device=input.device) | |
| for y in range(downscale_factor): | |
| for x in range(downscale_factor): | |
| kernel[x + y * downscale_factor::downscale_factor*downscale_factor, 0, y, x] = 1 | |
| return F.conv2d(input, kernel, stride=downscale_factor, groups=c) | |
| class PixelUnshuffle(nn.Module): | |
| def __init__(self, downscale_factor): | |
| super(PixelUnshuffle, self).__init__() | |
| self.downscale_factor = downscale_factor | |
| def forward(self, input): | |
| ''' | |
| input: batchSize * c * k*w * k*h | |
| kdownscale_factor: k | |
| batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h | |
| ''' | |
| return pixel_unshuffle(input, self.downscale_factor) | |
| class Concat(nn.Module): | |
| def __init__(self, dim=1): | |
| super().__init__() | |
| self.dim = 1 | |
| self.concat = torch.cat | |
| def padding(self, tensors): | |
| if len(tensors) > 2: | |
| return tensors | |
| x , y = tensors | |
| xb, xc, xh, xw = x.size() | |
| yb, yc, yh, yw = y.size() | |
| diffY = xh - yh | |
| diffX = xw - yw | |
| y = F.pad(y, (diffX // 2, diffX - diffX//2, | |
| diffY // 2, diffY - diffY//2)) | |
| return (x, y) | |
| def forward(self, x, dim=None): | |
| x = self.padding(x) | |
| return self.concat(x, dim if dim is not None else self.dim) | |
| # if __name__ == '__main__': | |
| # from torchsummary import summary | |
| # x = torch.randn((1,32,16,16)) | |
| # for k in range(1,3): | |
| # # up = upsample(32, 2**k) | |
| # # down = downsample(32//(2**k), 2**k) | |
| # # x_up = up(x) | |
| # # x_down = down(x_up) | |
| # # s_up = (32,16,16) | |
| # # summary(up,s,device='cpu') | |
| # # summary(down,s,device='cpu') | |
| # print(k) | |