Spaces:
Running
Running
| import copy | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .abinet_decoder import PositionAttention | |
| from .nrtr_decoder import PositionalEncoding, TransformerBlock | |
| class Trans(nn.Module): | |
| def __init__(self, dim, nhead, dim_feedforward, dropout, num_layers): | |
| super().__init__() | |
| self.d_model = dim | |
| self.nhead = nhead | |
| self.pos_encoder = PositionalEncoding(dropout=0.0, | |
| dim=self.d_model, | |
| max_len=512) | |
| self.transformer = nn.ModuleList([ | |
| TransformerBlock( | |
| dim, | |
| nhead, | |
| dim_feedforward, | |
| attention_dropout_rate=dropout, | |
| residual_dropout_rate=dropout, | |
| with_self_attn=True, | |
| with_cross_attn=False, | |
| ) for i in range(num_layers) | |
| ]) | |
| def forward(self, feature, attn_map=None, use_mask=False): | |
| n, c, h, w = feature.shape | |
| feature = feature.flatten(2).transpose(1, 2) | |
| if use_mask: | |
| _, t, h, w = attn_map.shape | |
| location_mask = (attn_map.view(n, t, -1).transpose(1, 2) > | |
| 0.05).type(torch.float) # n,hw,t | |
| location_mask = location_mask.bmm(location_mask.transpose( | |
| 1, 2)) # n, hw, hw | |
| location_mask = location_mask.new_zeros( | |
| (h * w, h * w)).masked_fill(location_mask > 0, float('-inf')) | |
| location_mask = location_mask.unsqueeze(1) # n, 1, hw, hw | |
| else: | |
| location_mask = None | |
| feature = self.pos_encoder(feature) | |
| for layer in self.transformer: | |
| feature = layer(feature, self_mask=location_mask) | |
| feature = feature.transpose(1, 2).view(n, c, h, w) | |
| return feature, location_mask | |
| def _get_clones(module, N): | |
| return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) | |
| class LPVDecoder(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| num_layer=3, | |
| max_len=25, | |
| use_mask=False, | |
| dim_feedforward=1024, | |
| nhead=8, | |
| dropout=0.1, | |
| trans_layer=2): | |
| super().__init__() | |
| self.use_mask = use_mask | |
| self.max_len = max_len | |
| attn_layer = PositionAttention(max_length=max_len + 1, | |
| mode='nearest', | |
| in_channels=in_channels, | |
| num_channels=in_channels // 8) | |
| trans_layer = Trans(dim=in_channels, | |
| nhead=nhead, | |
| dim_feedforward=dim_feedforward, | |
| dropout=dropout, | |
| num_layers=trans_layer) | |
| cls_layer = nn.Linear(in_channels, out_channels - 2) | |
| self.attention = _get_clones(attn_layer, num_layer) | |
| self.trans = _get_clones(trans_layer, num_layer - 1) | |
| self.cls = _get_clones(cls_layer, num_layer) | |
| def forward(self, x, data=None): | |
| if data is not None: | |
| max_len = data[1].max() | |
| else: | |
| max_len = self.max_len | |
| features = x # (N, E, H, W) | |
| attn_vecs, attn_scores_map = self.attention[0](features) | |
| attn_vecs = attn_vecs[:, :max_len + 1, :] | |
| if not self.training: | |
| for i in range(1, len(self.attention)): | |
| features, mask = self.trans[i - 1](features, | |
| attn_scores_map, | |
| use_mask=self.use_mask) | |
| attn_vecs, attn_scores_map = self.attention[i]( | |
| features, attn_vecs) # (N, T, E), (N, T, H, W) | |
| return F.softmax(self.cls[-1](attn_vecs), -1) | |
| else: | |
| logits = [] | |
| logit = self.cls[0](attn_vecs) # (N, T, C) | |
| logits.append(logit) | |
| for i in range(1, len(self.attention)): | |
| features, mask = self.trans[i - 1](features, | |
| attn_scores_map, | |
| use_mask=self.use_mask) | |
| attn_vecs, attn_scores_map = self.attention[i]( | |
| features, attn_vecs) # (N, T, E), (N, T, H, W) | |
| logit = self.cls[i](attn_vecs) # (N, T, C) | |
| logits.append(logit) | |
| return logits | |