Spaces:
Running
Running
| """ | |
| This file only for testing mask regularzation. | |
| If it works, it will be merged with `layers.py`. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class AADLayer(nn.Module): | |
| def __init__(self, c_x, attr_c, c_id=256): | |
| super(AADLayer, self).__init__() | |
| self.attr_c = attr_c | |
| self.c_id = c_id | |
| self.c_x = c_x | |
| self.conv1 = nn.Conv2d( | |
| attr_c, c_x, kernel_size=1, stride=1, padding=0, bias=True | |
| ) | |
| self.conv2 = nn.Conv2d( | |
| attr_c, c_x, kernel_size=1, stride=1, padding=0, bias=True | |
| ) | |
| self.fc1 = nn.Linear(c_id, c_x) | |
| self.fc2 = nn.Linear(c_id, c_x) | |
| self.norm = nn.InstanceNorm2d(c_x, affine=False) | |
| self.conv_h = nn.Conv2d(c_x, 1, kernel_size=1, stride=1, padding=0, bias=True) | |
| def forward(self, h_in, z_attr, z_id): | |
| # h_in cxnxn | |
| # zid 256x1x1 | |
| # zattr cxnxn | |
| h = self.norm(h_in) | |
| gamma_attr = self.conv1(z_attr) | |
| beta_attr = self.conv2(z_attr) | |
| gamma_id = self.fc1(z_id) | |
| beta_id = self.fc2(z_id) | |
| A = gamma_attr * h + beta_attr | |
| gamma_id = gamma_id.reshape(h.shape[0], self.c_x, 1, 1).expand_as(h) | |
| beta_id = beta_id.reshape(h.shape[0], self.c_x, 1, 1).expand_as(h) | |
| I = gamma_id * h + beta_id | |
| M = torch.sigmoid(self.conv_h(h)) | |
| out = (torch.ones_like(M).to(M.device) - M) * A + M * I | |
| return out, torch.mean(torch.ones_like(M).to(M.device) - M, dim=[1, 2, 3]) | |
| class AAD_ResBlk(nn.Module): | |
| def __init__(self, cin, cout, c_attr, c_id=256): | |
| super(AAD_ResBlk, self).__init__() | |
| self.cin = cin | |
| self.cout = cout | |
| self.AAD1 = AADLayer(cin, c_attr, c_id) | |
| self.conv1 = nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=False) | |
| self.relu1 = nn.ReLU(inplace=True) | |
| self.AAD2 = AADLayer(cin, c_attr, c_id) | |
| self.conv2 = nn.Conv2d( | |
| cin, cout, kernel_size=3, stride=1, padding=1, bias=False | |
| ) | |
| self.relu2 = nn.ReLU(inplace=True) | |
| if cin != cout: | |
| self.AAD3 = AADLayer(cin, c_attr, c_id) | |
| self.conv3 = nn.Conv2d( | |
| cin, cout, kernel_size=3, stride=1, padding=1, bias=False | |
| ) | |
| self.relu3 = nn.ReLU(inplace=True) | |
| def forward(self, h, z_attr, z_id): | |
| x, m1_ = self.AAD1(h, z_attr, z_id) | |
| x = self.relu1(x) | |
| x = self.conv1(x) | |
| x, m2_ = self.AAD2(x, z_attr, z_id) | |
| x = self.relu2(x) | |
| x = self.conv2(x) | |
| m = m1_ + m2_ | |
| if self.cin != self.cout: | |
| h, m3_ = self.AAD3(h, z_attr, z_id) | |
| h = self.relu3(h) | |
| h = self.conv3(h) | |
| m += m3_ | |
| x = x + h | |
| return x, m | |
| def weight_init(m): | |
| if isinstance(m, nn.Linear): | |
| m.weight.data.normal_(0, 0.001) | |
| m.bias.data.zero_() | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.xavier_normal_(m.weight.data) | |
| if isinstance(m, nn.ConvTranspose2d): | |
| nn.init.xavier_normal_(m.weight.data) | |
| def conv4x4(in_c, out_c, norm=nn.BatchNorm2d): | |
| return nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=in_c, | |
| out_channels=out_c, | |
| kernel_size=4, | |
| stride=2, | |
| padding=1, | |
| bias=False, | |
| ), | |
| norm(out_c), | |
| nn.LeakyReLU(0.1, inplace=True), | |
| ) | |
| class deconv4x4(nn.Module): | |
| def __init__(self, in_c, out_c, norm=nn.BatchNorm2d): | |
| super(deconv4x4, self).__init__() | |
| self.deconv = nn.ConvTranspose2d( | |
| in_channels=in_c, | |
| out_channels=out_c, | |
| kernel_size=4, | |
| stride=2, | |
| padding=1, | |
| bias=False, | |
| ) | |
| self.bn = norm(out_c) | |
| self.lrelu = nn.LeakyReLU(0.1, inplace=True) | |
| def forward(self, input, skip): | |
| x = self.deconv(input) | |
| x = self.bn(x) | |
| x = self.lrelu(x) | |
| return torch.cat((x, skip), dim=1) | |
| class MLAttrEncoder(nn.Module): | |
| def __init__(self, finetune=False, downup=False): | |
| super(MLAttrEncoder, self).__init__() | |
| self.downup = downup | |
| if self.downup: | |
| self.conv00 = conv4x4(3, 16) | |
| self.conv01 = conv4x4(16, 32) | |
| self.deconv7 = deconv4x4(64, 16) | |
| self.conv1 = conv4x4(3, 32) | |
| self.conv2 = conv4x4(32, 64) | |
| self.conv3 = conv4x4(64, 128) | |
| self.conv4 = conv4x4(128, 256) | |
| self.conv5 = conv4x4(256, 512) | |
| self.conv6 = conv4x4(512, 1024) | |
| self.conv7 = conv4x4(1024, 1024) | |
| self.deconv1 = deconv4x4(1024, 1024) | |
| self.deconv2 = deconv4x4(2048, 512) | |
| self.deconv3 = deconv4x4(1024, 256) | |
| self.deconv4 = deconv4x4(512, 128) | |
| self.deconv5 = deconv4x4(256, 64) | |
| self.deconv6 = deconv4x4(128, 32) | |
| self.apply(weight_init) | |
| self.finetune = finetune | |
| if finetune: | |
| for name, param in self.named_parameters(): | |
| param.requires_grad = False | |
| if self.downup: | |
| self.conv00.requires_grad_(True) | |
| self.conv01.requires_grad_(True) | |
| self.deconv7.requires_grad_(True) | |
| def forward(self, Xt): | |
| if self.downup: | |
| feat0 = self.conv00(Xt) # (16,256,256) | |
| feat1 = self.conv01(feat0) # (32,128,128) | |
| else: | |
| feat0 = None | |
| feat1 = self.conv1(Xt) | |
| # 32x128x128 | |
| feat2 = self.conv2(feat1) | |
| # 64x64x64 | |
| feat3 = self.conv3(feat2) | |
| # 128x32x32 | |
| feat4 = self.conv4(feat3) | |
| # 256x16xx16 | |
| feat5 = self.conv5(feat4) | |
| # 512x8x8 | |
| feat6 = self.conv6(feat5) | |
| # 1024x4x4 | |
| if self.downup: | |
| z_attr1 = self.conv7(feat6) | |
| # 1024x2x2 | |
| z_attr2 = self.deconv1(z_attr1, feat6) | |
| z_attr3 = self.deconv2(z_attr2, feat5) | |
| z_attr4 = self.deconv3(z_attr3, feat4) | |
| z_attr5 = self.deconv4(z_attr4, feat3) | |
| z_attr6 = self.deconv5(z_attr5, feat2) | |
| z_attr7 = self.deconv6(z_attr6, feat1) # (128,64,64)+(32,128,128)->(64,128,128) | |
| z_attr8 = self.deconv7(z_attr7, feat0) # (64,128,128)+(16,256,256)->(32,256,256) | |
| z_attr9 = F.interpolate( | |
| z_attr8, scale_factor=2, mode="bilinear", align_corners=True | |
| ) # (32,512,512) | |
| return ( | |
| z_attr1, | |
| z_attr2, | |
| z_attr3, | |
| z_attr4, | |
| z_attr5, | |
| z_attr6, | |
| z_attr7, | |
| z_attr8, | |
| z_attr9 | |
| ) | |
| else: | |
| z_attr1 = self.conv7(feat6) | |
| # 1024x2x2 | |
| z_attr2 = self.deconv1(z_attr1, feat6) | |
| z_attr3 = self.deconv2(z_attr2, feat5) | |
| z_attr4 = self.deconv3(z_attr3, feat4) | |
| z_attr5 = self.deconv4(z_attr4, feat3) | |
| z_attr6 = self.deconv5(z_attr5, feat2) | |
| z_attr7 = self.deconv6(z_attr6, feat1) | |
| z_attr8 = F.interpolate( | |
| z_attr7, scale_factor=2, mode="bilinear", align_corners=True | |
| ) | |
| return ( | |
| z_attr1, | |
| z_attr2, | |
| z_attr3, | |
| z_attr4, | |
| z_attr5, | |
| z_attr6, | |
| z_attr7, | |
| z_attr8, | |
| ) | |
| class AADGenerator(nn.Module): | |
| def __init__(self, c_id=256, finetune=False, downup=False): | |
| super(AADGenerator, self).__init__() | |
| self.up1 = nn.ConvTranspose2d(c_id, 1024, kernel_size=2, stride=1, padding=0) | |
| self.AADBlk1 = AAD_ResBlk(1024, 1024, 1024, c_id) | |
| self.AADBlk2 = AAD_ResBlk(1024, 1024, 2048, c_id) | |
| self.AADBlk3 = AAD_ResBlk(1024, 1024, 1024, c_id) | |
| self.AADBlk4 = AAD_ResBlk(1024, 512, 512, c_id) | |
| self.AADBlk5 = AAD_ResBlk(512, 256, 256, c_id) | |
| self.AADBlk6 = AAD_ResBlk(256, 128, 128, c_id) | |
| self.AADBlk7 = AAD_ResBlk(128, 64, 64, c_id) | |
| self.AADBlk8 = AAD_ResBlk(64, 3, 64, c_id) | |
| self.downup = downup | |
| if downup: | |
| self.AADBlk8_0 = AAD_ResBlk(64, 32, 32, c_id) | |
| self.AADBlk8_1 = AAD_ResBlk(32, 3, 32, c_id) | |
| self.apply(weight_init) | |
| if finetune: | |
| for name, param in self.named_parameters(): | |
| param.requires_grad = False | |
| self.AADBlk8_0.requires_grad_(True) | |
| self.AADBlk8_1.requires_grad_(True) | |
| def forward(self, z_attr, z_id): | |
| m = self.up1(z_id.reshape(z_id.shape[0], -1, 1, 1)) | |
| scale= z_attr[0].shape[2] // 2 # adaptive support for 512x512, 1024x1024 | |
| m = F.interpolate(m, scale_factor=scale, mode='bilinear', align_corners=True) | |
| m2, m2_ = self.AADBlk1(m, z_attr[0], z_id) | |
| m2 = F.interpolate( | |
| m2, | |
| scale_factor=2, | |
| mode="bilinear", | |
| align_corners=True, | |
| ) | |
| m3, m3_ = self.AADBlk2(m2, z_attr[1], z_id) | |
| m3 = F.interpolate( | |
| m3, | |
| scale_factor=2, | |
| mode="bilinear", | |
| align_corners=True, | |
| ) | |
| m4, m4_ = self.AADBlk3(m3, z_attr[2], z_id) | |
| m4 = F.interpolate( | |
| m4, | |
| scale_factor=2, | |
| mode="bilinear", | |
| align_corners=True, | |
| ) | |
| m5, m5_ = self.AADBlk4(m4, z_attr[3], z_id) | |
| m5 = F.interpolate( | |
| m5, | |
| scale_factor=2, | |
| mode="bilinear", | |
| align_corners=True, | |
| ) | |
| m6, m6_ = self.AADBlk5(m5, z_attr[4], z_id) | |
| m6 = F.interpolate( | |
| m6, | |
| scale_factor=2, | |
| mode="bilinear", | |
| align_corners=True, | |
| ) | |
| m7, m7_ = self.AADBlk6(m6, z_attr[5], z_id) | |
| m7 = F.interpolate( | |
| m7, | |
| scale_factor=2, | |
| mode="bilinear", | |
| align_corners=True, | |
| ) | |
| m8, m8_ = self.AADBlk7(m7, z_attr[6], z_id) | |
| m8 = F.interpolate( | |
| m8, | |
| scale_factor=2, | |
| mode="bilinear", | |
| align_corners=True, | |
| ) | |
| if self.downup: | |
| y0, m9_ = self.AADBlk8_0(m8, z_attr[7], z_id) | |
| y0 = F.interpolate(y0, scale_factor=2, mode='bilinear', align_corners=True) | |
| y1, m10_ = self.AADBlk8_1(y0, z_attr[8], z_id) | |
| y = torch.tanh(y1) | |
| else: | |
| y, m9_ = self.AADBlk8(m8, z_attr[7], z_id) | |
| y = torch.tanh(y) | |
| return y # , m # yuange | |
| class AEI_Net(nn.Module): | |
| def __init__(self, c_id=512, finetune=False, downup=False): | |
| super(AEI_Net, self).__init__() | |
| self.encoder = MLAttrEncoder(finetune=finetune, downup=downup) | |
| self.generator = AADGenerator(c_id, finetune=finetune, downup=downup) | |
| def forward(self, Xt, z_id): | |
| attr = self.encoder(Xt) | |
| Y = self.generator(attr, z_id) # yuange | |
| return Y, attr | |
| def get_attr(self, X): | |
| return self.encoder(X) | |
| def trainable_params(self): | |
| train_params = [] | |
| for param in self.parameters(): | |
| if param.requires_grad: | |
| train_params.append(param) | |
| return train_params | |
| if __name__ == "__main__": | |
| aie = AEI_Net(512).eval() | |
| x = aie(torch.randn(1, 3, 512, 512), torch.randn(1, 512)) | |
| # def numel(m: torch.nn.Module, only_trainable: bool = False): | |
| # """ | |
| # returns the total number of parameters used by `m` (only counting | |
| # shared parameters once); if `only_trainable` is True, then only | |
| # includes parameters with `requires_grad = True` | |
| # """ | |
| # parameters = list(m.parameters()) | |
| # if only_trainable: | |
| # parameters = [p for p in parameters if p.requires_grad] | |
| # unique = {p.data_ptr(): p for p in parameters}.values() | |
| # return sum(p.numel() for p in unique) | |
| # | |
| # | |
| # print(numel(aie, True)) | |
| # print(x[0].size()) | |
| # print(len(x[-1])) | |
| import thop | |
| img = torch.randn(1, 3, 256, 256) | |
| latent = torch.randn(1, 512) | |
| net = aie | |
| flops, params = thop.profile(net, inputs=(img, latent), verbose=False) | |
| print('#Params=%.2fM, GFLOPS=%.2f' % (params / 1e6, flops / 1e9)) | |