Spaces:
Runtime error
Runtime error
| # Extracted from transformers' WhisperModel to simplify package dependency | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional, Literal | |
| from fireredtts2.codec.utils import make_nonpad_mask | |
| from fireredtts2.codec.audio import mel_filter_bank | |
| def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor: | |
| """Returns sinusoids for positional embedding""" | |
| if channels % 2 != 0: | |
| raise ValueError( | |
| f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels." | |
| ) | |
| log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1) | |
| inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) | |
| scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1) | |
| return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1) | |
| class WhisperSdpaAttention(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| num_heads: int, | |
| dropout: float = 0.0, | |
| bias: bool = True, | |
| ): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.dropout = dropout | |
| self.head_dim = embed_dim // num_heads | |
| self.bias = bias | |
| self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) | |
| self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | |
| self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | |
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | |
| def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | |
| return ( | |
| tensor.view(bsz, seq_len, self.num_heads, self.head_dim) | |
| .transpose(1, 2) | |
| .contiguous() | |
| ) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| ): | |
| """ | |
| Args: | |
| attention_mask: Bool mask or float mask. Bool mask, True indicates should attend. Float mask is added to the attention score. | |
| """ | |
| bsz, tgt_len, _ = hidden_states.size() | |
| query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz) | |
| key_states = self._shape(self.k_proj(hidden_states), tgt_len, bsz) | |
| value_states = self._shape(self.v_proj(hidden_states), tgt_len, bsz) | |
| # NOTE sdpa needs a 4-dim attention_mask: (b, nh, tq, tv) | |
| if attention_mask is not None and len(attention_mask.shape) == 3: | |
| attention_mask = attention_mask.unsqueeze(1) | |
| attn_output = F.scaled_dot_product_attention( | |
| query_states, | |
| key_states, | |
| value_states, | |
| attn_mask=attention_mask, | |
| dropout_p=self.dropout if self.training else 0.0, | |
| ) # (bsz, nh, l, d) | |
| attn_output = attn_output.transpose(1, 2) | |
| attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) | |
| attn_output = self.out_proj(attn_output) | |
| return attn_output | |
| def forward_chunk( | |
| self, | |
| hidden_states: torch.Tensor, | |
| kv_cache: torch.Tensor = None, | |
| ): | |
| """Forward self-attention with kv cache. | |
| Args: | |
| hidden_states: shape (b, t, c) | |
| kv_cache: shape (b, nh, t, c*2) | |
| """ | |
| bsz, tgt_len, _ = hidden_states.size() | |
| # shape (b, nh, t, c) | |
| query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz) | |
| key_states = self._shape(self.k_proj(hidden_states), tgt_len, bsz) | |
| value_states = self._shape(self.v_proj(hidden_states), tgt_len, bsz) | |
| # unpack cache | |
| if kv_cache is not None: | |
| k_cache, v_cache = kv_cache.chunk(2, dim=-1) | |
| key_states = torch.cat([k_cache, key_states], dim=2) | |
| value_states = torch.cat([v_cache, value_states], dim=2) | |
| new_kv_cache = torch.cat([key_states, value_states], dim=-1) | |
| # attention | |
| attn_output = F.scaled_dot_product_attention( | |
| query_states, | |
| key_states, | |
| value_states, | |
| attn_mask=None, | |
| dropout_p=0.0, | |
| ) # (bsz, nh, l, d) | |
| attn_output = attn_output.transpose(1, 2) | |
| attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) | |
| attn_output = self.out_proj(attn_output) | |
| return attn_output, new_kv_cache | |
| class WhisperEncoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| num_heads: int, | |
| ffn_dim: int = None, | |
| attn_dropout: float = 0.0, | |
| dropout: float = 0.0, | |
| ): | |
| super().__init__() | |
| self.dropout = dropout | |
| # Attention | |
| self.self_attn = WhisperSdpaAttention(embed_dim, num_heads, attn_dropout) | |
| self.self_attn_layer_norm = nn.LayerNorm(embed_dim) | |
| # FFN | |
| ffn_dim = ffn_dim if ffn_dim is not None else embed_dim * 4 | |
| self.fc1 = nn.Linear(embed_dim, ffn_dim) | |
| self.fc2 = nn.Linear(ffn_dim, embed_dim) | |
| # Output norm | |
| self.final_layer_norm = nn.LayerNorm(embed_dim) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| ): | |
| # Attention | |
| residual = hidden_states | |
| hidden_states = self.self_attn_layer_norm(hidden_states) | |
| hidden_states = self.self_attn(hidden_states, attention_mask) | |
| hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) | |
| hidden_states = residual + hidden_states | |
| # FFN | |
| residual = hidden_states | |
| hidden_states = self.final_layer_norm(hidden_states) | |
| hidden_states = F.gelu(self.fc1(hidden_states)) | |
| hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) | |
| hidden_states = self.fc2(hidden_states) | |
| hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) | |
| hidden_states = residual + hidden_states | |
| return hidden_states | |
| def forward_chunk( | |
| self, | |
| hidden_states: torch.Tensor, | |
| kv_cache: torch.Tensor = None, | |
| ): | |
| """Forward self-attention with kv cache. | |
| Args: | |
| hidden_states: shape (b, t, c) | |
| kv_cache: shape (b, nh, t, c*2) | |
| """ | |
| # Attention | |
| residual = hidden_states | |
| hidden_states = self.self_attn_layer_norm(hidden_states) | |
| hidden_states, new_kv_cache = self.self_attn.forward_chunk( | |
| hidden_states, kv_cache | |
| ) | |
| hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) | |
| hidden_states = residual + hidden_states | |
| # FFN | |
| residual = hidden_states | |
| hidden_states = self.final_layer_norm(hidden_states) | |
| hidden_states = F.gelu(self.fc1(hidden_states)) | |
| hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) | |
| hidden_states = self.fc2(hidden_states) | |
| hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) | |
| hidden_states = residual + hidden_states | |
| return hidden_states, new_kv_cache | |
| class WhisperEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| in_dim: int, | |
| embed_dim: int, | |
| num_layers: int, | |
| num_heads: int, | |
| ffn_dim: int = None, | |
| attn_dropout: float = 0.0, | |
| dropout: float = 0.0, | |
| max_positions: int = 1500, | |
| ): | |
| super().__init__() | |
| self.in_dim = in_dim | |
| self.embed_dim = embed_dim | |
| self.dropout = dropout | |
| # Input downsampling | |
| self.conv1 = nn.Conv1d(in_dim, embed_dim, kernel_size=3, padding=1) | |
| self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) | |
| # Fixed positional embedding | |
| self.max_positions = max_positions | |
| self.embed_positions = nn.Embedding(self.max_positions, embed_dim) | |
| self.embed_positions.requires_grad_(False) | |
| # Transformer | |
| self.layers = nn.ModuleList( | |
| [ | |
| WhisperEncoderLayer( | |
| embed_dim, num_heads, ffn_dim, attn_dropout, dropout | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| # Output norm | |
| self.layer_norm = nn.LayerNorm(embed_dim) | |
| # Init weight | |
| self.apply(self._init_weights) | |
| # Init position embedding | |
| self.embed_positions.weight.copy_(sinusoids(*self.embed_positions.weight.shape)) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| hidden_length: torch.Tensor, | |
| apply_position: bool = True, | |
| ): | |
| # Downsampling | |
| hidden_states = hidden_states.transpose(1, 2) | |
| hidden_states = F.gelu(self.conv1(hidden_states)) | |
| hidden_states = F.gelu(self.conv2(hidden_states)) | |
| hidden_states = hidden_states.transpose(1, 2) | |
| hidden_length = hidden_length // 2 # from 100Hz -> 50Hz | |
| # Pos encoding | |
| if apply_position: | |
| pos_embed = self.embed_positions( | |
| torch.arange(0, hidden_states.shape[1], device=hidden_states.device) | |
| ) | |
| hidden_states = hidden_states + pos_embed | |
| hidden_states = nn.functional.dropout( | |
| hidden_states, p=self.dropout, training=self.training | |
| ) | |
| # Transformer | |
| attention_mask = make_nonpad_mask(hidden_length).unsqueeze(1) # (b, 1, t) | |
| for layer in self.layers: | |
| hidden_states = layer(hidden_states, attention_mask) | |
| hidden_states = self.layer_norm(hidden_states) | |
| return hidden_states, hidden_length | |
| def _init_weights(self, module): | |
| std = 0.02 | |
| if isinstance(module, (nn.Linear, nn.Conv1d)): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| class WhisperMelExtractor(nn.Module): | |
| def __init__( | |
| self, | |
| num_mels: int = 128, | |
| sampling_rate: int = 16000, | |
| hop_length: int = 160, | |
| n_fft: int = 400, | |
| fmin: float = 0, | |
| fmax: float = 8000, | |
| padding_value=0.0, | |
| ): | |
| super().__init__() | |
| self.num_mels = num_mels | |
| self.sampling_rate = sampling_rate | |
| self.hop_length = hop_length | |
| self.n_fft = n_fft | |
| self.fmin = fmin | |
| self.fmax = fmax | |
| self.padding_value = padding_value | |
| self.mel_filters = mel_filter_bank( | |
| num_frequency_bins=(1 + n_fft // 2), | |
| num_mel_filters=num_mels, | |
| min_frequency=fmin, | |
| max_frequency=fmax, | |
| sampling_rate=sampling_rate, | |
| norm="slaney", | |
| mel_scale="slaney", | |
| ) | |
| def extract_fbank(self, audio: torch.Tensor): | |
| """ | |
| Args: | |
| audio: batched audio of shape (b, t) | |
| """ | |
| device = audio.device # compute on cuda if input is on cuda | |
| # Mel | |
| window = torch.hann_window(self.n_fft).to(device) | |
| stft = torch.stft( | |
| audio, self.n_fft, self.hop_length, window=window, return_complex=True | |
| ) | |
| magnitudes = stft[..., :-1].abs() ** 2 | |
| mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32).to(device) | |
| mel_spec = mel_filters.T @ magnitudes | |
| log_spec = torch.clamp(mel_spec, min=1e-10).log10() | |
| # Norm | |
| max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0] | |
| log_spec = torch.maximum(log_spec, max_val - 8.0) | |
| log_spec = (log_spec + 4.0) / 4.0 | |
| return log_spec | |
| def __call__(self, audio16k: torch.Tensor, audio16k_length: torch.Tensor): | |
| mel = self.extract_fbank(audio16k).transpose(1, 2) | |
| mel_length = audio16k_length // self.hop_length | |
| # mel: (b, t, c=128) | |
| return mel, mel_length | |
| # Pretrained encoder from whisper-large-v3 | |
| class PretrainedWhisperEncoder(WhisperEncoder): | |
| def from_pretrained(cls, pretrained_path: str = None): | |
| encoder = cls( | |
| in_dim=128, | |
| embed_dim=1280, | |
| num_layers=32, | |
| num_heads=20, | |
| ffn_dim=5120, | |
| attn_dropout=0.0, | |
| max_positions=1500, | |
| ) | |
| if pretrained_path is not None: | |
| ckpt = torch.load(pretrained_path, map_location="cpu") | |
| encoder.load_state_dict(ckpt) | |
| encoder.eval() | |
| # Disable grad | |
| for p in encoder.parameters(): | |
| p.requires_grad_(False) | |
| # Add Mel extractor | |
| encoder.feature_extractor = WhisperMelExtractor( | |
| num_mels=128, | |
| sampling_rate=16000, | |
| hop_length=160, | |
| n_fft=400, | |
| fmin=0, | |
| fmax=8000, | |
| ) | |
| return encoder | |
| def forward(self, audio16k: torch.Tensor, audio16k_length: torch.Tensor): | |
| # Extract mel | |
| mel, mel_length = self.feature_extractor(audio16k, audio16k_length) | |
| # Forward model | |
| semantic_feats, semantic_length = super().forward( | |
| mel, mel_length, apply_position=True | |
| ) | |
| return semantic_feats, semantic_length | |
| class WhisperAcousticEncoder(WhisperEncoder): | |
| def __init__( | |
| self, | |
| # Mel extraction params | |
| num_mels: int = 128, | |
| sampling_rate: int = 16000, | |
| hop_length: int = 160, | |
| n_fft: int = 400, | |
| fmin: float = 0.0, | |
| fmax: float = 8000, | |
| # Encoder params | |
| embed_dim: int = 768, | |
| num_layers: int = 12, | |
| num_heads: int = 8, | |
| ffn_dim: int = None, | |
| attn_dropout: float = 0.0, | |
| dropout: float = 0.0, | |
| max_positions: int = 1500, # 50Hz * 30s | |
| ): | |
| super().__init__( | |
| in_dim=num_mels, | |
| embed_dim=embed_dim, | |
| num_layers=num_layers, | |
| num_heads=num_heads, | |
| ffn_dim=ffn_dim, | |
| attn_dropout=attn_dropout, | |
| dropout=dropout, | |
| max_positions=max_positions, | |
| ) | |
| self.feature_extractor = WhisperMelExtractor( | |
| num_mels=num_mels, | |
| sampling_rate=sampling_rate, | |
| hop_length=hop_length, | |
| n_fft=n_fft, | |
| fmin=fmin, | |
| fmax=fmax, | |
| ) | |
| def forward(self, audio16k: torch.Tensor, audio16k_length: torch.Tensor): | |
| # Extract mel | |
| with torch.no_grad(): | |
| mel, mel_length = self.feature_extractor(audio16k, audio16k_length) | |
| # Forward model | |
| hidden_states, hidden_length = super().forward( | |
| mel, mel_length, apply_position=True | |
| ) | |
| return hidden_states, hidden_length | |