Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import torch | |
| import torch.nn as nn | |
| from pathlib import Path | |
| from dataclasses import dataclass | |
| from typing import Union | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from transformers import AutoTokenizer | |
| from fireredtts2.llm.llm import Model, ModelArgs | |
| class Segment: | |
| speaker: str | |
| text: str | |
| audio: torch.Tensor | |
| class WarmupDecayLR(LambdaLR): | |
| """ | |
| Learning rate scheduler with a linear warmup and specificable decay. | |
| """ | |
| def __init__( | |
| self, optimizer, warmup_steps: int, total_steps: int, decay_type: str = "linear" | |
| ): | |
| self.warmup_steps = warmup_steps | |
| self.total_steps = total_steps | |
| self.decay_type = decay_type | |
| super().__init__(optimizer, self.lr_lambda, last_epoch=-1) | |
| def lr_lambda(self, step: int) -> float: | |
| if step < self.warmup_steps: | |
| return step / self.warmup_steps | |
| else: | |
| if self.decay_type == "linear": | |
| return (self.total_steps - step) / ( | |
| self.total_steps - self.warmup_steps | |
| ) | |
| elif self.decay_type == "constant": | |
| return 1.0 | |
| elif self.decay_type == "exponential": | |
| return 0.1 ** ( | |
| (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) | |
| ) | |
| elif self.decay_type == "cosine": | |
| return 0.5 * ( | |
| 1 | |
| + torch.cos( | |
| torch.pi | |
| * torch.tensor( | |
| (step - self.warmup_steps) | |
| / (self.total_steps - self.warmup_steps) | |
| ) | |
| ) | |
| ) | |
| else: | |
| raise ValueError(f"Invalid decay type: {self.decay_type}") | |
| additional_special_tokens = [ | |
| "<|text_start|>", | |
| "<|text_end|>", | |
| "[S1]", | |
| "[S2]", | |
| "[S3]", | |
| "[S4]", | |
| "[S5]", | |
| "[S6]", | |
| "[S7]", | |
| "[S8]", | |
| "[S9]", | |
| "[S10]", | |
| "[S11]", | |
| "[S12]", | |
| "[S13]", | |
| "[S14]", | |
| "[S15]", | |
| "[S16]", | |
| "[S17]", | |
| "[S18]", | |
| "[S19]", | |
| "[S20]", | |
| "[S21]", | |
| "[S22]", | |
| "[S23]", | |
| "[S24]", | |
| "[S25]", | |
| "[S26]", | |
| "[S27]", | |
| "[S28]", | |
| "[S29]", | |
| "[S30]", | |
| "[S31]", | |
| "[S32]", | |
| "[S33]", | |
| "[S34]", | |
| "[S35]", | |
| "[S36]", | |
| "[S37]", | |
| "[S38]", | |
| "[S39]", | |
| "[S40]", | |
| "[S_PODCAST_1]", | |
| "[S_PODCAST_2]", | |
| "[S_PODCAST_3]", | |
| "[S_PODCAST_4]", | |
| "[S_PODCAST_5]", | |
| "[S_PODCAST_6]", | |
| "[S_PODCAST_7]", | |
| "[S_PODCAST_8]", | |
| "[S_PODCAST_9]", | |
| "[S_PODCAST_10]", | |
| "[S_DIALOG_1]", | |
| "[S_DIALOG_2]", | |
| "[S_DIALOG_3]", | |
| "[S_DIALOG_4]", | |
| "[S_DIALOG_5]", | |
| "[S_DIALOG_6]", | |
| "[S_DIALOG_7]", | |
| "[S_DIALOG_8]", | |
| "[S_DIALOG_9]", | |
| "[S_DIALOG_10]", | |
| "<|emotion_neutral|>", | |
| "<|emotion_happy|>", | |
| "<|emotion_sad|>", | |
| "<|emotion_concern|>", | |
| "<|emotion_confuse|>", | |
| "<|emotion_angry|>", | |
| "<|emotion_surprise|>", | |
| "<|emotion_disgust|>", | |
| "<|emotion_nervous|>", | |
| "<|emotion_apology|>", | |
| "<|emotion_understand|>", | |
| "<|emotion_fear|>", | |
| "<|emotion_comfort|>", | |
| "<|emotion_shy|>", | |
| "<|emotion_serious|>", | |
| "<|emotion_extra1|>", | |
| "<|emotion_extra2|>", | |
| "<|emotion_extra3|>", | |
| "<|emotion_extra4|>", | |
| "<|emotion_extra5|>", | |
| "<|emotion_extra6|>", | |
| "<|emotion_extra7|>", | |
| "<|emotion_extra8|>", | |
| "<|emotion_extra9|>", | |
| "<|emotion_extra10|>", | |
| "<|breath|>", | |
| "<|humph|>", | |
| "<|laugh_heng|>", | |
| "<|hissing|>", | |
| "<|sniff|>", | |
| "<|laugh_he|>", | |
| "<|sigh|>", | |
| "<|laugh|>", | |
| "<|laugh_ha|>", | |
| "<|quick_breath|>", | |
| "<|laugh_hei|>", | |
| "<|laugh_speak|>", | |
| "<|/laugh_speak|>", | |
| "<|cry|>", | |
| "<|choking|>", | |
| "<|cry_speak|>", | |
| "<|/cry_speak|>", | |
| "<|slurp|>", | |
| "<|clucking|>", | |
| "<|yawning|>", | |
| "<|cough|>", | |
| "<|smack|>", | |
| "<|hem|>", | |
| "<|stretch|>", | |
| "<|sneeze|>", | |
| "<|paralinguistic_extra1|>", | |
| "<|paralinguistic_extra2|>", | |
| "<|paralinguistic_extra3|>", | |
| "<|paralinguistic_extra4|>", | |
| "<|paralinguistic_extra5|>", | |
| "<|paralinguistic_extra6|>", | |
| "<|paralinguistic_extra7|>", | |
| "<|paralinguistic_extra8|>", | |
| "<|paralinguistic_extra10|>", | |
| "<|paralinguistic_extra11|>", | |
| "<|paralinguistic_extra12|>", | |
| "<|paralinguistic_extra13|>", | |
| ] | |
| def load_custom_tokenizer(qwen2_tokenizer_path: str): | |
| tok = AutoTokenizer.from_pretrained(qwen2_tokenizer_path) | |
| special_tokens_dict = { | |
| "additional_special_tokens": additional_special_tokens, | |
| } | |
| tok.add_special_tokens(special_tokens_dict) | |
| return tok | |
| def init_weights(model: nn.Module): | |
| """ | |
| Initialize the weights of the model. | |
| - Xavier uniform initialization for linear layers | |
| - Normal initialization for embeddings | |
| - Xavier uniform initialization for parameters | |
| """ | |
| def _init_weights(m): | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.zeros_(m.bias) | |
| elif isinstance(m, nn.Embedding): | |
| nn.init.normal_(m.weight, mean=0.0, std=0.02) | |
| elif isinstance(m, nn.Parameter): | |
| nn.init.xavier_uniform_(m.data) | |
| model.apply(_init_weights) | |
| # Special handling for audio_head because it's nn.Parameter directly | |
| nn.init.xavier_uniform_(model.audio_head) | |
| return model | |
| def load_llm_model( | |
| configs, | |
| checkpoint_path: Union[str, Path] = None, | |
| device: Union[str, torch.device] = "cuda", | |
| ) -> Model: | |
| """Load model, add forward method, and move to device. | |
| Args: | |
| model_name_or_checkpoint_path: Name or path of pretrained model or checkpoint. | |
| device: Device to move the model to. | |
| decoder_loss_weight: Decoder loss weight. | |
| """ | |
| model_arg = ModelArgs( | |
| backbone_flavor=configs["llm_models"]["backbone_flavor"], | |
| decoder_flavor=configs["llm_models"]["decoder_flavor"], | |
| text_vocab_size=configs["llm_models"]["text_vocab_size"], | |
| audio_vocab_size=configs["llm_models"]["audio_vocab_size"], | |
| audio_num_codebooks=configs["llm_models"]["audio_num_codebooks"], | |
| decoder_loss_weight=configs["llm_models"]["decoder_loss_weight"], | |
| use_text_loss=True, | |
| ) | |
| model = Model(model_arg) | |
| if checkpoint_path and os.path.exists(checkpoint_path): | |
| state_dict = torch.load( | |
| checkpoint_path, map_location="cpu", weights_only=False | |
| )["model"] | |
| model.load_state_dict(state_dict) | |
| else: | |
| model = init_weights(model) | |
| model = model.to(device=device) | |
| return model | |
| def summarize( | |
| writer, | |
| global_step, | |
| scalars={}, | |
| histograms={}, | |
| images={}, | |
| audios={}, | |
| audio_sampling_rate=22050, | |
| ): | |
| for k, v in scalars.items(): | |
| writer.add_scalar(k, v, global_step) | |
| for k, v in histograms.items(): | |
| writer.add_histogram(k, v, global_step) | |
| for k, v in images.items(): | |
| writer.add_image(k, v, global_step, dataformats="HWC") | |
| for k, v in audios.items(): | |
| writer.add_audio(k, v, global_step, audio_sampling_rate) | |
| def get_grad_norm(model): | |
| total_norm = 0 | |
| num = 0 | |
| for name, p in model.named_parameters(): | |
| try: | |
| param_norm = p.grad.data.norm(2) | |
| total_norm += param_norm.item() ** 2 | |
| num += 1 | |
| except: | |
| print(name) | |
| total_norm = total_norm ** (1.0 / 2) | |
| total_norm = total_norm / num | |
| return total_norm | |
| def read_jsonl(path): | |
| path = os.path.expanduser(path) | |
| with open(path, "r") as f: | |
| json_str = f.read() | |
| data_list = [] | |
| for line in json_str.splitlines(): | |
| data = json.loads(line) | |
| data_list.append(data) | |
| return data_list | |