Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| import json | |
| import torch | |
| import torchaudio | |
| from typing import List, Tuple | |
| from fireredtts2.codec import RedCodecInfer | |
| from fireredtts2.llm import load_llm_model, load_custom_tokenizer | |
| from fireredtts2.llm.utils import Segment | |
| from fireredtts2.utils.spliter import clean_text, split_text, process_text_list | |
| from tqdm import tqdm | |
| class FireRedTTS2: | |
| def __init__(self, pretrained_dir, gen_type, device): | |
| assert os.path.exists(pretrained_dir) | |
| assert gen_type in ["monologue", "dialogue"] | |
| llm_config_path = os.path.join(pretrained_dir, "config_llm.json") | |
| if gen_type == "monologue": | |
| llm_ckpt_path = os.path.join(pretrained_dir, "llm_pretrain.pt") | |
| # llm_ckpt_path = os.path.join(pretrained_dir, "llm_posttrain.pt") | |
| else: | |
| llm_ckpt_path = os.path.join(pretrained_dir, "llm_posttrain.pt") | |
| codec_config_path = os.path.join(pretrained_dir, "config_codec.json") | |
| codec_ckpt_path = os.path.join(pretrained_dir, "codec.pt") | |
| pretrained_qwen_path = os.path.join(pretrained_dir, "Qwen2.5-1.5B") | |
| # check | |
| assert os.path.exists(llm_config_path) | |
| assert os.path.exists(llm_ckpt_path) | |
| assert os.path.exists(codec_config_path) | |
| assert os.path.exists(codec_ckpt_path) | |
| assert os.path.exists(pretrained_qwen_path) | |
| # ==== Load Torch LLM ==== | |
| llm_config = json.load(open(llm_config_path)) | |
| self._model = load_llm_model( | |
| configs=llm_config, checkpoint_path=llm_ckpt_path, device=device | |
| ) | |
| self._model.eval() | |
| self._model.setup_caches(1) | |
| print("[INFO] LLM Loaded...") | |
| # ==== Load Qwen2.5 Text Tokenizer ==== | |
| self._text_tokenizer = load_custom_tokenizer(pretrained_qwen_path) | |
| print("[INFO] Text Tokenizer Loaded...") | |
| # ==== Load Torch Audio Tokenizer ==== | |
| torch_codec = RedCodecInfer.from_pretrained(codec_config_path, codec_ckpt_path) | |
| torch_codec.eval() | |
| self._audio_tokenizer = torch_codec.to(device) | |
| print("[INFO] Codec Loaded...") | |
| self.sample_rate = 16000 | |
| self.device = device | |
| self.max_seq_len = 3100 | |
| def load_prompt_audio(self, audio_path) -> torch.Tensor: | |
| audio, audio_sr = torchaudio.load(audio_path) | |
| # Audio must be single channel | |
| if audio.shape[0] > 1: | |
| audio = audio[0, :].unsqueeze(0) | |
| audio16k = torchaudio.functional.resample(audio, audio_sr, 16000) | |
| return audio16k | |
| def prepare_prompt(self, text, speaker, audio_path) -> Segment: | |
| audio_tensor = self.load_prompt_audio(audio_path) | |
| return Segment(text=text, speaker=speaker, audio=audio_tensor) | |
| def _tokenize_text_segment( | |
| self, text: str, speaker: str | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| frame_tokens = [] | |
| frame_masks = [] | |
| text = speaker + "<|text_start|>" + text + "<|text_end|>" | |
| text_tokens = self._text_tokenizer.encode(text) | |
| text_frame = torch.zeros(len(text_tokens), 17).long() | |
| text_frame_mask = torch.zeros(len(text_tokens), 17).bool() | |
| text_frame[:, -1] = torch.tensor(text_tokens) | |
| text_frame_mask[:, -1] = True | |
| frame_tokens.append(text_frame.to(self.device)) | |
| frame_masks.append(text_frame_mask.to(self.device)) | |
| return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0) | |
| def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| frame_tokens = [] | |
| frame_masks = [] | |
| # (K, T) | |
| audio_length = torch.tensor([audio.shape[1]], dtype=torch.long) | |
| audio_tokens, token_length = self._audio_tokenizer.encode( | |
| audio.to(self.device), | |
| audio_length.to(self.device), | |
| batch_size=48, | |
| ) | |
| audio_tokens = audio_tokens.squeeze(0) | |
| # add EOS frame | |
| eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device) | |
| audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1) | |
| audio_frame = torch.zeros(audio_tokens.size(1), 17).long().to(self.device) | |
| audio_frame_mask = torch.zeros(audio_tokens.size(1), 17).bool().to(self.device) | |
| audio_frame[:, :-1] = audio_tokens.transpose(0, 1) | |
| audio_frame_mask[:, :-1] = True | |
| frame_tokens.append(audio_frame) | |
| frame_masks.append(audio_frame_mask) | |
| return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0) | |
| def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Returns: | |
| (seq_len,17), (seq_len, 17) | |
| """ | |
| text_tokens, text_masks = self._tokenize_text_segment( | |
| segment.text, segment.speaker | |
| ) | |
| audio_tokens, audio_masks = self._tokenize_audio(segment.audio) | |
| return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat( | |
| [text_masks, audio_masks], dim=0 | |
| ) | |
| def generate( | |
| self, | |
| text: str, | |
| speaker: str, | |
| context: List[Segment], | |
| max_audio_length_ms: float = 90_000, | |
| temperature: float = 0.9, | |
| topk: int = 20, | |
| ) -> torch.Tensor: | |
| self._model.reset_caches() | |
| max_generation_len = int(max_audio_length_ms / 80) | |
| tokens, tokens_mask = [], [] | |
| for segment in context: | |
| segment_tokens, segment_tokens_mask = self._tokenize_segment(segment) | |
| tokens.append(segment_tokens) | |
| tokens_mask.append(segment_tokens_mask) | |
| gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment( | |
| text, speaker | |
| ) | |
| tokens.append(gen_segment_tokens) | |
| tokens_mask.append(gen_segment_tokens_mask) | |
| prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device) | |
| prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device) | |
| samples = [] | |
| curr_tokens = prompt_tokens.unsqueeze(0) | |
| curr_tokens_mask = prompt_tokens_mask.unsqueeze(0) | |
| curr_pos = ( | |
| torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device) | |
| ) | |
| max_seq_len = 3100 | |
| max_context_len = max_seq_len - max_generation_len | |
| if curr_tokens.size(1) >= max_context_len: | |
| raise ValueError( | |
| f"Inputs too long, must be below max_seq_len - max_generation_len: {max_context_len}" | |
| ) | |
| for _ in range(max_generation_len): | |
| sample = self._model.generate_frame( | |
| curr_tokens, curr_tokens_mask, curr_pos, temperature, topk | |
| ) | |
| # eos | |
| if torch.all(sample == 0): | |
| break | |
| samples.append(sample) | |
| curr_tokens = torch.cat( | |
| [sample, torch.zeros(1, 1).long().to(self.device)], dim=1 | |
| ).unsqueeze(1) | |
| curr_tokens_mask = torch.cat( | |
| [ | |
| torch.ones_like(sample).bool(), | |
| torch.zeros(1, 1).bool().to(self.device), | |
| ], | |
| dim=1, | |
| ).unsqueeze(1) | |
| curr_pos = curr_pos[:, -1:] + 1 | |
| audio = ( | |
| self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)) | |
| .squeeze(0) | |
| .squeeze(0) | |
| ) | |
| return audio | |
| def generate_single( | |
| self, context: List[Segment], temperature: float = 0.9, topk: int = 20 | |
| ): | |
| self._model.reset_caches() | |
| max_generation_len = 400 | |
| tokens, tokens_mask = [], [] | |
| for segment in context: | |
| segment_tokens, segment_tokens_mask = self._tokenize_segment(segment) | |
| tokens.append(segment_tokens) | |
| tokens_mask.append(segment_tokens_mask) | |
| prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device) | |
| prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device) | |
| prompt_tokens = prompt_tokens[:-3, :] | |
| prompt_tokens_mask = prompt_tokens_mask[:-3, :] | |
| samples = [] | |
| curr_tokens = prompt_tokens.unsqueeze(0) | |
| curr_tokens_mask = prompt_tokens_mask.unsqueeze(0) | |
| curr_pos = ( | |
| torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device) | |
| ) | |
| num_token = 0 | |
| start_time = time.time() | |
| for _ in range(max_generation_len): | |
| sample = self._model.generate_frame( | |
| curr_tokens, curr_tokens_mask, curr_pos, temperature, topk | |
| ) | |
| # eos | |
| if torch.all(sample == 0): | |
| break | |
| samples.append(sample) | |
| curr_tokens = torch.cat( | |
| [sample, torch.zeros(1, 1).long().to(self.device)], dim=1 | |
| ).unsqueeze(1) | |
| curr_tokens_mask = torch.cat( | |
| [ | |
| torch.ones_like(sample).bool(), | |
| torch.zeros(1, 1).bool().to(self.device), | |
| ], | |
| dim=1, | |
| ).unsqueeze(1) | |
| curr_pos = curr_pos[:, -1:] + 1 | |
| num_token += 1 | |
| if num_token == 2: | |
| end_time = time.time() | |
| duration = end_time - start_time | |
| print("---first pack duration:", duration) | |
| gen_tokens = torch.stack(samples).permute(1, 2, 0) | |
| return gen_tokens | |
| # @torch.inference_mode() | |
| # def generate_stream( | |
| # self, | |
| # text: str, | |
| # speaker: str, | |
| # context: List[Segment], | |
| # max_audio_length_ms: float = 90_000, | |
| # temperature: float = 0.9, | |
| # topk: int = 50, | |
| # ): | |
| # self._model.reset_caches() | |
| # max_generation_len = int(max_audio_length_ms / 80) | |
| # tokens, tokens_mask = [], [] | |
| # for segment in context: | |
| # segment_tokens, segment_tokens_mask = self._tokenize_segment(segment) | |
| # tokens.append(segment_tokens) | |
| # tokens_mask.append(segment_tokens_mask) | |
| # gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment( | |
| # text, speaker | |
| # ) | |
| # tokens.append(gen_segment_tokens) | |
| # tokens_mask.append(gen_segment_tokens_mask) | |
| # prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device) | |
| # prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device) | |
| # samples = [] | |
| # curr_tokens = prompt_tokens.unsqueeze(0) | |
| # curr_tokens_mask = prompt_tokens_mask.unsqueeze(0) | |
| # curr_pos = ( | |
| # torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device) | |
| # ) | |
| # max_seq_len = 3100 | |
| # max_context_len = max_seq_len - max_generation_len | |
| # if curr_tokens.size(1) >= max_context_len: | |
| # raise ValueError( | |
| # f"Inputs too long, must be below max_seq_len - max_generation_len: {max_context_len}" | |
| # ) | |
| # # codec cache | |
| # codec_cache = {} | |
| # prev_sample = None | |
| # for _ in range(max_generation_len): | |
| # sample = self._model.generate_frame( | |
| # curr_tokens, curr_tokens_mask, curr_pos, temperature, topk | |
| # ) | |
| # # eos | |
| # if torch.all(sample == 0): | |
| # break | |
| # # decode one token | |
| # if prev_sample is None: | |
| # prev_sample = sample # sample: (b, nq) | |
| # else: | |
| # audio_chunk, codec_cache = self._audio_tokenizer.decode_one_token( | |
| # prev_sample.unsqueeze(-1), | |
| # codec_cache, | |
| # last_token=False, | |
| # ) | |
| # yield audio_chunk.squeeze(0) | |
| # prev_sample = sample | |
| # samples.append(sample) # sample: (b, nq) | |
| # curr_tokens = torch.cat( | |
| # [sample, torch.zeros(1, 1).long().to(self.device)], dim=1 | |
| # ).unsqueeze(1) | |
| # curr_tokens_mask = torch.cat( | |
| # [ | |
| # torch.ones_like(sample).bool(), | |
| # torch.zeros(1, 1).bool().to(self.device), | |
| # ], | |
| # dim=1, | |
| # ).unsqueeze(1) | |
| # curr_pos = curr_pos[:, -1:] + 1 | |
| # audio_chunk, codec_cache = self._audio_tokenizer.decode_one_token( | |
| # prev_sample.unsqueeze(-1), | |
| # codec_cache, | |
| # last_token=True, | |
| # ) | |
| # yield audio_chunk.squeeze(0) | |
| def generate_dialogue( | |
| self, | |
| text_list, | |
| prompt_wav_list=None, | |
| prompt_text_list=None, | |
| temperature=0.9, | |
| topk=20, | |
| ): | |
| all_generated_segments = [] | |
| all_storage_segments = [] | |
| prompt_segments = [] | |
| text_list = process_text_list(text_list=text_list) | |
| if prompt_wav_list is not None: | |
| assert len(prompt_wav_list) == len(prompt_text_list) | |
| # Prepare prompts | |
| for i in range(len(prompt_wav_list)): | |
| prompt_wav = prompt_wav_list[i] | |
| prompt_text = prompt_text_list[i] | |
| speaker = prompt_text[:4] | |
| assert speaker in ["[S1]", "[S2]", "[S3]", "[S4]"] | |
| prompt_segments.append( | |
| self.prepare_prompt( | |
| text=prompt_text, speaker=speaker, audio_path=prompt_wav | |
| ) | |
| ) | |
| for text in tqdm(text_list): | |
| speaker = text[:4] | |
| text = text[4:] | |
| # print("---speaker:", speaker) | |
| # print("---text:", text) | |
| assert speaker in ["[S1]", "[S2]", "[S3]", "[S4]"] | |
| audio_tensor = self.generate( | |
| text=text, | |
| speaker=speaker, | |
| context=prompt_segments + all_generated_segments, | |
| max_audio_length_ms=30_000, | |
| temperature=temperature, | |
| topk=topk, | |
| ) | |
| # 做上下文管理的时候需要将audio 转到16k | |
| audio_16k = torchaudio.functional.resample( | |
| audio_tensor.unsqueeze(0), 24000, 16000 | |
| ) | |
| all_generated_segments.append( | |
| Segment(text=text, speaker=speaker, audio=audio_16k) | |
| ) | |
| all_storage_segments.append( | |
| Segment(text=text, speaker=speaker, audio=audio_tensor.unsqueeze(0)) | |
| ) | |
| # Concatenate all generations | |
| all_audio = torch.cat([seg.audio for seg in all_storage_segments], dim=1) | |
| all_audio = all_audio.cpu() | |
| return all_audio | |
| def generate_monologue( | |
| self, text, prompt_wav=None, prompt_text=None, temperature=0.75, topk=20 | |
| ): | |
| # step1. construct context | |
| if prompt_wav is not None: | |
| assert os.path.exists(prompt_wav) | |
| assert prompt_text is not None | |
| all_generated_segments = [] | |
| all_storage_segments = [] | |
| prompt_segments = [] | |
| prompt_text = clean_text(text=prompt_text) | |
| text = clean_text(text=text) | |
| text_list = split_text(text=text, length=400) | |
| audio_list = [] | |
| for text in text_list: | |
| text = clean_text(text=text) | |
| input_text = prompt_text[:-1] + "," + text | |
| prompt_a = self.prepare_prompt( | |
| text=input_text, speaker="[S1]", audio_path=prompt_wav | |
| ) | |
| context = [prompt_a] | |
| while True: | |
| gen_tokens = self.generate_single( | |
| context=context, temperature=temperature, topk=topk | |
| ) | |
| if gen_tokens.shape[2] > 18: | |
| break | |
| # else: | |
| # print("生成结果小于1s,重新跑") | |
| gen_tokens = gen_tokens[:, :, 2:] # cut leading silence | |
| audio = self._audio_tokenizer.decode(gen_tokens).squeeze(0).squeeze(0) | |
| audio_list.append(audio.unsqueeze(0)) | |
| all_audio = torch.cat(tensors=audio_list, dim=1) | |
| return all_audio | |
| else: | |
| # random speaker | |
| text = clean_text(text=text.strip()) | |
| audio_tensor = self.generate( | |
| text=text, | |
| speaker="[S1]", | |
| context=[], | |
| max_audio_length_ms=30_000, | |
| temperature=temperature, | |
| topk=topk, | |
| ) | |
| return audio_tensor.unsqueeze(0) | |