| from torchaudio.models import Conformer | |
| from torchaudio.models.rnnt import _TimeReduction | |
| from transformers import PretrainedConfig, PreTrainedModel | |
| import torch | |
| from torch import nn | |
| from typing import List, Tuple, Optional | |
| class ConformerConfig(PretrainedConfig): | |
| model_type = 'conformer' | |
| class ConformerEncoder(PreTrainedModel): | |
| config_class = ConformerConfig | |
| def __init__( | |
| self, | |
| config, | |
| ) -> None: | |
| super().__init__(config) | |
| self.time_reduction = _TimeReduction(config.time_reduction_stride) | |
| self.input_linear = torch.nn.Linear( | |
| config.input_dim * config.time_reduction_stride, | |
| config.conformer_input_dim) | |
| self.conformer = Conformer( | |
| num_layers=config.conformer_num_layers, | |
| input_dim=config.conformer_input_dim, | |
| ffn_dim=config.conformer_ffn_dim, | |
| num_heads=config.conformer_num_heads, | |
| depthwise_conv_kernel_size=config.conformer_depthwise_conv_kernel_size, | |
| dropout=config.conformer_dropout, | |
| use_group_norm=True, | |
| convolution_first=True, | |
| ) | |
| self.output_linear = torch.nn.Linear(config.conformer_input_dim, config.output_dim) | |
| def forward(self, inputs, lengths, labels=None): | |
| time_reduction_out, time_reduction_lengths = self.time_reduction(inputs, lengths) | |
| input_linear_out = self.input_linear(time_reduction_out) | |
| x, input_lengths = self.conformer(input_linear_out, time_reduction_lengths) | |
| logits = self.output_linear(x) | |
| loss = None | |
| if labels is not None: | |
| labels_mask = labels >= 0 | |
| target_lengths = labels_mask.sum(-1) | |
| flattened_targets = labels.masked_select(labels_mask) | |
| log_probs = nn.functional.log_softmax( | |
| logits, | |
| dim=-1, | |
| dtype=torch.float32 | |
| ).transpose(0, 1) | |
| with torch.backends.cudnn.flags(enabled=False): | |
| loss = nn.functional.ctc_loss( | |
| log_probs, | |
| flattened_targets, | |
| input_lengths, | |
| target_lengths, | |
| blank=self.config.pad_token_id, | |
| reduction=self.config.ctc_loss_reduction, | |
| zero_infinity=self.config.ctc_zero_infinity, | |
| ) | |
| output = (logits, input_lengths) | |
| return ((loss,) + output) if loss is not None else output | |