Shen Feiyu
init at 250916
71cd91e
# Extracted from transformers' WhisperModel to simplify package dependency
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Literal
from fireredtts2.codec.utils import make_nonpad_mask
from fireredtts2.codec.audio import mel_filter_bank
def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
"""Returns sinusoids for positional embedding"""
if channels % 2 != 0:
raise ValueError(
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
)
log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
class WhisperSdpaAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.bias = bias
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
"""
Args:
attention_mask: Bool mask or float mask. Bool mask, True indicates should attend. Float mask is added to the attention score.
"""
bsz, tgt_len, _ = hidden_states.size()
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
key_states = self._shape(self.k_proj(hidden_states), tgt_len, bsz)
value_states = self._shape(self.v_proj(hidden_states), tgt_len, bsz)
# NOTE sdpa needs a 4-dim attention_mask: (b, nh, tq, tv)
if attention_mask is not None and len(attention_mask.shape) == 3:
attention_mask = attention_mask.unsqueeze(1)
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
) # (bsz, nh, l, d)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
def forward_chunk(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor = None,
):
"""Forward self-attention with kv cache.
Args:
hidden_states: shape (b, t, c)
kv_cache: shape (b, nh, t, c*2)
"""
bsz, tgt_len, _ = hidden_states.size()
# shape (b, nh, t, c)
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
key_states = self._shape(self.k_proj(hidden_states), tgt_len, bsz)
value_states = self._shape(self.v_proj(hidden_states), tgt_len, bsz)
# unpack cache
if kv_cache is not None:
k_cache, v_cache = kv_cache.chunk(2, dim=-1)
key_states = torch.cat([k_cache, key_states], dim=2)
value_states = torch.cat([v_cache, value_states], dim=2)
new_kv_cache = torch.cat([key_states, value_states], dim=-1)
# attention
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=None,
dropout_p=0.0,
) # (bsz, nh, l, d)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, new_kv_cache
class WhisperEncoderLayer(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
ffn_dim: int = None,
attn_dropout: float = 0.0,
dropout: float = 0.0,
):
super().__init__()
self.dropout = dropout
# Attention
self.self_attn = WhisperSdpaAttention(embed_dim, num_heads, attn_dropout)
self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
# FFN
ffn_dim = ffn_dim if ffn_dim is not None else embed_dim * 4
self.fc1 = nn.Linear(embed_dim, ffn_dim)
self.fc2 = nn.Linear(ffn_dim, embed_dim)
# Output norm
self.final_layer_norm = nn.LayerNorm(embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
):
# Attention
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states, attention_mask)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# FFN
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = F.gelu(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
return hidden_states
def forward_chunk(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor = None,
):
"""Forward self-attention with kv cache.
Args:
hidden_states: shape (b, t, c)
kv_cache: shape (b, nh, t, c*2)
"""
# Attention
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, new_kv_cache = self.self_attn.forward_chunk(
hidden_states, kv_cache
)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# FFN
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = F.gelu(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
return hidden_states, new_kv_cache
class WhisperEncoder(nn.Module):
def __init__(
self,
in_dim: int,
embed_dim: int,
num_layers: int,
num_heads: int,
ffn_dim: int = None,
attn_dropout: float = 0.0,
dropout: float = 0.0,
max_positions: int = 1500,
):
super().__init__()
self.in_dim = in_dim
self.embed_dim = embed_dim
self.dropout = dropout
# Input downsampling
self.conv1 = nn.Conv1d(in_dim, embed_dim, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
# Fixed positional embedding
self.max_positions = max_positions
self.embed_positions = nn.Embedding(self.max_positions, embed_dim)
self.embed_positions.requires_grad_(False)
# Transformer
self.layers = nn.ModuleList(
[
WhisperEncoderLayer(
embed_dim, num_heads, ffn_dim, attn_dropout, dropout
)
for _ in range(num_layers)
]
)
# Output norm
self.layer_norm = nn.LayerNorm(embed_dim)
# Init weight
self.apply(self._init_weights)
# Init position embedding
self.embed_positions.weight.copy_(sinusoids(*self.embed_positions.weight.shape))
def forward(
self,
hidden_states: torch.Tensor,
hidden_length: torch.Tensor,
apply_position: bool = True,
):
# Downsampling
hidden_states = hidden_states.transpose(1, 2)
hidden_states = F.gelu(self.conv1(hidden_states))
hidden_states = F.gelu(self.conv2(hidden_states))
hidden_states = hidden_states.transpose(1, 2)
hidden_length = hidden_length // 2 # from 100Hz -> 50Hz
# Pos encoding
if apply_position:
pos_embed = self.embed_positions(
torch.arange(0, hidden_states.shape[1], device=hidden_states.device)
)
hidden_states = hidden_states + pos_embed
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
# Transformer
attention_mask = make_nonpad_mask(hidden_length).unsqueeze(1) # (b, 1, t)
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask)
hidden_states = self.layer_norm(hidden_states)
return hidden_states, hidden_length
def _init_weights(self, module):
std = 0.02
if isinstance(module, (nn.Linear, nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class WhisperMelExtractor(nn.Module):
def __init__(
self,
num_mels: int = 128,
sampling_rate: int = 16000,
hop_length: int = 160,
n_fft: int = 400,
fmin: float = 0,
fmax: float = 8000,
padding_value=0.0,
):
super().__init__()
self.num_mels = num_mels
self.sampling_rate = sampling_rate
self.hop_length = hop_length
self.n_fft = n_fft
self.fmin = fmin
self.fmax = fmax
self.padding_value = padding_value
self.mel_filters = mel_filter_bank(
num_frequency_bins=(1 + n_fft // 2),
num_mel_filters=num_mels,
min_frequency=fmin,
max_frequency=fmax,
sampling_rate=sampling_rate,
norm="slaney",
mel_scale="slaney",
)
def extract_fbank(self, audio: torch.Tensor):
"""
Args:
audio: batched audio of shape (b, t)
"""
device = audio.device # compute on cuda if input is on cuda
# Mel
window = torch.hann_window(self.n_fft).to(device)
stft = torch.stft(
audio, self.n_fft, self.hop_length, window=window, return_complex=True
)
magnitudes = stft[..., :-1].abs() ** 2
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32).to(device)
mel_spec = mel_filters.T @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
# Norm
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
log_spec = torch.maximum(log_spec, max_val - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
def __call__(self, audio16k: torch.Tensor, audio16k_length: torch.Tensor):
mel = self.extract_fbank(audio16k).transpose(1, 2)
mel_length = audio16k_length // self.hop_length
# mel: (b, t, c=128)
return mel, mel_length
# Pretrained encoder from whisper-large-v3
class PretrainedWhisperEncoder(WhisperEncoder):
@classmethod
def from_pretrained(cls, pretrained_path: str = None):
encoder = cls(
in_dim=128,
embed_dim=1280,
num_layers=32,
num_heads=20,
ffn_dim=5120,
attn_dropout=0.0,
max_positions=1500,
)
if pretrained_path is not None:
ckpt = torch.load(pretrained_path, map_location="cpu")
encoder.load_state_dict(ckpt)
encoder.eval()
# Disable grad
for p in encoder.parameters():
p.requires_grad_(False)
# Add Mel extractor
encoder.feature_extractor = WhisperMelExtractor(
num_mels=128,
sampling_rate=16000,
hop_length=160,
n_fft=400,
fmin=0,
fmax=8000,
)
return encoder
@torch.inference_mode()
def forward(self, audio16k: torch.Tensor, audio16k_length: torch.Tensor):
# Extract mel
mel, mel_length = self.feature_extractor(audio16k, audio16k_length)
# Forward model
semantic_feats, semantic_length = super().forward(
mel, mel_length, apply_position=True
)
return semantic_feats, semantic_length
class WhisperAcousticEncoder(WhisperEncoder):
def __init__(
self,
# Mel extraction params
num_mels: int = 128,
sampling_rate: int = 16000,
hop_length: int = 160,
n_fft: int = 400,
fmin: float = 0.0,
fmax: float = 8000,
# Encoder params
embed_dim: int = 768,
num_layers: int = 12,
num_heads: int = 8,
ffn_dim: int = None,
attn_dropout: float = 0.0,
dropout: float = 0.0,
max_positions: int = 1500, # 50Hz * 30s
):
super().__init__(
in_dim=num_mels,
embed_dim=embed_dim,
num_layers=num_layers,
num_heads=num_heads,
ffn_dim=ffn_dim,
attn_dropout=attn_dropout,
dropout=dropout,
max_positions=max_positions,
)
self.feature_extractor = WhisperMelExtractor(
num_mels=num_mels,
sampling_rate=sampling_rate,
hop_length=hop_length,
n_fft=n_fft,
fmin=fmin,
fmax=fmax,
)
def forward(self, audio16k: torch.Tensor, audio16k_length: torch.Tensor):
# Extract mel
with torch.no_grad():
mel, mel_length = self.feature_extractor(audio16k, audio16k_length)
# Forward model
hidden_states, hidden_length = super().forward(
mel, mel_length, apply_position=True
)
return hidden_states, hidden_length