Spaces:
Running
Running
| """This code is refer from: | |
| https://github.com/jjwei66/BUSNet | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .nrtr_decoder import PositionalEncoding, TransformerBlock | |
| from .abinet_decoder import _get_mask, _get_length | |
| class BUSDecoder(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| nhead=8, | |
| num_layers=4, | |
| dim_feedforward=2048, | |
| dropout=0.1, | |
| max_length=25, | |
| ignore_index=100, | |
| pretraining=False, | |
| detach=True): | |
| super().__init__() | |
| d_model = in_channels | |
| self.ignore_index = ignore_index | |
| self.pretraining = pretraining | |
| self.d_model = d_model | |
| self.detach = detach | |
| self.max_length = max_length + 1 # additional stop token | |
| self.out_channels = out_channels | |
| # -------------------------------------------------------------------------- | |
| # decoder specifics | |
| self.proj = nn.Linear(out_channels, d_model, False) | |
| self.token_encoder = PositionalEncoding(dropout=0.1, | |
| dim=d_model, | |
| max_len=self.max_length) | |
| self.pos_encoder = PositionalEncoding(dropout=0.1, | |
| dim=d_model, | |
| max_len=self.max_length) | |
| self.decoder = nn.ModuleList([ | |
| TransformerBlock( | |
| d_model=d_model, | |
| nhead=nhead, | |
| dim_feedforward=dim_feedforward, | |
| attention_dropout_rate=dropout, | |
| residual_dropout_rate=dropout, | |
| with_self_attn=False, | |
| with_cross_attn=True, | |
| ) for i in range(num_layers) | |
| ]) | |
| v_mask = torch.empty((1, 1, d_model)) | |
| l_mask = torch.empty((1, 1, d_model)) | |
| self.v_mask = nn.Parameter(v_mask) | |
| self.l_mask = nn.Parameter(l_mask) | |
| torch.nn.init.uniform_(self.v_mask, -0.001, 0.001) | |
| torch.nn.init.uniform_(self.l_mask, -0.001, 0.001) | |
| v_embeding = torch.empty((1, 1, d_model)) | |
| l_embeding = torch.empty((1, 1, d_model)) | |
| self.v_embeding = nn.Parameter(v_embeding) | |
| self.l_embeding = nn.Parameter(l_embeding) | |
| torch.nn.init.uniform_(self.v_embeding, -0.001, 0.001) | |
| torch.nn.init.uniform_(self.l_embeding, -0.001, 0.001) | |
| self.cls = nn.Linear(d_model, out_channels) | |
| def forward_decoder(self, q, x, mask=None): | |
| for decoder_layer in self.decoder: | |
| q = decoder_layer(q, x, cross_mask=mask) | |
| output = q # (N, T, E) | |
| logits = self.cls(output) # (N, T, C) | |
| return logits | |
| def forward(self, img_feat, data=None): | |
| """ | |
| Args: | |
| tokens: (N, T, C) where T is length, N is batch size and C is classes number | |
| lengths: (N,) | |
| """ | |
| img_feat = img_feat + self.v_embeding | |
| B, L, C = img_feat.shape | |
| # -------------------------------------------------------------------------- | |
| # decoder procedure | |
| T = self.max_length | |
| zeros = img_feat.new_zeros((B, T, C)) | |
| zeros_len = img_feat.new_zeros(B) | |
| query = self.pos_encoder(zeros) | |
| # 1. vision decode | |
| v_embed = torch.cat((img_feat, self.l_mask.repeat(B, T, 1)), | |
| dim=1) # v | |
| padding_mask = _get_mask( | |
| self.max_length + zeros_len, | |
| self.max_length) # 对tokens长度以外的padding # B, maxlen maxlen | |
| v_mask = torch.zeros((1, 1, self.max_length, L), | |
| device=img_feat.device).tile([B, 1, 1, | |
| 1]) # maxlen L | |
| mask = torch.cat((v_mask, padding_mask), 3) | |
| v_logits = self.forward_decoder(query, v_embed, mask=mask) | |
| # 2. language decode | |
| if self.training and self.pretraining: | |
| tgt = torch.where(data[0] == self.ignore_index, 0, data[0]) | |
| tokens = F.one_hot(tgt, num_classes=self.out_channels) | |
| tokens = tokens.float() | |
| lengths = data[-1] | |
| else: | |
| tokens = torch.softmax(v_logits, dim=-1) | |
| lengths = _get_length(v_logits) | |
| tokens = tokens.detach() | |
| token_embed = self.proj(tokens) # (N, T, E) | |
| token_embed = self.token_encoder(token_embed) # (T, N, E) | |
| token_embed = token_embed + self.l_embeding | |
| padding_mask = _get_mask(lengths, | |
| self.max_length) # 对tokens长度以外的padding | |
| mask = torch.cat((v_mask, padding_mask), 3) | |
| l_embed = torch.cat((self.v_mask.repeat(B, L, 1), token_embed), dim=1) | |
| l_logits = self.forward_decoder(query, l_embed, mask=mask) | |
| # 3. vision language decode | |
| vl_embed = torch.cat((img_feat, token_embed), dim=1) | |
| vl_logits = self.forward_decoder(query, vl_embed, mask=mask) | |
| if self.training: | |
| return {'align': [vl_logits], 'lang': l_logits, 'vision': v_logits} | |
| else: | |
| return F.softmax(vl_logits, -1) | |