Shen Feiyu
init at 250916
71cd91e
import os
import json
import torch
import torch.nn as nn
from pathlib import Path
from dataclasses import dataclass
from typing import Union
from torch.optim.lr_scheduler import LambdaLR
from transformers import AutoTokenizer
from fireredtts2.llm.llm import Model, ModelArgs
@dataclass
class Segment:
speaker: str
text: str
audio: torch.Tensor
class WarmupDecayLR(LambdaLR):
"""
Learning rate scheduler with a linear warmup and specificable decay.
"""
def __init__(
self, optimizer, warmup_steps: int, total_steps: int, decay_type: str = "linear"
):
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.decay_type = decay_type
super().__init__(optimizer, self.lr_lambda, last_epoch=-1)
def lr_lambda(self, step: int) -> float:
if step < self.warmup_steps:
return step / self.warmup_steps
else:
if self.decay_type == "linear":
return (self.total_steps - step) / (
self.total_steps - self.warmup_steps
)
elif self.decay_type == "constant":
return 1.0
elif self.decay_type == "exponential":
return 0.1 ** (
(step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
)
elif self.decay_type == "cosine":
return 0.5 * (
1
+ torch.cos(
torch.pi
* torch.tensor(
(step - self.warmup_steps)
/ (self.total_steps - self.warmup_steps)
)
)
)
else:
raise ValueError(f"Invalid decay type: {self.decay_type}")
additional_special_tokens = [
"<|text_start|>",
"<|text_end|>",
"[S1]",
"[S2]",
"[S3]",
"[S4]",
"[S5]",
"[S6]",
"[S7]",
"[S8]",
"[S9]",
"[S10]",
"[S11]",
"[S12]",
"[S13]",
"[S14]",
"[S15]",
"[S16]",
"[S17]",
"[S18]",
"[S19]",
"[S20]",
"[S21]",
"[S22]",
"[S23]",
"[S24]",
"[S25]",
"[S26]",
"[S27]",
"[S28]",
"[S29]",
"[S30]",
"[S31]",
"[S32]",
"[S33]",
"[S34]",
"[S35]",
"[S36]",
"[S37]",
"[S38]",
"[S39]",
"[S40]",
"[S_PODCAST_1]",
"[S_PODCAST_2]",
"[S_PODCAST_3]",
"[S_PODCAST_4]",
"[S_PODCAST_5]",
"[S_PODCAST_6]",
"[S_PODCAST_7]",
"[S_PODCAST_8]",
"[S_PODCAST_9]",
"[S_PODCAST_10]",
"[S_DIALOG_1]",
"[S_DIALOG_2]",
"[S_DIALOG_3]",
"[S_DIALOG_4]",
"[S_DIALOG_5]",
"[S_DIALOG_6]",
"[S_DIALOG_7]",
"[S_DIALOG_8]",
"[S_DIALOG_9]",
"[S_DIALOG_10]",
"<|emotion_neutral|>",
"<|emotion_happy|>",
"<|emotion_sad|>",
"<|emotion_concern|>",
"<|emotion_confuse|>",
"<|emotion_angry|>",
"<|emotion_surprise|>",
"<|emotion_disgust|>",
"<|emotion_nervous|>",
"<|emotion_apology|>",
"<|emotion_understand|>",
"<|emotion_fear|>",
"<|emotion_comfort|>",
"<|emotion_shy|>",
"<|emotion_serious|>",
"<|emotion_extra1|>",
"<|emotion_extra2|>",
"<|emotion_extra3|>",
"<|emotion_extra4|>",
"<|emotion_extra5|>",
"<|emotion_extra6|>",
"<|emotion_extra7|>",
"<|emotion_extra8|>",
"<|emotion_extra9|>",
"<|emotion_extra10|>",
"<|breath|>",
"<|humph|>",
"<|laugh_heng|>",
"<|hissing|>",
"<|sniff|>",
"<|laugh_he|>",
"<|sigh|>",
"<|laugh|>",
"<|laugh_ha|>",
"<|quick_breath|>",
"<|laugh_hei|>",
"<|laugh_speak|>",
"<|/laugh_speak|>",
"<|cry|>",
"<|choking|>",
"<|cry_speak|>",
"<|/cry_speak|>",
"<|slurp|>",
"<|clucking|>",
"<|yawning|>",
"<|cough|>",
"<|smack|>",
"<|hem|>",
"<|stretch|>",
"<|sneeze|>",
"<|paralinguistic_extra1|>",
"<|paralinguistic_extra2|>",
"<|paralinguistic_extra3|>",
"<|paralinguistic_extra4|>",
"<|paralinguistic_extra5|>",
"<|paralinguistic_extra6|>",
"<|paralinguistic_extra7|>",
"<|paralinguistic_extra8|>",
"<|paralinguistic_extra10|>",
"<|paralinguistic_extra11|>",
"<|paralinguistic_extra12|>",
"<|paralinguistic_extra13|>",
]
def load_custom_tokenizer(qwen2_tokenizer_path: str):
tok = AutoTokenizer.from_pretrained(qwen2_tokenizer_path)
special_tokens_dict = {
"additional_special_tokens": additional_special_tokens,
}
tok.add_special_tokens(special_tokens_dict)
return tok
def init_weights(model: nn.Module):
"""
Initialize the weights of the model.
- Xavier uniform initialization for linear layers
- Normal initialization for embeddings
- Xavier uniform initialization for parameters
"""
def _init_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
elif isinstance(m, nn.Parameter):
nn.init.xavier_uniform_(m.data)
model.apply(_init_weights)
# Special handling for audio_head because it's nn.Parameter directly
nn.init.xavier_uniform_(model.audio_head)
return model
def load_llm_model(
configs,
checkpoint_path: Union[str, Path] = None,
device: Union[str, torch.device] = "cuda",
) -> Model:
"""Load model, add forward method, and move to device.
Args:
model_name_or_checkpoint_path: Name or path of pretrained model or checkpoint.
device: Device to move the model to.
decoder_loss_weight: Decoder loss weight.
"""
model_arg = ModelArgs(
backbone_flavor=configs["llm_models"]["backbone_flavor"],
decoder_flavor=configs["llm_models"]["decoder_flavor"],
text_vocab_size=configs["llm_models"]["text_vocab_size"],
audio_vocab_size=configs["llm_models"]["audio_vocab_size"],
audio_num_codebooks=configs["llm_models"]["audio_num_codebooks"],
decoder_loss_weight=configs["llm_models"]["decoder_loss_weight"],
use_text_loss=True,
)
model = Model(model_arg)
if checkpoint_path and os.path.exists(checkpoint_path):
state_dict = torch.load(
checkpoint_path, map_location="cpu", weights_only=False
)["model"]
model.load_state_dict(state_dict)
else:
model = init_weights(model)
model = model.to(device=device)
return model
def summarize(
writer,
global_step,
scalars={},
histograms={},
images={},
audios={},
audio_sampling_rate=22050,
):
for k, v in scalars.items():
writer.add_scalar(k, v, global_step)
for k, v in histograms.items():
writer.add_histogram(k, v, global_step)
for k, v in images.items():
writer.add_image(k, v, global_step, dataformats="HWC")
for k, v in audios.items():
writer.add_audio(k, v, global_step, audio_sampling_rate)
def get_grad_norm(model):
total_norm = 0
num = 0
for name, p in model.named_parameters():
try:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
num += 1
except:
print(name)
total_norm = total_norm ** (1.0 / 2)
total_norm = total_norm / num
return total_norm
def read_jsonl(path):
path = os.path.expanduser(path)
with open(path, "r") as f:
json_str = f.read()
data_list = []
for line in json_str.splitlines():
data = json.loads(line)
data_list.append(data)
return data_list