Spaces:
Running
Running
| """This code is refer from: | |
| https://github.com/MelosY/CAM | |
| """ | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.init import trunc_normal_ | |
| from .convnextv2 import ConvNeXtV2, Block, LayerNorm | |
| from .svtrv2_lnconv_two33 import SVTRv2LNConvTwo33 | |
| class Swish(nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def forward(self, x): | |
| return x * torch.sigmoid(x) | |
| class UNetBlock(nn.Module): | |
| def __init__(self, cin, cout, bn2d, stride, deformable=False): | |
| """ | |
| a UNet block with 2x up sampling | |
| """ | |
| super().__init__() | |
| stride_h, stride_w = stride | |
| if stride_h == 1: | |
| kernel_h = 1 | |
| padding_h = 0 | |
| elif stride_h == 2: | |
| kernel_h = 4 | |
| padding_h = 1 | |
| elif stride_h == 4: | |
| kernel_h = 4 | |
| padding_h = 0 | |
| if stride_w == 1: | |
| kernel_w = 1 | |
| padding_w = 0 | |
| elif stride_w == 2: | |
| kernel_w = 4 | |
| padding_w = 1 | |
| elif stride_w == 4: | |
| kernel_w = 4 | |
| padding_w = 0 | |
| conv = nn.Conv2d | |
| self.up_sample = nn.ConvTranspose2d(cin, | |
| cin, | |
| kernel_size=(kernel_h, kernel_w), | |
| stride=(stride_h, stride_w), | |
| padding=(padding_h, padding_w), | |
| bias=True) | |
| self.conv = nn.Sequential( | |
| conv(cin, cin, kernel_size=3, stride=1, padding=1, bias=False), | |
| bn2d(cin), | |
| nn.ReLU6(inplace=True), | |
| conv(cin, cout, kernel_size=3, stride=1, padding=1, bias=False), | |
| bn2d(cout), | |
| ) | |
| def forward(self, x): | |
| x = self.up_sample(x) | |
| return self.conv(x) | |
| class DepthWiseUNetBlock(nn.Module): | |
| def __init__(self, cin, cout, bn2d, stride, deformable=False): | |
| """ | |
| a UNet block with 2x up sampling | |
| """ | |
| super().__init__() | |
| stride_h, stride_w = stride | |
| if stride_h == 1: | |
| kernel_h = 1 | |
| padding_h = 0 | |
| elif stride_h == 2: | |
| kernel_h = 4 | |
| padding_h = 1 | |
| elif stride_h == 4: | |
| kernel_h = 4 | |
| padding_h = 0 | |
| if stride_w == 1: | |
| kernel_w = 1 | |
| padding_w = 0 | |
| elif stride_w == 2: | |
| kernel_w = 4 | |
| padding_w = 1 | |
| elif stride_w == 4: | |
| kernel_w = 4 | |
| padding_w = 0 | |
| self.up_sample = nn.ConvTranspose2d(cin, | |
| cin, | |
| kernel_size=(kernel_h, kernel_w), | |
| stride=(stride_h, stride_w), | |
| padding=(padding_h, padding_w), | |
| bias=True) | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(cin, | |
| cin, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| groups=cin), | |
| nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0, | |
| bias=False), | |
| bn2d(cin), | |
| nn.ReLU6(inplace=True), | |
| nn.Conv2d(cin, | |
| cin, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| groups=cin), | |
| nn.Conv2d(cin, | |
| cout, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=False), | |
| bn2d(cout), | |
| ) | |
| def forward(self, x): | |
| x = self.up_sample(x) | |
| return self.conv(x) | |
| class SFTLayer(nn.Module): | |
| def __init__(self, dim_in, dim_out): | |
| super(SFTLayer, self).__init__() | |
| self.SFT_scale_conv0 = nn.Linear( | |
| dim_in, | |
| dim_in, | |
| ) | |
| self.SFT_scale_conv1 = nn.Linear( | |
| dim_in, | |
| dim_out, | |
| ) | |
| self.SFT_shift_conv0 = nn.Linear( | |
| dim_in, | |
| dim_in, | |
| ) | |
| self.SFT_shift_conv1 = nn.Linear( | |
| dim_in, | |
| dim_out, | |
| ) | |
| def forward(self, x): | |
| # x[0]: fea; x[1]: cond | |
| scale = self.SFT_scale_conv1( | |
| F.leaky_relu(self.SFT_scale_conv0(x[1]), 0.1, inplace=True)) | |
| shift = self.SFT_shift_conv1( | |
| F.leaky_relu(self.SFT_shift_conv0(x[1]), 0.1, inplace=True)) | |
| return x[0] * (scale + 1) + shift | |
| class MoreUNetBlock(nn.Module): | |
| def __init__(self, cin, cout, bn2d, stride, deformable=False): | |
| """ | |
| a UNet block with 2x up sampling | |
| """ | |
| super().__init__() | |
| stride_h, stride_w = stride | |
| if stride_h == 1: | |
| kernel_h = 1 | |
| padding_h = 0 | |
| elif stride_h == 2: | |
| kernel_h = 4 | |
| padding_h = 1 | |
| elif stride_h == 4: | |
| kernel_h = 4 | |
| padding_h = 0 | |
| if stride_w == 1: | |
| kernel_w = 1 | |
| padding_w = 0 | |
| elif stride_w == 2: | |
| kernel_w = 4 | |
| padding_w = 1 | |
| elif stride_w == 4: | |
| kernel_w = 4 | |
| padding_w = 0 | |
| self.up_sample = nn.ConvTranspose2d(cin, | |
| cin, | |
| kernel_size=(kernel_h, kernel_w), | |
| stride=(stride_h, stride_w), | |
| padding=(padding_h, padding_w), | |
| bias=True) | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(cin, | |
| cin, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| groups=cin), | |
| nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0, | |
| bias=False), bn2d(cin), nn.ReLU6(inplace=True), | |
| nn.Conv2d(cin, | |
| cin, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| groups=cin), | |
| nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0, | |
| bias=False), bn2d(cin), nn.ReLU6(inplace=True), | |
| nn.Conv2d(cin, | |
| cin, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| groups=cin), | |
| nn.Conv2d(cin, | |
| cout, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=False), bn2d(cout), nn.ReLU6(inplace=True), | |
| nn.Conv2d(cout, | |
| cout, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| groups=cout), | |
| nn.Conv2d(cout, | |
| cout, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=False), bn2d(cout)) | |
| def forward(self, x): | |
| x = self.up_sample(x) | |
| return self.conv(x) | |
| class BinaryDecoder(nn.Module): | |
| def __init__(self, | |
| dim, | |
| num_classes, | |
| strides, | |
| use_depthwise_unet=False, | |
| use_more_unet=False, | |
| binary_loss_type='DiceLoss') -> None: | |
| super().__init__() | |
| channels = [dim // 2**i for i in range(4)] | |
| self.linear_enc2binary = nn.Sequential( | |
| nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1), | |
| nn.SyncBatchNorm(dim), | |
| ) | |
| self.strides = strides | |
| self.use_deformable = False | |
| self.binary_decoder = nn.ModuleList() | |
| unet = DepthWiseUNetBlock if use_depthwise_unet else UNetBlock | |
| unet = MoreUNetBlock if use_more_unet else unet | |
| for i in range(3): | |
| up_sample_stride = self.strides[::-1][i] | |
| cin, cout = channels[i], channels[i + 1] | |
| self.binary_decoder.append( | |
| unet(cin, cout, nn.SyncBatchNorm, up_sample_stride, | |
| self.use_deformable)) | |
| last_stride = (self.strides[0][0] // 2, self.strides[0][1] // 2) | |
| self.binary_decoder.append( | |
| unet(cout, cout, nn.SyncBatchNorm, last_stride, | |
| self.use_deformable)) | |
| if binary_loss_type == 'CrossEntropyDiceLoss' or binary_loss_type == 'BanlanceMultiClassCrossEntropyLoss': | |
| segm_num_cls = num_classes - 2 | |
| else: | |
| segm_num_cls = num_classes - 3 | |
| self.binary_pred = nn.Conv2d(channels[-1], | |
| segm_num_cls, | |
| kernel_size=1, | |
| stride=1, | |
| bias=True) | |
| def patchify(self, imgs): | |
| """ | |
| imgs: (N, 3, H, W) | |
| x: (N, L, patch_size**2 *3) | |
| """ | |
| p_h, p_w = self.strides[0] | |
| p_h = p_h // 2 | |
| p_w = p_w // 2 | |
| h = imgs.shape[2] // p_h | |
| w = imgs.shape[3] // p_w | |
| x = imgs.reshape(shape=(imgs.shape[0], 3, h, p_h, w, p_w)) | |
| x = torch.einsum('nchpwq->nhwpqc', x) | |
| x = x.reshape(shape=(imgs.shape[0], h * w, p_h * p_w * 3)) | |
| return x | |
| def unpatchify(self, x): | |
| """ | |
| x: (N, patch_size**2, h, w) | |
| imgs: (N, 3, H, W) | |
| """ | |
| p_h, p_w = self.strides[0] | |
| p_h = p_h // 2 | |
| p_w = p_w // 2 | |
| _, _, h, w = x.shape | |
| assert p_h * p_w == x.shape[1] | |
| x = x.permute(0, 2, 3, 1) # [N, h, w, 4*4] | |
| x = x.reshape(shape=(x.shape[0], h, w, p_h, p_w)) | |
| x = torch.einsum('nhwpq->nhpwq', x) | |
| imgs = x.reshape(shape=(x.shape[0], h * p_h, w * p_w)) | |
| return imgs | |
| def forward(self, x, time=None): | |
| """ | |
| x: the encoder feat to init the query for binary prediction, usually this is equal to the `img`. | |
| img: the encoder feat. | |
| txt: the unnormmed text to get the length of predicted words. | |
| txt_feat: the text feat before character prediction. | |
| xs: the encoder feat from different stages | |
| """ | |
| binary_feats = [] | |
| x = self.linear_enc2binary(x) | |
| binary_feats.append(x.clone()) | |
| for i, d in enumerate(self.binary_decoder): | |
| x = d(x) | |
| binary_feats.append(x.clone()) | |
| #return None,binary_feats | |
| x = self.binary_pred(x) | |
| if self.training: | |
| return x, binary_feats | |
| else: | |
| # return torch.sigmoid(x), binary_feat | |
| return x.softmax(1), binary_feats | |
| class LayerNormProxy(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.norm = nn.LayerNorm(dim) | |
| def forward(self, x): | |
| x = x.permute(0, 2, 3, 1) | |
| x = self.norm(x) | |
| return x.permute(0, 3, 1, 2) | |
| class DAttentionFuse(nn.Module): | |
| def __init__( | |
| self, | |
| q_size=(4, 32), | |
| kv_size=(4, 32), | |
| n_heads=8, | |
| n_head_channels=80, | |
| n_groups=4, | |
| attn_drop=0.0, | |
| proj_drop=0.0, | |
| stride=2, | |
| offset_range_factor=2, | |
| use_pe=True, | |
| stage_idx=0, | |
| ): | |
| ''' | |
| stage_idx from 2 to 3 | |
| ''' | |
| super().__init__() | |
| self.n_head_channels = n_head_channels | |
| self.scale = self.n_head_channels**-0.5 | |
| self.n_heads = n_heads | |
| self.q_h, self.q_w = q_size | |
| self.kv_h, self.kv_w = kv_size | |
| self.nc = n_head_channels * n_heads | |
| self.n_groups = n_groups | |
| self.n_group_channels = self.nc // self.n_groups | |
| self.n_group_heads = self.n_heads // self.n_groups | |
| self.use_pe = use_pe | |
| self.offset_range_factor = offset_range_factor | |
| ksizes = [9, 7, 5, 3] | |
| kk = ksizes[stage_idx] | |
| self.conv_offset = nn.Sequential( | |
| nn.Conv2d(2 * self.n_group_channels, | |
| 2 * self.n_group_channels, | |
| kk, | |
| stride, | |
| kk // 2, | |
| groups=self.n_group_channels), | |
| LayerNormProxy(2 * self.n_group_channels), nn.GELU(), | |
| nn.Conv2d(2 * self.n_group_channels, 2, 1, 1, 0, bias=False)) | |
| self.proj_q = nn.Conv2d(self.nc, | |
| self.nc, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0) | |
| self.proj_k = nn.Conv2d(self.nc, | |
| self.nc, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0) | |
| self.proj_v = nn.Conv2d(self.nc, | |
| self.nc, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0) | |
| self.proj_out = nn.Conv2d(self.nc, | |
| self.nc, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0) | |
| self.proj_drop = nn.Dropout(proj_drop, inplace=True) | |
| self.attn_drop = nn.Dropout(attn_drop, inplace=True) | |
| if self.use_pe: | |
| self.rpe_table = nn.Parameter( | |
| torch.zeros(self.n_heads, self.kv_h * 2 - 1, | |
| self.kv_w * 2 - 1)) | |
| trunc_normal_(self.rpe_table, std=0.01) | |
| else: | |
| self.rpe_table = None | |
| def _get_ref_points(self, H_key, W_key, B, dtype, device): | |
| ref_y, ref_x = torch.meshgrid( | |
| torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, | |
| device=device), | |
| torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, | |
| device=device)) | |
| ref = torch.stack((ref_y, ref_x), -1) | |
| ref[..., 1].div_(W_key).mul_(2).sub_(1) | |
| ref[..., 0].div_(H_key).mul_(2).sub_(1) | |
| ref = ref[None, ...].expand(B * self.n_groups, -1, -1, | |
| -1) # B * g H W 2 | |
| return ref | |
| def forward(self, x, y): | |
| B, C, H, W = x.size() | |
| dtype, device = x.dtype, x.device | |
| q_off = torch.cat( | |
| (x, y), dim=1 | |
| ).reshape(B, self.n_groups, 2 * self.n_group_channels, H, W).flatten( | |
| 0, 1 | |
| ) #einops.rearrange(torch.cat((x,y),dim=1), 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=2*self.n_group_channels) | |
| offset = self.conv_offset(q_off) # B * g 2 Hg Wg | |
| Hk, Wk = offset.size(2), offset.size(3) | |
| n_sample = Hk * Wk | |
| if self.offset_range_factor > 0: | |
| offset_range = torch.tensor([1.0 / Hk, 1.0 / Wk], | |
| device=device).reshape(1, 2, 1, 1) | |
| offset = offset.tanh().mul(offset_range).mul( | |
| self.offset_range_factor) | |
| offset = offset.permute( | |
| 0, 2, 3, 1) #einops.rearrange(offset, 'b p h w -> b h w p') | |
| reference = self._get_ref_points(Hk, Wk, B, dtype, device) | |
| if self.offset_range_factor >= 0: | |
| pos = offset + reference | |
| else: | |
| pos = (offset + reference).tanh() | |
| q = self.proj_q(y) | |
| x_sampled = F.grid_sample( | |
| input=x.reshape(B * self.n_groups, self.n_group_channels, H, W), | |
| grid=pos[..., (1, 0)], # y, x -> x, y | |
| mode='bilinear', | |
| align_corners=False) # B * g, Cg, Hg, Wg | |
| x_sampled = x_sampled.reshape(B, C, 1, n_sample) | |
| q = q.reshape(B * self.n_heads, self.n_head_channels, H * W) | |
| k = self.proj_k(x_sampled).reshape(B * self.n_heads, | |
| self.n_head_channels, n_sample) | |
| v = self.proj_v(x_sampled).reshape(B * self.n_heads, | |
| self.n_head_channels, n_sample) | |
| attn = torch.einsum('b c m, b c n -> b m n', q, k) # B * h, HW, Ns | |
| attn = attn.mul(self.scale) | |
| if self.use_pe: | |
| rpe_table = self.rpe_table | |
| rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1) | |
| q_grid = self._get_ref_points(H, W, B, dtype, device) | |
| displacement = ( | |
| q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - | |
| pos.reshape(B * self.n_groups, n_sample, | |
| 2).unsqueeze(1)).mul(0.5) | |
| attn_bias = F.grid_sample(input=rpe_bias.reshape( | |
| B * self.n_groups, self.n_group_heads, 2 * H - 1, 2 * W - 1), | |
| grid=displacement[..., (1, 0)], | |
| mode='bilinear', | |
| align_corners=False) | |
| attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample) | |
| attn = attn + attn_bias | |
| attn = F.softmax(attn, dim=2) | |
| attn = self.attn_drop(attn) | |
| out = torch.einsum('b m n, b c n -> b c m', attn, v) | |
| out = out.reshape(B, C, H, W) | |
| out = self.proj_drop(self.proj_out(out)) | |
| return out, pos.reshape(B, self.n_groups, Hk, Wk, | |
| 2), reference.reshape(B, self.n_groups, Hk, Wk, | |
| 2) | |
| class FuseModel(nn.Module): | |
| def __init__(self, | |
| dim, | |
| deform_stride=2, | |
| stage_idx=2, | |
| k_size=[(2, 2), (2, 1), (2, 1), (1, 1)], | |
| q_size=(2, 32)): | |
| super().__init__() | |
| channels = [dim // 2**i for i in range(4)] | |
| refine_conv = nn.Conv2d | |
| self.deform_stride = deform_stride | |
| in_out_ch = [(-1, -2), (-2, -3), (-3, -4), (-4, -4)] | |
| self.binary_condition_layer = DAttentionFuse(q_size=q_size, | |
| kv_size=q_size, | |
| stride=self.deform_stride, | |
| n_head_channels=dim // 8, | |
| stage_idx=stage_idx) | |
| self.binary2refine_linear_norm = nn.ModuleList() | |
| for i in range(len(k_size)): | |
| self.binary2refine_linear_norm.append( | |
| nn.Sequential( | |
| Block(dim=channels[in_out_ch[i][0]]), | |
| LayerNorm(channels[in_out_ch[i][0]], | |
| eps=1e-6, | |
| data_format='channels_first'), | |
| refine_conv(channels[in_out_ch[i][0]], | |
| channels[in_out_ch[i][1]], | |
| kernel_size=k_size[i], | |
| stride=k_size[i])), # [8, 32] | |
| ) | |
| def forward(self, recog_feat, binary_feats, dec_in=None): | |
| multi_feat = [] | |
| binary_feat = binary_feats[-1] | |
| for i in range(len(self.binary2refine_linear_norm)): | |
| binary_feat = self.binary2refine_linear_norm[i](binary_feat) | |
| multi_feat.append(binary_feat) | |
| binary_feat = binary_feat + binary_feats[0] | |
| multi_feat[3] += binary_feats[0] | |
| binary_refined_feat, pos, _ = self.binary_condition_layer( | |
| recog_feat, binary_feat) | |
| binary_refined_feat = binary_refined_feat + binary_feat | |
| return binary_refined_feat, binary_feat | |
| class CAMEncoder(nn.Module): | |
| """ | |
| Args: | |
| in_chans (int): Number of input image channels. Default: 3 | |
| depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] | |
| dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] | |
| drop_path_rate (float): Stochastic depth rate. Default: 0. | |
| head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. | |
| """ | |
| def __init__(self, | |
| in_channels=3, | |
| encoder_config={'name': 'ConvNeXtV2'}, | |
| nb_classes=71, | |
| strides=[(4, 4), (2, 1), (2, 1), (1, 1)], | |
| k_size=[(2, 2), (2, 1), (2, 1), (1, 1)], | |
| q_size=[2, 32], | |
| deform_stride=2, | |
| stage_idx=2, | |
| use_depthwise_unet=True, | |
| use_more_unet=False, | |
| binary_loss_type='BanlanceMultiClassCrossEntropyLoss', | |
| mid_size=True, | |
| d_embedding=384): | |
| super().__init__() | |
| encoder_name = encoder_config.pop('name') | |
| encoder_config['in_channels'] = in_channels | |
| self.backbone = eval(encoder_name)(**encoder_config) | |
| dim = self.backbone.out_channels | |
| self.mid_size = mid_size | |
| if self.mid_size: | |
| self.enc_downsample = nn.Sequential( | |
| nn.Conv2d(dim, dim // 2, kernel_size=1, stride=1), | |
| nn.SyncBatchNorm(dim // 2), | |
| #nn.ReLU6(inplace=True), | |
| nn.Conv2d(dim // 2, | |
| dim // 2, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| groups=dim // 2), | |
| nn.Conv2d(dim // 2, | |
| dim // 2, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=False), | |
| nn.SyncBatchNorm(dim // 2), | |
| ) | |
| dim = dim // 2 | |
| # recognition decoder | |
| self.linear_enc2recog = nn.Sequential( | |
| nn.Conv2d( | |
| dim, | |
| dim, | |
| kernel_size=1, | |
| stride=1, | |
| ), | |
| nn.SyncBatchNorm(dim), | |
| #nn.ReLU6(inplace=True), | |
| nn.Conv2d(dim, | |
| dim, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| groups=dim), | |
| nn.Conv2d(dim, | |
| dim, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=False), | |
| nn.SyncBatchNorm(dim), | |
| ) | |
| else: | |
| self.linear_enc2recog = nn.Sequential( | |
| nn.Conv2d(dim, dim // 2, kernel_size=1, stride=1), | |
| nn.SyncBatchNorm(dim // 2), | |
| #nn.ReLU6(inplace=True), | |
| nn.Conv2d(dim // 2, dim, kernel_size=3, stride=1, padding=1), | |
| nn.SyncBatchNorm(dim), | |
| ) | |
| self.linear_norm = nn.Sequential( | |
| nn.Linear(dim, d_embedding), | |
| nn.LayerNorm(d_embedding, eps=1e-6), | |
| ) | |
| self.out_channels = d_embedding | |
| self.binary_decoder = BinaryDecoder( | |
| dim, | |
| nb_classes, | |
| strides, | |
| use_depthwise_unet=use_depthwise_unet, | |
| use_more_unet=use_more_unet, | |
| binary_loss_type=binary_loss_type) | |
| self.fuse_model = FuseModel(dim, | |
| deform_stride=deform_stride, | |
| stage_idx=stage_idx, | |
| k_size=k_size, | |
| q_size=q_size) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, (nn.Conv2d, nn.Linear)): | |
| trunc_normal_(m.weight, std=.02) | |
| if isinstance(m, (nn.Conv2d, nn.Linear)) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| if isinstance(m, nn.ConvTranspose2d): | |
| nn.init.kaiming_normal_(m.weight, | |
| mode='fan_out', | |
| nonlinearity='relu') | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0.) | |
| elif isinstance(m, nn.LayerNorm): | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| if m.weight is not None: | |
| nn.init.constant_(m.weight, 1.0) | |
| elif isinstance(m, nn.SyncBatchNorm): | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| if m.weight is not None: | |
| nn.init.constant_(m.weight, 1.0) | |
| elif isinstance(m, nn.BatchNorm2d): | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| if m.weight is not None: | |
| nn.init.constant_(m.weight, 1.0) | |
| def no_weight_decay(self): | |
| return {} | |
| def forward(self, x): | |
| output = {} | |
| enc_feat = self.backbone(x) | |
| if self.mid_size: | |
| enc_feat = self.enc_downsample(enc_feat) | |
| output['enc_feat'] = enc_feat | |
| # binary mask | |
| pred_binary, binary_feats = self.binary_decoder(enc_feat) | |
| output['pred_binary'] = pred_binary | |
| reg_feat = self.linear_enc2recog(enc_feat) | |
| B, C, H, W = reg_feat.shape | |
| last_feat, binary_feat = self.fuse_model(reg_feat, binary_feats) | |
| dec_in = last_feat.reshape(B, C, H * W).permute(0, 2, 1) | |
| dec_in = self.linear_norm(dec_in) | |
| output['refined_feat'] = dec_in | |
| output['binary_feat'] = binary_feats[-1] | |
| return output | |