Spaces:
Runtime error
Runtime error
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT license. | |
| import torch | |
| import torch.nn as nn | |
| import torch | |
| from torch.autograd import Variable | |
| import copy | |
| class Seq2Seq(nn.Module): | |
| """ | |
| Build Seqence-to-Sequence. | |
| Parameters: | |
| * `encoder`- encoder of seq2seq model. e.g. roberta | |
| * `decoder`- decoder of seq2seq model. e.g. transformer | |
| * `config`- configuration of encoder model. | |
| * `beam_size`- beam size for beam search. | |
| * `max_length`- max length of target for beam search. | |
| * `sos_id`- start of symbol ids in target for beam search. | |
| * `eos_id`- end of symbol ids in target for beam search. | |
| """ | |
| def __init__( | |
| self, | |
| encoder, | |
| decoder, | |
| config, | |
| beam_size=None, | |
| max_length=None, | |
| sos_id=None, | |
| eos_id=None, | |
| ): | |
| super(Seq2Seq, self).__init__() | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| self.config = config | |
| self.register_buffer("bias", torch.tril(torch.ones(2048, 2048))) | |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| self.lsm = nn.LogSoftmax(dim=-1) | |
| self.tie_weights() | |
| self.beam_size = beam_size | |
| self.max_length = max_length | |
| self.sos_id = sos_id | |
| self.eos_id = eos_id | |
| def _tie_or_clone_weights(self, first_module, second_module): | |
| """Tie or clone module weights depending of weither we are using TorchScript or not""" | |
| if self.config.torchscript: | |
| first_module.weight = nn.Parameter(second_module.weight.clone()) | |
| else: | |
| first_module.weight = second_module.weight | |
| def tie_weights(self): | |
| """Make sure we are sharing the input and output embeddings. | |
| Export to TorchScript can't handle parameter sharing so we are cloning them instead. | |
| """ | |
| self._tie_or_clone_weights( | |
| self.lm_head, self.encoder.embeddings.word_embeddings | |
| ) | |
| def forward( | |
| self, | |
| source_ids=None, | |
| source_mask=None, | |
| target_ids=None, | |
| target_mask=None, | |
| args=None, | |
| ): | |
| outputs = self.encoder(source_ids, attention_mask=source_mask) | |
| encoder_output = outputs[0].permute([1, 0, 2]).contiguous() | |
| if target_ids is not None: | |
| attn_mask = -1e4 * ( | |
| 1 - self.bias[: target_ids.shape[1], : target_ids.shape[1]] | |
| ) | |
| tgt_embeddings = ( | |
| self.encoder.embeddings(target_ids).permute([1, 0, 2]).contiguous() | |
| ) | |
| out = self.decoder( | |
| tgt_embeddings, | |
| encoder_output, | |
| tgt_mask=attn_mask, | |
| memory_key_padding_mask=(1 - source_mask).bool(), | |
| ) | |
| hidden_states = torch.tanh(self.dense(out)).permute([1, 0, 2]).contiguous() | |
| lm_logits = self.lm_head(hidden_states) | |
| # Shift so that tokens < n predict n | |
| active_loss = target_mask[..., 1:].ne(0).view(-1) == 1 | |
| shift_logits = lm_logits[..., :-1, :].contiguous() | |
| shift_labels = target_ids[..., 1:].contiguous() | |
| # Flatten the tokens | |
| loss_fct = nn.CrossEntropyLoss(ignore_index=-1) | |
| loss = loss_fct( | |
| shift_logits.view(-1, shift_logits.size(-1))[active_loss], | |
| shift_labels.view(-1)[active_loss], | |
| ) | |
| outputs = loss, loss * active_loss.sum(), active_loss.sum() | |
| return outputs | |
| else: | |
| # Predict | |
| preds = [] | |
| try: | |
| zero = torch.cuda.LongTensor(1).fill_(0) | |
| except Exception as e: | |
| zero = torch.LongTensor(1).fill_(0) | |
| for i in range(source_ids.shape[0]): | |
| context = encoder_output[:, i : i + 1] | |
| context_mask = source_mask[i : i + 1, :] | |
| beam = Beam(self.beam_size, self.sos_id, self.eos_id) | |
| input_ids = beam.getCurrentState() | |
| context = context.repeat(1, self.beam_size, 1) | |
| context_mask = context_mask.repeat(self.beam_size, 1) | |
| for _ in range(self.max_length): | |
| if beam.done(): | |
| break | |
| attn_mask = -1e4 * ( | |
| 1 - self.bias[: input_ids.shape[1], : input_ids.shape[1]] | |
| ) | |
| tgt_embeddings = ( | |
| self.encoder.embeddings(input_ids) | |
| .permute([1, 0, 2]) | |
| .contiguous() | |
| ) | |
| out = self.decoder( | |
| tgt_embeddings, | |
| context, | |
| tgt_mask=attn_mask, | |
| memory_key_padding_mask=(1 - context_mask).bool(), | |
| ) | |
| out = torch.tanh(self.dense(out)) | |
| hidden_states = out.permute([1, 0, 2]).contiguous()[:, -1, :] | |
| out = self.lsm(self.lm_head(hidden_states)).data | |
| beam.advance(out) | |
| input_ids.data.copy_( | |
| input_ids.data.index_select(0, beam.getCurrentOrigin()) | |
| ) | |
| input_ids = torch.cat((input_ids, beam.getCurrentState()), -1) | |
| hyp = beam.getHyp(beam.getFinal()) | |
| pred = beam.buildTargetTokens(hyp)[: self.beam_size] | |
| pred = [ | |
| torch.cat( | |
| [x.view(-1) for x in p] + [zero] * (self.max_length - len(p)) | |
| ).view(1, -1) | |
| for p in pred | |
| ] | |
| preds.append(torch.cat(pred, 0).unsqueeze(0)) | |
| preds = torch.cat(preds, 0) | |
| return preds | |
| class Beam(object): | |
| def __init__(self, size, sos, eos): | |
| self.size = size | |
| if torch.cuda.is_available(): | |
| self.tt = torch.cuda | |
| else: | |
| self.tt = torch | |
| # The score for each translation on the beam. | |
| self.scores = self.tt.FloatTensor(size).zero_() | |
| # The backpointers at each time-step. | |
| self.prevKs = [] | |
| # The outputs at each time-step. | |
| self.nextYs = [self.tt.LongTensor(size).fill_(0)] | |
| self.nextYs[0][0] = sos | |
| # Has EOS topped the beam yet. | |
| self._eos = eos | |
| self.eosTop = False | |
| # Time and k pair for finished. | |
| self.finished = [] | |
| def getCurrentState(self): | |
| "Get the outputs for the current timestep." | |
| batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) | |
| return batch | |
| def getCurrentOrigin(self): | |
| "Get the backpointers for the current timestep." | |
| return self.prevKs[-1] | |
| def advance(self, wordLk): | |
| """ | |
| Given prob over words for every last beam `wordLk` and attention | |
| `attnOut`: Compute and update the beam search. | |
| Parameters: | |
| * `wordLk`- probs of advancing from the last step (K x words) | |
| * `attnOut`- attention at the last step | |
| Returns: True if beam search is complete. | |
| """ | |
| numWords = wordLk.size(1) | |
| # Sum the previous scores. | |
| if len(self.prevKs) > 0: | |
| beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) | |
| # Don't let EOS have children. | |
| for i in range(self.nextYs[-1].size(0)): | |
| if self.nextYs[-1][i] == self._eos: | |
| beamLk[i] = -1e20 | |
| else: | |
| beamLk = wordLk[0] | |
| flatBeamLk = beamLk.view(-1) | |
| bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) | |
| self.scores = bestScores | |
| # bestScoresId is flattened beam x word array, so calculate which | |
| # word and beam each score came from | |
| prevK = bestScoresId // numWords | |
| self.prevKs.append(prevK) | |
| self.nextYs.append((bestScoresId - prevK * numWords)) | |
| for i in range(self.nextYs[-1].size(0)): | |
| if self.nextYs[-1][i] == self._eos: | |
| s = self.scores[i] | |
| self.finished.append((s, len(self.nextYs) - 1, i)) | |
| # End condition is when top-of-beam is EOS and no global score. | |
| if self.nextYs[-1][0] == self._eos: | |
| self.eosTop = True | |
| def done(self): | |
| return self.eosTop and len(self.finished) >= self.size | |
| def getFinal(self): | |
| if len(self.finished) == 0: | |
| self.finished.append((self.scores[0], len(self.nextYs) - 1, 0)) | |
| self.finished.sort(key=lambda a: -a[0]) | |
| if len(self.finished) != self.size: | |
| unfinished = [] | |
| for i in range(self.nextYs[-1].size(0)): | |
| if self.nextYs[-1][i] != self._eos: | |
| s = self.scores[i] | |
| unfinished.append((s, len(self.nextYs) - 1, i)) | |
| unfinished.sort(key=lambda a: -a[0]) | |
| self.finished += unfinished[: self.size - len(self.finished)] | |
| return self.finished[: self.size] | |
| def getHyp(self, beam_res): | |
| """ | |
| Walk back to construct the full hypothesis. | |
| """ | |
| hyps = [] | |
| for _, timestep, k in beam_res: | |
| hyp = [] | |
| for j in range(len(self.prevKs[:timestep]) - 1, -1, -1): | |
| hyp.append(self.nextYs[j + 1][k]) | |
| k = self.prevKs[j][k] | |
| hyps.append(hyp[::-1]) | |
| return hyps | |
| def buildTargetTokens(self, preds): | |
| sentence = [] | |
| for pred in preds: | |
| tokens = [] | |
| for tok in pred: | |
| if tok == self._eos: | |
| break | |
| tokens.append(tok) | |
| sentence.append(tokens) | |
| return sentence | |