Shen Feiyu
init at 250916
71cd91e
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from huggingface_hub import PyTorchModelHubMixin
from fireredtts2.llm.modules import FLAVORS
def _prepare_transformer(model):
embed_dim = model.tok_embeddings.embedding_dim
model.tok_embeddings = nn.Identity()
model.output = nn.Identity()
return model, embed_dim
def _create_causal_mask(seq_len: int, device: torch.device):
return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
"""
Args:
mask: (max_seq_len, max_seq_len)
input_pos: (batch_size, seq_len)
Returns:
(batch_size, seq_len, max_seq_len)
"""
r = mask[input_pos, :]
return r
# Does multinomial sampling without a cuda synchronization
def _multinomial_sample_one_no_sync(probs):
q = torch.empty_like(probs).exponential_(1)
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
logits = logits / temperature
filter_value: float = -float("Inf")
indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
scores_processed = logits.masked_fill(indices_to_remove, filter_value)
scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
probs = torch.nn.functional.softmax(scores_processed, dim=-1)
sample_token = _multinomial_sample_one_no_sync(probs)
return sample_token
def sample_top_nsigma(logits: torch.Tensor, n: float, temperature: float):
"""_summary_
Args:
logits (torch.Tensor): _description_
n (float): _description_
temperature (float): _description_
Returns:
_type_: _description_
"""
logits = logits / temperature
threshold = logits.max(dim=-1, keepdim=True).values - n * logits.std(
dim=-1, keepdim=True
)
logits[logits < threshold] = float("-inf")
# scores_processed = torch.nn.functional.log_softmax(logits, dim=-1)
probs = torch.nn.functional.softmax(logits, dim=-1)
sample_token = _multinomial_sample_one_no_sync(probs)
return sample_token
@dataclass
class ModelArgs:
backbone_flavor: str
decoder_flavor: str
text_vocab_size: int
audio_vocab_size: int
audio_num_codebooks: int
decoder_loss_weight: float
use_text_loss: bool
class Model(nn.Module, PyTorchModelHubMixin):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.backbone, backbone_dim = _prepare_transformer(
FLAVORS[config.backbone_flavor]()
)
self.decoder, decoder_dim = _prepare_transformer(
FLAVORS[config.decoder_flavor]()
)
self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
self.audio_embeddings = nn.Embedding(
config.audio_vocab_size * config.audio_num_codebooks, backbone_dim
)
self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
self.text_head = nn.Linear(backbone_dim, config.text_vocab_size, bias=False)
self.codebook0_head = nn.Linear(
backbone_dim, config.audio_vocab_size, bias=False
)
self.audio_head = nn.Parameter(
torch.empty(
config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size
)
)
self.decoder_loss_weight = config.decoder_loss_weight
self.use_text_loss = config.use_text_loss
# debug
# print("---backbone_dim:", backbone_dim)
# print("---decoder_dim:", decoder_dim)
# print("---self.decoder_loss_weight:", self.decoder_loss_weight)
# print("---self.use_text_loss:", self.use_text_loss)
def setup_caches(self, max_batch_size: int) -> torch.Tensor:
"""Setup KV caches and return a causal mask."""
dtype = next(self.parameters()).dtype
device = next(self.parameters()).device
with device:
self.backbone.setup_caches(max_batch_size, dtype)
self.decoder.setup_caches(
max_batch_size,
dtype,
decoder_max_seq_len=self.config.audio_num_codebooks,
)
self.register_buffer(
"backbone_causal_mask",
_create_causal_mask(self.backbone.max_seq_len, device),
)
self.register_buffer(
"decoder_causal_mask",
_create_causal_mask(self.config.audio_num_codebooks, device),
)
def forward(self, tokens: torch.Tensor, tokens_mask: torch.Tensor):
"""
Forward pass for Sesame's CSM model.
This will be added to the model with `model.forward = types.MethodType(forward, model)`
Args:
tokens: (batch_size, seq_len, n_codebooks+1)
tokens_mask: (batch_size, seq_len, n_codebooks+1)
"""
dtype = next(self.parameters()).dtype
bsz, seq_len, _ = tokens.size()
device = tokens.device
# print("---tokens:\n", tokens, tokens.shape)
# print("---tokens_mask:\n", tokens_mask, tokens_mask.shape)
# print("---bsz:", bsz)
# print("---seq_len:", seq_len)
# embed tokens
embeds = self._embed_tokens(tokens) # (bsz,seq_len,33,2048)
# print("---embeds:\n", embeds, embeds.shape)
# get targets and codebook embeddings corresponding to audio tokens
audio_mask = tokens_mask[:, :, 0] # [bsz, seq_len]
target_tokens = tokens[audio_mask][:, :-1] # [audio_len, n_codebooks]
# [audio_len, n_codebooks, embed_dim]
c_embeds = embeds[:, :, :-1, :][audio_mask]
# print("---audio_mask:\n", audio_mask, audio_mask.shape)
# print("---target_tokens:\n", target_tokens, target_tokens.shape)
# get targets corresponding to text tokens
text_mask = tokens_mask[:, :, -1]
text_target_mask = torch.roll(input=text_mask, shifts=1, dims=1)
text_target_tokens = tokens[text_target_mask][:, -1]
# print("---text_target_mask:\n", text_target_mask, text_target_mask.shape)
# print("---target_text_tokens:\n", text_target_tokens, text_target_tokens.shape)
# print("\n\n")
# retain just non-padding embeddings
masked_embeds = embeds * tokens_mask.unsqueeze(-1)
h = masked_embeds.sum(dim=2)
# backbone forward pass
# [bsz, seq_len]
padding_mask = tokens_mask[:, :, 0] | tokens_mask[:, :, -1]
# [seq_len, seq_len]
backbone_attn_mask = _create_causal_mask(seq_len, device)
# [bsz, seq_len, seq_len]
padding_3d = padding_mask.unsqueeze(-1) * padding_mask.unsqueeze(1)
backbone_attn_mask = backbone_attn_mask.unsqueeze(0) * padding_3d
backbone_attn_mask = backbone_attn_mask | torch.eye(
seq_len, device=device
).bool().unsqueeze(0).expand(bsz, -1, -1)
input_pos = (
torch.arange(0, seq_len).unsqueeze(0).expand(bsz, seq_len).long().to(device)
)
h = self.backbone(h, input_pos=input_pos, mask=backbone_attn_mask).to(
dtype=dtype
)
# print("---h:\n", h, h.shape)
# get backbone embeddings used for audio codebook prediction predict first codebook and compute loss
audio_mask = torch.roll(audio_mask, -1, 1) # shift audio mask to the right by 1
audio_h = h[audio_mask] # [audio_len, embed_dim]
# print("---audio_mask after shift:\n", audio_mask, audio_mask.shape)
c0_logits = self.codebook0_head(audio_h) # [audio_len, audio_vocab_size]
c0_target = target_tokens[:, 0] # [audio_len]
c0_loss = F.cross_entropy(c0_logits, c0_target)
# predict text loss
text_h = h[text_mask]
text_logits = self.text_head(text_h)
text_loss = F.cross_entropy(text_logits, text_target_tokens, ignore_index=0)
# print("---text_h:\n", text_h, text_h.shape)
# print("---text_logits:\n", text_logits)
# print("---text_loss:", text_loss)
# "compute amortization" (train decoder on random 1/16 subset of audio tokens)
# important change to 1/8
# indices = torch.randperm(c_embeds.size(0))[: c_embeds.size(0) // 16]
indices = torch.randperm(c_embeds.size(0))[: c_embeds.size(0) // 8]
# [audio_len//16, n_codebooks-1, embed_dim]
c_embeds = c_embeds[indices][:, :-1, :]
audio_h = audio_h[indices] # [audio_len//16, embed_dim]
target_tokens = target_tokens[indices][:, 1:] # [audio_len//16, n_codebooks-1]
# concatenate backbone embeddings and codebook embeddings for decoder input
# [audio_len//16, n_codebooks, embed_dim]
decoder_embeds = torch.cat([audio_h.unsqueeze(1), c_embeds], dim=1)
N, n_codebooks, _ = decoder_embeds.size()
c_pos = (
torch.arange(0, n_codebooks)
.unsqueeze(0)
.expand(N, n_codebooks)
.long()
.to(device)
)
decoder_causal_mask = _create_causal_mask(
decoder_embeds.size(1), device
).expand(N, -1, -1)
decoder_h = self.decoder(
self.projection(decoder_embeds), input_pos=c_pos, mask=decoder_causal_mask
).to(dtype=dtype)
c_logits = torch.einsum("bsd,sdv->bsv", decoder_h[:, 1:, :], self.audio_head)
c_loss = F.cross_entropy(
c_logits.reshape(-1, c_logits.size(-1)), target_tokens.reshape(-1)
)
if self.use_text_loss:
loss = (
2
* (
(1 - self.decoder_loss_weight) * c0_loss
+ self.decoder_loss_weight * c_loss
)
+ 0.01 * text_loss
)
else:
loss = 2 * (
(1 - self.decoder_loss_weight) * c0_loss
+ self.decoder_loss_weight * c_loss
)
return loss, text_loss, c0_loss, c_loss
def generate_frame(
self,
tokens: torch.Tensor,
tokens_mask: torch.Tensor,
input_pos: torch.Tensor,
temperature: float,
topk: int,
) -> torch.Tensor:
"""
Args:
tokens: (batch_size, seq_len, audio_num_codebooks+1)
tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
input_pos: (batch_size, seq_len) positions for each token
mask: (batch_size, seq_len, max_seq_len
Returns:
(batch_size, audio_num_codebooks) sampled tokens
"""
dtype = next(self.parameters()).dtype
b, s, _ = tokens.size()
assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
embeds = self._embed_tokens(tokens)
masked_embeds = embeds * tokens_mask.unsqueeze(-1)
h = masked_embeds.sum(dim=2)
h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(
dtype=dtype
)
last_h = h[:, -1, :]
c0_logits = self.codebook0_head(last_h)
c0_sample = sample_topk(c0_logits, topk, temperature)
c0_embed = self._embed_audio(0, c0_sample)
curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
curr_sample = c0_sample.clone()
curr_pos = (
torch.arange(0, curr_h.size(1), device=curr_h.device)
.unsqueeze(0)
.repeat(curr_h.size(0), 1)
)
# Decoder caches must be reset every frame.
self.decoder.reset_caches()
for i in range(1, self.config.audio_num_codebooks):
curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
decoder_h = self.decoder(
self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask
).to(dtype=dtype)
ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
ci_sample = sample_topk(ci_logits, 10, 0.75) # fix to 10 and 0.75
ci_embed = self._embed_audio(i, ci_sample)
curr_h = ci_embed
curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
curr_pos = curr_pos[:, -1:] + 1
return curr_sample
def reset_caches(self):
self.backbone.reset_caches()
self.decoder.reset_caches()
def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
audio_tokens = tokens[:, :, :-1] + (
self.config.audio_vocab_size
* torch.arange(self.config.audio_num_codebooks, device=tokens.device)
)
audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
)
return torch.cat([audio_embeds, text_embeds], dim=-2)
if __name__ == "__main__":
MIMI_SAMPLE_RATE = 24000
BACKBONE_FLAVOR = "qwen-3b"
DECODER_FLAVOR = "qwen-500m"
TEXT_VOCAB_SIZE = 128256
AUDIO_VOCAB_SIZE = 2051
AUDIO_NUM_CODEBOOKS = 32
config = ModelArgs(
backbone_flavor=BACKBONE_FLAVOR,
decoder_flavor=DECODER_FLAVOR,
text_vocab_size=TEXT_VOCAB_SIZE,
audio_vocab_size=AUDIO_VOCAB_SIZE,
audio_num_codebooks=AUDIO_NUM_CODEBOOKS,
decoder_loss_weight=0.5,
use_text_loss=True,
)
model = Model(config)