Spaces:
Running
Running
| import math | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from openrec.modeling.common import Mlp | |
| class NRTRDecoder(nn.Module): | |
| """A transformer model. User is able to modify the attributes as needed. | |
| The architechture is based on the paper "Attention Is All You Need". Ashish | |
| Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N | |
| Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you | |
| need. In Advances in Neural Information Processing Systems, pages | |
| 6000-6010. | |
| Args: | |
| d_model: the number of expected features in the encoder/decoder inputs (default=512). | |
| nhead: the number of heads in the multiheadattention models (default=8). | |
| num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). | |
| num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). | |
| dim_feedforward: the dimension of the feedforward network model (default=2048). | |
| dropout: the dropout value (default=0.1). | |
| custom_encoder: custom encoder (default=None). | |
| custom_decoder: custom decoder (default=None). | |
| """ | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| nhead=None, | |
| num_encoder_layers=6, | |
| beam_size=0, | |
| num_decoder_layers=6, | |
| max_len=25, | |
| attention_dropout_rate=0.0, | |
| residual_dropout_rate=0.1, | |
| scale_embedding=True, | |
| ): | |
| super(NRTRDecoder, self).__init__() | |
| self.out_channels = out_channels | |
| self.ignore_index = out_channels - 1 | |
| self.bos = out_channels - 2 | |
| self.eos = 0 | |
| self.max_len = max_len | |
| d_model = in_channels | |
| dim_feedforward = d_model * 4 | |
| nhead = nhead if nhead is not None else d_model // 32 | |
| self.embedding = Embeddings( | |
| d_model=d_model, | |
| vocab=self.out_channels, | |
| padding_idx=0, | |
| scale_embedding=scale_embedding, | |
| ) | |
| self.positional_encoding = PositionalEncoding( | |
| dropout=residual_dropout_rate, dim=d_model) | |
| if num_encoder_layers > 0: | |
| self.encoder = nn.ModuleList([ | |
| TransformerBlock( | |
| d_model, | |
| nhead, | |
| dim_feedforward, | |
| attention_dropout_rate, | |
| residual_dropout_rate, | |
| with_self_attn=True, | |
| with_cross_attn=False, | |
| ) for i in range(num_encoder_layers) | |
| ]) | |
| else: | |
| self.encoder = None | |
| self.decoder = nn.ModuleList([ | |
| TransformerBlock( | |
| d_model, | |
| nhead, | |
| dim_feedforward, | |
| attention_dropout_rate, | |
| residual_dropout_rate, | |
| with_self_attn=True, | |
| with_cross_attn=True, | |
| ) for i in range(num_decoder_layers) | |
| ]) | |
| self.beam_size = beam_size | |
| self.d_model = d_model | |
| self.nhead = nhead | |
| self.tgt_word_prj = nn.Linear(d_model, | |
| self.out_channels - 2, | |
| bias=False) | |
| w0 = np.random.normal(0.0, d_model**-0.5, | |
| (d_model, self.out_channels - 2)).astype( | |
| np.float32) | |
| self.tgt_word_prj.weight.data = torch.from_numpy(w0.transpose()) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_normal_(m.weight) | |
| if m.bias is not None: | |
| nn.init.zeros_(m.bias) | |
| def forward_train(self, src, tgt): | |
| tgt = tgt[:, :-1] | |
| tgt = self.embedding(tgt) | |
| tgt = self.positional_encoding(tgt) | |
| tgt_mask = self.generate_square_subsequent_mask( | |
| tgt.shape[1], device=src.get_device()) | |
| if self.encoder is not None: | |
| src = self.positional_encoding(src) | |
| for encoder_layer in self.encoder: | |
| src = encoder_layer(src) | |
| memory = src # B N C | |
| else: | |
| memory = src # B N C | |
| for decoder_layer in self.decoder: | |
| tgt = decoder_layer(tgt, memory, self_mask=tgt_mask) | |
| output = tgt | |
| logit = self.tgt_word_prj(output) | |
| return logit | |
| def forward(self, src, data=None): | |
| """Take in and process masked source/target sequences. | |
| Args: | |
| src: the sequence to the encoder (required). | |
| tgt: the sequence to the decoder (required). | |
| Shape: | |
| - src: :math:`(B, sN, C)`. | |
| - tgt: :math:`(B, tN, C)`. | |
| Examples: | |
| >>> output = transformer_model(src, tgt) | |
| """ | |
| if self.training: | |
| max_len = data[1].max() | |
| tgt = data[0][:, :2 + max_len] | |
| res = self.forward_train(src, tgt) | |
| else: | |
| res = self.forward_test(src) | |
| return res | |
| def forward_test(self, src): | |
| bs = src.shape[0] | |
| if self.encoder is not None: | |
| src = self.positional_encoding(src) | |
| for encoder_layer in self.encoder: | |
| src = encoder_layer(src) | |
| memory = src # B N C | |
| else: | |
| memory = src | |
| dec_seq = torch.full((bs, self.max_len + 1), | |
| self.ignore_index, | |
| dtype=torch.int64, | |
| device=src.get_device()) | |
| dec_seq[:, 0] = self.bos | |
| logits = [] | |
| self.attn_maps = [] | |
| for len_dec_seq in range(0, self.max_len): | |
| dec_seq_embed = self.embedding( | |
| dec_seq[:, :len_dec_seq + 1]) # N dim 26+10 # </s> 012 a | |
| dec_seq_embed = self.positional_encoding(dec_seq_embed) | |
| tgt_mask = self.generate_square_subsequent_mask( | |
| dec_seq_embed.shape[1], src.get_device()) | |
| tgt = dec_seq_embed # bs, 3, dim #bos, a, b, c, ... eos | |
| for decoder_layer in self.decoder: | |
| tgt = decoder_layer(tgt, memory, self_mask=tgt_mask) | |
| self.attn_maps.append( | |
| self.decoder[-1].cross_attn.attn_map[0][:, -1:, :]) | |
| dec_output = tgt | |
| dec_output = dec_output[:, -1:, :] | |
| word_prob = F.softmax(self.tgt_word_prj(dec_output), dim=-1) | |
| logits.append(word_prob) | |
| if len_dec_seq < self.max_len: | |
| # greedy decode. add the next token index to the target input | |
| dec_seq[:, len_dec_seq + 1] = word_prob.squeeze().argmax(-1) | |
| # Efficient batch decoding: If all output words have at least one EOS token, end decoding. | |
| if (dec_seq == self.eos).any(dim=-1).all(): | |
| break | |
| logits = torch.cat(logits, dim=1) | |
| return logits | |
| def generate_square_subsequent_mask(self, sz, device): | |
| """Generate a square mask for the sequence. | |
| The masked positions are filled with float('-inf'). Unmasked positions | |
| are filled with float(0.0). | |
| """ | |
| mask = torch.zeros([sz, sz], dtype=torch.float32) | |
| mask_inf = torch.triu( | |
| torch.full((sz, sz), dtype=torch.float32, fill_value=-torch.inf), | |
| diagonal=1, | |
| ) | |
| mask = mask + mask_inf | |
| return mask.unsqueeze(0).unsqueeze(0).to(device) | |
| class MultiheadAttention(nn.Module): | |
| def __init__(self, embed_dim, num_heads, dropout=0.0, self_attn=False): | |
| super(MultiheadAttention, self).__init__() | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.head_dim = embed_dim // num_heads | |
| assert (self.head_dim * num_heads == self.embed_dim | |
| ), 'embed_dim must be divisible by num_heads' | |
| self.scale = self.head_dim**-0.5 | |
| self.self_attn = self_attn | |
| if self_attn: | |
| self.qkv = nn.Linear(embed_dim, embed_dim * 3) | |
| else: | |
| self.q = nn.Linear(embed_dim, embed_dim) | |
| self.kv = nn.Linear(embed_dim, embed_dim * 2) | |
| self.attn_drop = nn.Dropout(dropout) | |
| self.out_proj = nn.Linear(embed_dim, embed_dim) | |
| def forward(self, query, key=None, attn_mask=None): | |
| B, qN = query.shape[:2] | |
| if self.self_attn: | |
| qkv = self.qkv(query) | |
| qkv = qkv.reshape(B, qN, 3, self.num_heads, | |
| self.head_dim).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv.unbind(0) | |
| else: | |
| kN = key.shape[1] | |
| q = self.q(query) | |
| q = q.reshape(B, qN, self.num_heads, self.head_dim).transpose(1, 2) | |
| kv = self.kv(key) | |
| kv = kv.reshape(B, kN, 2, self.num_heads, | |
| self.head_dim).permute(2, 0, 3, 1, 4) | |
| k, v = kv.unbind(0) | |
| attn = (q.matmul(k.transpose(2, 3))) * self.scale | |
| if attn_mask is not None: | |
| attn += attn_mask | |
| attn = F.softmax(attn, dim=-1) | |
| if not self.training: | |
| self.attn_map = attn | |
| attn = self.attn_drop(attn) | |
| x = (attn.matmul(v)).transpose(1, 2) | |
| x = x.reshape(B, qN, self.embed_dim) | |
| x = self.out_proj(x) | |
| return x | |
| class TransformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| d_model, | |
| nhead, | |
| dim_feedforward=2048, | |
| attention_dropout_rate=0.0, | |
| residual_dropout_rate=0.1, | |
| with_self_attn=True, | |
| with_cross_attn=False, | |
| epsilon=1e-5, | |
| ): | |
| super(TransformerBlock, self).__init__() | |
| self.with_self_attn = with_self_attn | |
| if with_self_attn: | |
| self.self_attn = MultiheadAttention(d_model, | |
| nhead, | |
| dropout=attention_dropout_rate, | |
| self_attn=with_self_attn) | |
| self.norm1 = nn.LayerNorm(d_model, eps=epsilon) | |
| self.dropout1 = nn.Dropout(residual_dropout_rate) | |
| self.with_cross_attn = with_cross_attn | |
| if with_cross_attn: | |
| self.cross_attn = MultiheadAttention( | |
| d_model, nhead, dropout=attention_dropout_rate | |
| ) # for self_attn of encoder or cross_attn of decoder | |
| self.norm2 = nn.LayerNorm(d_model, eps=epsilon) | |
| self.dropout2 = nn.Dropout(residual_dropout_rate) | |
| self.mlp = Mlp( | |
| in_features=d_model, | |
| hidden_features=dim_feedforward, | |
| act_layer=nn.ReLU, | |
| drop=residual_dropout_rate, | |
| ) | |
| self.norm3 = nn.LayerNorm(d_model, eps=epsilon) | |
| self.dropout3 = nn.Dropout(residual_dropout_rate) | |
| def forward(self, tgt, memory=None, self_mask=None, cross_mask=None): | |
| if self.with_self_attn: | |
| tgt1 = self.self_attn(tgt, attn_mask=self_mask) | |
| tgt = self.norm1(tgt + self.dropout1(tgt1)) | |
| if self.with_cross_attn: | |
| tgt2 = self.cross_attn(tgt, key=memory, attn_mask=cross_mask) | |
| tgt = self.norm2(tgt + self.dropout2(tgt2)) | |
| tgt = self.norm3(tgt + self.dropout3(self.mlp(tgt))) | |
| return tgt | |
| class PositionalEncoding(nn.Module): | |
| """Inject some information about the relative or absolute position of the | |
| tokens in the sequence. The positional encodings have the same dimension as | |
| the embeddings, so that the two can be summed. Here, we use sine and cosine | |
| functions of different frequencies. | |
| .. math:: | |
| \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) | |
| \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) | |
| \text{where pos is the word position and i is the embed idx) | |
| Args: | |
| d_model: the embed dim (required). | |
| dropout: the dropout value (default=0.1). | |
| max_len: the max. length of the incoming sequence (default=5000). | |
| Examples: | |
| >>> pos_encoder = PositionalEncoding(d_model) | |
| """ | |
| def __init__(self, dropout, dim, max_len=5000): | |
| super(PositionalEncoding, self).__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| pe = torch.zeros([max_len, dim]) | |
| position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) | |
| div_term = torch.exp( | |
| torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = torch.unsqueeze(pe, 0) | |
| # pe = torch.permute(pe, [1, 0, 2]) | |
| self.register_buffer('pe', pe) | |
| def forward(self, x): | |
| """Inputs of forward function | |
| Args: | |
| x: the sequence fed to the positional encoder model (required). | |
| Shape: | |
| x: [sequence length, batch size, embed dim] | |
| output: [sequence length, batch size, embed dim] | |
| Examples: | |
| >>> output = pos_encoder(x) | |
| """ | |
| # x = x.permute([1, 0, 2]) | |
| # x = x + self.pe[:x.shape[0], :] | |
| x = x + self.pe[:, :x.shape[1], :] | |
| return self.dropout(x) # .permute([1, 0, 2]) | |
| class PositionalEncoding_2d(nn.Module): | |
| """Inject some information about the relative or absolute position of the | |
| tokens in the sequence. The positional encodings have the same dimension as | |
| the embeddings, so that the two can be summed. Here, we use sine and cosine | |
| functions of different frequencies. | |
| .. math:: | |
| \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) | |
| \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) | |
| \text{where pos is the word position and i is the embed idx) | |
| Args: | |
| d_model: the embed dim (required). | |
| dropout: the dropout value (default=0.1). | |
| max_len: the max. length of the incoming sequence (default=5000). | |
| Examples: | |
| >>> pos_encoder = PositionalEncoding(d_model) | |
| """ | |
| def __init__(self, dropout, dim, max_len=5000): | |
| super(PositionalEncoding_2d, self).__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| pe = torch.zeros([max_len, dim]) | |
| position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) | |
| div_term = torch.exp( | |
| torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = torch.permute(torch.unsqueeze(pe, 0), [1, 0, 2]) | |
| self.register_buffer('pe', pe) | |
| self.avg_pool_1 = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.linear1 = nn.Linear(dim, dim) | |
| self.linear1.weight.data.fill_(1.0) | |
| self.avg_pool_2 = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.linear2 = nn.Linear(dim, dim) | |
| self.linear2.weight.data.fill_(1.0) | |
| def forward(self, x): | |
| """Inputs of forward function | |
| Args: | |
| x: the sequence fed to the positional encoder model (required). | |
| Shape: | |
| x: [sequence length, batch size, embed dim] | |
| output: [sequence length, batch size, embed dim] | |
| Examples: | |
| >>> output = pos_encoder(x) | |
| """ | |
| w_pe = self.pe[:x.shape[-1], :] | |
| w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0) | |
| w_pe = w_pe * w1 | |
| w_pe = torch.permute(w_pe, [1, 2, 0]) | |
| w_pe = torch.unsqueeze(w_pe, 2) | |
| h_pe = self.pe[:x.shape[-2], :] | |
| w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0) | |
| h_pe = h_pe * w2 | |
| h_pe = torch.permute(h_pe, [1, 2, 0]) | |
| h_pe = torch.unsqueeze(h_pe, 3) | |
| x = x + w_pe + h_pe | |
| x = torch.permute( | |
| torch.reshape(x, | |
| [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]), | |
| [2, 0, 1], | |
| ) | |
| return self.dropout(x) | |
| class Embeddings(nn.Module): | |
| def __init__(self, d_model, vocab, padding_idx=None, scale_embedding=True): | |
| super(Embeddings, self).__init__() | |
| self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx) | |
| self.embedding.weight.data.normal_(mean=0.0, std=d_model**-0.5) | |
| self.d_model = d_model | |
| self.scale_embedding = scale_embedding | |
| def forward(self, x): | |
| if self.scale_embedding: | |
| x = self.embedding(x) | |
| return x * math.sqrt(self.d_model) | |
| return self.embedding(x) | |