Spaces:
Running
Running
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class RobustScannerDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| out_channels, # 90 + unknown + start + padding | |
| in_channels, | |
| enc_outchannles=128, | |
| hybrid_dec_rnn_layers=2, | |
| hybrid_dec_dropout=0, | |
| position_dec_rnn_layers=2, | |
| max_len=25, | |
| mask=True, | |
| encode_value=False, | |
| **kwargs): | |
| super(RobustScannerDecoder, self).__init__() | |
| start_idx = out_channels - 2 | |
| padding_idx = out_channels - 1 | |
| end_idx = 0 | |
| # encoder module | |
| self.encoder = ChannelReductionEncoder(in_channels=in_channels, | |
| out_channels=enc_outchannles) | |
| self.max_text_length = max_len + 1 | |
| self.mask = mask | |
| # decoder module | |
| self.decoder = Decoder( | |
| num_classes=out_channels, | |
| dim_input=in_channels, | |
| dim_model=enc_outchannles, | |
| hybrid_decoder_rnn_layers=hybrid_dec_rnn_layers, | |
| hybrid_decoder_dropout=hybrid_dec_dropout, | |
| position_decoder_rnn_layers=position_dec_rnn_layers, | |
| max_len=max_len + 1, | |
| start_idx=start_idx, | |
| mask=mask, | |
| padding_idx=padding_idx, | |
| end_idx=end_idx, | |
| encode_value=encode_value) | |
| def forward(self, inputs, data=None): | |
| ''' | |
| data: [label, valid_ratio, 'length'] | |
| ''' | |
| out_enc = self.encoder(inputs) | |
| bs = out_enc.shape[0] | |
| valid_ratios = None | |
| word_positions = torch.arange(0, | |
| self.max_text_length, | |
| device=inputs.device).unsqueeze(0).tile( | |
| [bs, 1]) | |
| if self.mask: | |
| valid_ratios = data[-1] | |
| if self.training: | |
| max_len = data[1].max() | |
| label = data[0][:, :1 + max_len] # label | |
| final_out = self.decoder(inputs, out_enc, label, valid_ratios, | |
| word_positions[:, :1 + max_len]) | |
| if not self.training: | |
| final_out = self.decoder(inputs, | |
| out_enc, | |
| label=None, | |
| valid_ratios=valid_ratios, | |
| word_positions=word_positions, | |
| train_mode=False) | |
| return final_out | |
| class BaseDecoder(nn.Module): | |
| def __init__(self, **kwargs): | |
| super().__init__() | |
| def forward_train(self, feat, out_enc, targets, img_metas): | |
| raise NotImplementedError | |
| def forward_test(self, feat, out_enc, img_metas): | |
| raise NotImplementedError | |
| def forward(self, | |
| feat, | |
| out_enc, | |
| label=None, | |
| valid_ratios=None, | |
| word_positions=None, | |
| train_mode=True): | |
| self.train_mode = train_mode | |
| if train_mode: | |
| return self.forward_train(feat, out_enc, label, valid_ratios, | |
| word_positions) | |
| return self.forward_test(feat, out_enc, valid_ratios, word_positions) | |
| class ChannelReductionEncoder(nn.Module): | |
| """Change the channel number with a one by one convoluational layer. | |
| Args: | |
| in_channels (int): Number of input channels. | |
| out_channels (int): Number of output channels. | |
| """ | |
| def __init__(self, in_channels, out_channels, **kwargs): | |
| super(ChannelReductionEncoder, self).__init__() | |
| weight = torch.nn.Parameter( | |
| torch.nn.init.xavier_normal_(torch.empty(out_channels, in_channels, | |
| 1, 1), | |
| gain=1.0)) | |
| self.layer = nn.Conv2d(in_channels, | |
| out_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0) | |
| use_xavier_normal = 1 | |
| if use_xavier_normal: | |
| self.layer.weight = weight | |
| def forward(self, feat): | |
| """ | |
| Args: | |
| feat (Tensor): Image features with the shape of | |
| :math:`(N, C_{in}, H, W)`. | |
| Returns: | |
| Tensor: A tensor of shape :math:`(N, C_{out}, H, W)`. | |
| """ | |
| return self.layer(feat) | |
| def masked_fill(x, mask, value): | |
| y = torch.full(x.shape, value, x.dtype) | |
| return torch.where(mask, y, x) | |
| class DotProductAttentionLayer(nn.Module): | |
| def __init__(self, dim_model=None): | |
| super().__init__() | |
| self.scale = dim_model**-0.5 if dim_model is not None else 1. | |
| def forward(self, query, key, value, mask=None): | |
| query = query.permute(0, 2, 1) | |
| logits = query @ key * self.scale | |
| if mask is not None: | |
| n, seq_len = mask.size() | |
| mask = mask.view(n, 1, seq_len) | |
| logits = logits.masked_fill(mask, float('-inf')) | |
| weights = F.softmax(logits, dim=2) | |
| value = value.transpose(1, 2) | |
| glimpse = weights @ value | |
| glimpse = glimpse.permute(0, 2, 1).contiguous() | |
| return glimpse | |
| class SequenceAttentionDecoder(BaseDecoder): | |
| """Sequence attention decoder for RobustScanner. | |
| RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for | |
| Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_ | |
| Args: | |
| num_classes (int): Number of output classes :math:`C`. | |
| rnn_layers (int): Number of RNN layers. | |
| dim_input (int): Dimension :math:`D_i` of input vector ``feat``. | |
| dim_model (int): Dimension :math:`D_m` of the model. Should also be the | |
| same as encoder output vector ``out_enc``. | |
| max_seq_len (int): Maximum output sequence length :math:`T`. | |
| start_idx (int): The index of `<SOS>`. | |
| mask (bool): Whether to mask input features according to | |
| ``img_meta['valid_ratio']``. | |
| padding_idx (int): The index of `<PAD>`. | |
| dropout (float): Dropout rate. | |
| return_feature (bool): Return feature or logits as the result. | |
| encode_value (bool): Whether to use the output of encoder ``out_enc`` | |
| as `value` of attention layer. If False, the original feature | |
| ``feat`` will be used. | |
| Warning: | |
| This decoder will not predict the final class which is assumed to be | |
| `<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>` | |
| is also ignored by loss as specified in | |
| :obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`. | |
| """ | |
| def __init__(self, | |
| num_classes=None, | |
| rnn_layers=2, | |
| dim_input=512, | |
| dim_model=128, | |
| max_seq_len=40, | |
| start_idx=0, | |
| mask=True, | |
| padding_idx=None, | |
| dropout=0, | |
| return_feature=False, | |
| encode_value=False): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.dim_input = dim_input | |
| self.dim_model = dim_model | |
| self.return_feature = return_feature | |
| self.encode_value = encode_value | |
| self.max_seq_len = max_seq_len | |
| self.start_idx = start_idx | |
| self.mask = mask | |
| self.embedding = nn.Embedding(self.num_classes, | |
| self.dim_model, | |
| padding_idx=padding_idx) | |
| self.sequence_layer = nn.LSTM(input_size=dim_model, | |
| hidden_size=dim_model, | |
| num_layers=rnn_layers, | |
| batch_first=True, | |
| dropout=dropout) | |
| self.attention_layer = DotProductAttentionLayer() | |
| self.prediction = None | |
| if not self.return_feature: | |
| pred_num_classes = num_classes - 1 | |
| self.prediction = nn.Linear( | |
| dim_model if encode_value else dim_input, pred_num_classes) | |
| def forward_train(self, feat, out_enc, targets, valid_ratios): | |
| """ | |
| Args: | |
| feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. | |
| out_enc (Tensor): Encoder output of shape | |
| :math:`(N, D_m, H, W)`. | |
| targets (Tensor): a tensor of shape :math:`(N, T)`. Each element is the index of a | |
| character. | |
| valid_ratios (Tensor): valid length ratio of img. | |
| Returns: | |
| Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if | |
| ``return_feature=False``. Otherwise it would be the hidden feature | |
| before the prediction projection layer, whose shape is | |
| :math:`(N, T, D_m)`. | |
| """ | |
| tgt_embedding = self.embedding(targets) | |
| n, c_enc, h, w = out_enc.shape | |
| assert c_enc == self.dim_model | |
| _, c_feat, _, _ = feat.shape | |
| assert c_feat == self.dim_input | |
| _, len_q, c_q = tgt_embedding.shape | |
| assert c_q == self.dim_model | |
| assert len_q <= self.max_seq_len | |
| query, _ = self.sequence_layer(tgt_embedding) | |
| query = query.permute(0, 2, 1).contiguous() | |
| key = out_enc.view(n, c_enc, h * w) | |
| if self.encode_value: | |
| value = key | |
| else: | |
| value = feat.view(n, c_feat, h * w) | |
| mask = None | |
| if valid_ratios is not None: | |
| mask = query.new_zeros((n, h, w)) | |
| for i, valid_ratio in enumerate(valid_ratios): | |
| valid_width = min(w, math.ceil(w * valid_ratio)) | |
| mask[i, :, valid_width:] = 1 | |
| mask = mask.bool() | |
| mask = mask.view(n, h * w) | |
| attn_out = self.attention_layer(query, key, value, mask) | |
| attn_out = attn_out.permute(0, 2, 1).contiguous() | |
| if self.return_feature: | |
| return attn_out | |
| out = self.prediction(attn_out) | |
| return out | |
| def forward_test(self, feat, out_enc, valid_ratios): | |
| """ | |
| Args: | |
| feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. | |
| out_enc (Tensor): Encoder output of shape | |
| :math:`(N, D_m, H, W)`. | |
| valid_ratios (Tensor): valid length ratio of img. | |
| Returns: | |
| Tensor: The output logit sequence tensor of shape | |
| :math:`(N, T, C-1)`. | |
| """ | |
| batch_size = feat.shape[0] | |
| decode_sequence = (torch.ones((batch_size, self.max_seq_len), | |
| dtype=torch.int64, | |
| device=feat.device) * self.start_idx) | |
| outputs = [] | |
| for i in range(self.max_seq_len): | |
| step_out = self.forward_test_step(feat, out_enc, decode_sequence, | |
| i, valid_ratios) | |
| outputs.append(step_out) | |
| max_idx = torch.argmax(step_out, dim=1, keepdim=False) | |
| if i < self.max_seq_len - 1: | |
| decode_sequence[:, i + 1] = max_idx | |
| outputs = torch.stack(outputs, 1) | |
| return outputs | |
| def forward_test_step(self, feat, out_enc, decode_sequence, current_step, | |
| valid_ratios): | |
| """ | |
| Args: | |
| feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. | |
| out_enc (Tensor): Encoder output of shape | |
| :math:`(N, D_m, H, W)`. | |
| decode_sequence (Tensor): Shape :math:`(N, T)`. The tensor that | |
| stores history decoding result. | |
| current_step (int): Current decoding step. | |
| valid_ratios (Tensor): valid length ratio of img | |
| Returns: | |
| Tensor: Shape :math:`(N, C-1)`. The logit tensor of predicted | |
| tokens at current time step. | |
| """ | |
| embed = self.embedding(decode_sequence) | |
| n, c_enc, h, w = out_enc.shape | |
| assert c_enc == self.dim_model | |
| _, c_feat, _, _ = feat.shape | |
| assert c_feat == self.dim_input | |
| _, _, c_q = embed.shape | |
| assert c_q == self.dim_model | |
| query, _ = self.sequence_layer(embed) | |
| query = query.transpose(1, 2) | |
| key = torch.reshape(out_enc, (n, c_enc, h * w)) | |
| if self.encode_value: | |
| value = key | |
| else: | |
| value = torch.reshape(feat, (n, c_feat, h * w)) | |
| mask = None | |
| if valid_ratios is not None: | |
| mask = query.new_zeros((n, h, w)) | |
| for i, valid_ratio in enumerate(valid_ratios): | |
| valid_width = min(w, math.ceil(w * valid_ratio)) | |
| mask[i, :, valid_width:] = 1 | |
| mask = mask.bool() | |
| mask = mask.view(n, h * w) | |
| # [n, c, l] | |
| attn_out = self.attention_layer(query, key, value, mask) | |
| out = attn_out[:, :, current_step] | |
| if self.return_feature: | |
| return out | |
| out = self.prediction(out) | |
| out = F.softmax(out, dim=-1) | |
| return out | |
| class PositionAwareLayer(nn.Module): | |
| def __init__(self, dim_model, rnn_layers=2): | |
| super().__init__() | |
| self.dim_model = dim_model | |
| self.rnn = nn.LSTM(input_size=dim_model, | |
| hidden_size=dim_model, | |
| num_layers=rnn_layers, | |
| batch_first=True) | |
| self.mixer = nn.Sequential( | |
| nn.Conv2d(dim_model, dim_model, kernel_size=3, stride=1, | |
| padding=1), nn.ReLU(True), | |
| nn.Conv2d(dim_model, dim_model, kernel_size=3, stride=1, | |
| padding=1)) | |
| def forward(self, img_feature): | |
| n, c, h, w = img_feature.shape | |
| rnn_input = img_feature.permute(0, 2, 3, 1).contiguous() | |
| rnn_input = rnn_input.view(n * h, w, c) | |
| rnn_output, _ = self.rnn(rnn_input) | |
| rnn_output = rnn_output.view(n, h, w, c) | |
| rnn_output = rnn_output.permute(0, 3, 1, 2).contiguous() | |
| out = self.mixer(rnn_output) | |
| return out | |
| class PositionAttentionDecoder(BaseDecoder): | |
| """Position attention decoder for RobustScanner. | |
| RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for | |
| Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_ | |
| Args: | |
| num_classes (int): Number of output classes :math:`C`. | |
| rnn_layers (int): Number of RNN layers. | |
| dim_input (int): Dimension :math:`D_i` of input vector ``feat``. | |
| dim_model (int): Dimension :math:`D_m` of the model. Should also be the | |
| same as encoder output vector ``out_enc``. | |
| max_seq_len (int): Maximum output sequence length :math:`T`. | |
| mask (bool): Whether to mask input features according to | |
| ``img_meta['valid_ratio']``. | |
| return_feature (bool): Return feature or logits as the result. | |
| encode_value (bool): Whether to use the output of encoder ``out_enc`` | |
| as `value` of attention layer. If False, the original feature | |
| ``feat`` will be used. | |
| Warning: | |
| This decoder will not predict the final class which is assumed to be | |
| `<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>` | |
| is also ignored by loss | |
| """ | |
| def __init__(self, | |
| num_classes=None, | |
| rnn_layers=2, | |
| dim_input=512, | |
| dim_model=128, | |
| max_seq_len=40, | |
| mask=True, | |
| return_feature=False, | |
| encode_value=False): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.dim_input = dim_input | |
| self.dim_model = dim_model | |
| self.max_seq_len = max_seq_len | |
| self.return_feature = return_feature | |
| self.encode_value = encode_value | |
| self.mask = mask | |
| self.embedding = nn.Embedding(self.max_seq_len + 1, self.dim_model) | |
| self.position_aware_module = PositionAwareLayer( | |
| self.dim_model, rnn_layers) | |
| self.attention_layer = DotProductAttentionLayer() | |
| self.prediction = None | |
| if not self.return_feature: | |
| pred_num_classes = num_classes - 1 | |
| self.prediction = nn.Linear( | |
| dim_model if encode_value else dim_input, pred_num_classes) | |
| def _get_position_index(self, length, batch_size): | |
| position_index_list = [] | |
| for i in range(batch_size): | |
| position_index = torch.range(0, length, step=1, dtype='int64') | |
| position_index_list.append(position_index) | |
| batch_position_index = torch.stack(position_index_list, dim=0) | |
| return batch_position_index | |
| def forward_train(self, feat, out_enc, targets, valid_ratios, | |
| position_index): | |
| """ | |
| Args: | |
| feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. | |
| out_enc (Tensor): Encoder output of shape | |
| :math:`(N, D_m, H, W)`. | |
| targets (dict): A dict with the key ``padded_targets``, a | |
| tensor of shape :math:`(N, T)`. Each element is the index of a | |
| character. | |
| valid_ratios (Tensor): valid length ratio of img. | |
| position_index (Tensor): The position of each word. | |
| Returns: | |
| Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if | |
| ``return_feature=False``. Otherwise it will be the hidden feature | |
| before the prediction projection layer, whose shape is | |
| :math:`(N, T, D_m)`. | |
| """ | |
| n, c_enc, h, w = out_enc.shape | |
| assert c_enc == self.dim_model | |
| _, c_feat, _, _ = feat.shape | |
| assert c_feat == self.dim_input | |
| _, len_q = targets.shape | |
| assert len_q <= self.max_seq_len | |
| position_out_enc = self.position_aware_module(out_enc) | |
| query = self.embedding(position_index) | |
| query = query.permute(0, 2, 1).contiguous() | |
| key = position_out_enc.view(n, c_enc, h * w) | |
| if self.encode_value: | |
| value = out_enc.view(n, c_enc, h * w) | |
| else: | |
| value = feat.view(n, c_feat, h * w) | |
| mask = None | |
| if valid_ratios is not None: | |
| mask = query.new_zeros((n, h, w)) | |
| for i, valid_ratio in enumerate(valid_ratios): | |
| valid_width = min(w, math.ceil(w * valid_ratio)) | |
| mask[i, :, valid_width:] = 1 | |
| mask = mask.bool() | |
| mask = mask.view(n, h * w) | |
| attn_out = self.attention_layer(query, key, value, mask) | |
| attn_out = attn_out.permute(0, 2, 1).contiguous() | |
| if self.return_feature: | |
| return attn_out | |
| return self.prediction(attn_out) | |
| def forward_test(self, feat, out_enc, valid_ratios, position_index): | |
| """ | |
| Args: | |
| feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. | |
| out_enc (Tensor): Encoder output of shape | |
| :math:`(N, D_m, H, W)`. | |
| valid_ratios (Tensor): valid length ratio of img | |
| position_index (Tensor): The position of each word. | |
| Returns: | |
| Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if | |
| ``return_feature=False``. Otherwise it would be the hidden feature | |
| before the prediction projection layer, whose shape is | |
| :math:`(N, T, D_m)`. | |
| """ | |
| n, c_enc, h, w = out_enc.shape | |
| assert c_enc == self.dim_model | |
| _, c_feat, _, _ = feat.shape | |
| assert c_feat == self.dim_input | |
| position_out_enc = self.position_aware_module(out_enc) | |
| query = self.embedding(position_index) | |
| query = query.permute(0, 2, 1).contiguous() | |
| key = position_out_enc.view(n, c_enc, h * w) | |
| if self.encode_value: | |
| value = torch.reshape(out_enc, (n, c_enc, h * w)) | |
| else: | |
| value = torch.reshape(feat, (n, c_feat, h * w)) | |
| mask = None | |
| if valid_ratios is not None: | |
| mask = query.new_zeros((n, h, w)) | |
| for i, valid_ratio in enumerate(valid_ratios): | |
| valid_width = min(w, math.ceil(w * valid_ratio)) | |
| mask[i, :, valid_width:] = 1 | |
| mask = mask.bool() | |
| mask = mask.view(n, h * w) | |
| attn_out = self.attention_layer(query, key, value, mask) | |
| attn_out = attn_out.transpose(1, 2) # [n, len_q, dim_v] | |
| if self.return_feature: | |
| return attn_out | |
| return self.prediction(attn_out) | |
| class RobustScannerFusionLayer(nn.Module): | |
| def __init__(self, dim_model, dim=-1): | |
| super(RobustScannerFusionLayer, self).__init__() | |
| self.dim_model = dim_model | |
| self.dim = dim | |
| self.linear_layer = nn.Linear(dim_model * 2, dim_model * 2) | |
| def forward(self, x0, x1): | |
| assert x0.shape == x1.shape | |
| fusion_input = torch.concat((x0, x1), self.dim) | |
| output = self.linear_layer(fusion_input) | |
| output = F.glu(output, self.dim) | |
| return output | |
| class Decoder(BaseDecoder): | |
| """Decoder for RobustScanner. | |
| RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for | |
| Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_ | |
| Args: | |
| num_classes (int): Number of output classes :math:`C`. | |
| dim_input (int): Dimension :math:`D_i` of input vector ``feat``. | |
| dim_model (int): Dimension :math:`D_m` of the model. Should also be the | |
| same as encoder output vector ``out_enc``. | |
| max_seq_len (int): Maximum output sequence length :math:`T`. | |
| start_idx (int): The index of `<SOS>`. | |
| mask (bool): Whether to mask input features according to | |
| ``img_meta['valid_ratio']``. | |
| padding_idx (int): The index of `<PAD>`. | |
| encode_value (bool): Whether to use the output of encoder ``out_enc`` | |
| as `value` of attention layer. If False, the original feature | |
| ``feat`` will be used. | |
| Warning: | |
| This decoder will not predict the final class which is assumed to be | |
| `<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>` | |
| is also ignored by loss as specified in | |
| :obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`. | |
| """ | |
| def __init__(self, | |
| num_classes=None, | |
| dim_input=512, | |
| dim_model=128, | |
| hybrid_decoder_rnn_layers=2, | |
| hybrid_decoder_dropout=0, | |
| position_decoder_rnn_layers=2, | |
| max_len=40, | |
| start_idx=0, | |
| mask=True, | |
| padding_idx=None, | |
| end_idx=0, | |
| encode_value=False): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.dim_input = dim_input | |
| self.dim_model = dim_model | |
| self.max_seq_len = max_len | |
| self.encode_value = encode_value | |
| self.start_idx = start_idx | |
| self.padding_idx = padding_idx | |
| self.end_idx = end_idx | |
| self.mask = mask | |
| # init hybrid decoder | |
| self.hybrid_decoder = SequenceAttentionDecoder( | |
| num_classes=num_classes, | |
| rnn_layers=hybrid_decoder_rnn_layers, | |
| dim_input=dim_input, | |
| dim_model=dim_model, | |
| max_seq_len=max_len, | |
| start_idx=start_idx, | |
| mask=mask, | |
| padding_idx=padding_idx, | |
| dropout=hybrid_decoder_dropout, | |
| encode_value=encode_value, | |
| return_feature=True) | |
| # init position decoder | |
| self.position_decoder = PositionAttentionDecoder( | |
| num_classes=num_classes, | |
| rnn_layers=position_decoder_rnn_layers, | |
| dim_input=dim_input, | |
| dim_model=dim_model, | |
| max_seq_len=max_len, | |
| mask=mask, | |
| encode_value=encode_value, | |
| return_feature=True) | |
| self.fusion_module = RobustScannerFusionLayer( | |
| self.dim_model if encode_value else dim_input) | |
| pred_num_classes = num_classes | |
| self.prediction = nn.Linear(dim_model if encode_value else dim_input, | |
| pred_num_classes) | |
| def forward_train(self, feat, out_enc, target, valid_ratios, | |
| word_positions): | |
| """ | |
| Args: | |
| feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. | |
| out_enc (Tensor): Encoder output of shape | |
| :math:`(N, D_m, H, W)`. | |
| target (dict): A dict with the key ``padded_targets``, a | |
| tensor of shape :math:`(N, T)`. Each element is the index of a | |
| character. | |
| valid_ratios (Tensor): | |
| word_positions (Tensor): The position of each word. | |
| Returns: | |
| Tensor: A raw logit tensor of shape :math:`(N, T, C-1)`. | |
| """ | |
| hybrid_glimpse = self.hybrid_decoder.forward_train( | |
| feat, out_enc, target, valid_ratios) | |
| position_glimpse = self.position_decoder.forward_train( | |
| feat, out_enc, target, valid_ratios, word_positions) | |
| fusion_out = self.fusion_module(hybrid_glimpse, position_glimpse) | |
| out = self.prediction(fusion_out) | |
| return out | |
| def forward_test(self, feat, out_enc, valid_ratios, word_positions): | |
| """ | |
| Args: | |
| feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. | |
| out_enc (Tensor): Encoder output of shape | |
| :math:`(N, D_m, H, W)`. | |
| valid_ratios (Tensor): | |
| word_positions (Tensor): The position of each word. | |
| Returns: | |
| Tensor: The output logit sequence tensor of shape | |
| :math:`(N, T, C-1)`. | |
| """ | |
| seq_len = self.max_seq_len | |
| batch_size = feat.shape[0] | |
| decode_sequence = (torch.ones( | |
| (batch_size, seq_len), dtype=torch.int64, device=feat.device) * | |
| self.start_idx) | |
| position_glimpse = self.position_decoder.forward_test( | |
| feat, out_enc, valid_ratios, word_positions) | |
| outputs = [] | |
| for i in range(seq_len): | |
| hybrid_glimpse_step = self.hybrid_decoder.forward_test_step( | |
| feat, out_enc, decode_sequence, i, valid_ratios) | |
| fusion_out = self.fusion_module(hybrid_glimpse_step, | |
| position_glimpse[:, i, :]) | |
| char_out = self.prediction(fusion_out) | |
| char_out = F.softmax(char_out, -1) | |
| outputs.append(char_out) | |
| max_idx = torch.argmax(char_out, dim=1, keepdim=False) | |
| if i < seq_len - 1: | |
| decode_sequence[:, i + 1] = max_idx | |
| if (decode_sequence == self.end_idx).any(dim=-1).all(): | |
| break | |
| outputs = torch.stack(outputs, 1) | |
| return outputs | |