FireRedTTS2 / fireredtts2 /fireredtts2.py
Shen Feiyu
init at 250916
71cd91e
import os
import time
import json
import torch
import torchaudio
from typing import List, Tuple
from fireredtts2.codec import RedCodecInfer
from fireredtts2.llm import load_llm_model, load_custom_tokenizer
from fireredtts2.llm.utils import Segment
from fireredtts2.utils.spliter import clean_text, split_text, process_text_list
from tqdm import tqdm
class FireRedTTS2:
def __init__(self, pretrained_dir, gen_type, device):
assert os.path.exists(pretrained_dir)
assert gen_type in ["monologue", "dialogue"]
llm_config_path = os.path.join(pretrained_dir, "config_llm.json")
if gen_type == "monologue":
llm_ckpt_path = os.path.join(pretrained_dir, "llm_pretrain.pt")
# llm_ckpt_path = os.path.join(pretrained_dir, "llm_posttrain.pt")
else:
llm_ckpt_path = os.path.join(pretrained_dir, "llm_posttrain.pt")
codec_config_path = os.path.join(pretrained_dir, "config_codec.json")
codec_ckpt_path = os.path.join(pretrained_dir, "codec.pt")
pretrained_qwen_path = os.path.join(pretrained_dir, "Qwen2.5-1.5B")
# check
assert os.path.exists(llm_config_path)
assert os.path.exists(llm_ckpt_path)
assert os.path.exists(codec_config_path)
assert os.path.exists(codec_ckpt_path)
assert os.path.exists(pretrained_qwen_path)
# ==== Load Torch LLM ====
llm_config = json.load(open(llm_config_path))
self._model = load_llm_model(
configs=llm_config, checkpoint_path=llm_ckpt_path, device=device
)
self._model.eval()
self._model.setup_caches(1)
print("[INFO] LLM Loaded...")
# ==== Load Qwen2.5 Text Tokenizer ====
self._text_tokenizer = load_custom_tokenizer(pretrained_qwen_path)
print("[INFO] Text Tokenizer Loaded...")
# ==== Load Torch Audio Tokenizer ====
torch_codec = RedCodecInfer.from_pretrained(codec_config_path, codec_ckpt_path)
torch_codec.eval()
self._audio_tokenizer = torch_codec.to(device)
print("[INFO] Codec Loaded...")
self.sample_rate = 16000
self.device = device
self.max_seq_len = 3100
def load_prompt_audio(self, audio_path) -> torch.Tensor:
audio, audio_sr = torchaudio.load(audio_path)
# Audio must be single channel
if audio.shape[0] > 1:
audio = audio[0, :].unsqueeze(0)
audio16k = torchaudio.functional.resample(audio, audio_sr, 16000)
return audio16k
def prepare_prompt(self, text, speaker, audio_path) -> Segment:
audio_tensor = self.load_prompt_audio(audio_path)
return Segment(text=text, speaker=speaker, audio=audio_tensor)
def _tokenize_text_segment(
self, text: str, speaker: str
) -> Tuple[torch.Tensor, torch.Tensor]:
frame_tokens = []
frame_masks = []
text = speaker + "<|text_start|>" + text + "<|text_end|>"
text_tokens = self._text_tokenizer.encode(text)
text_frame = torch.zeros(len(text_tokens), 17).long()
text_frame_mask = torch.zeros(len(text_tokens), 17).bool()
text_frame[:, -1] = torch.tensor(text_tokens)
text_frame_mask[:, -1] = True
frame_tokens.append(text_frame.to(self.device))
frame_masks.append(text_frame_mask.to(self.device))
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
frame_tokens = []
frame_masks = []
# (K, T)
audio_length = torch.tensor([audio.shape[1]], dtype=torch.long)
audio_tokens, token_length = self._audio_tokenizer.encode(
audio.to(self.device),
audio_length.to(self.device),
batch_size=48,
)
audio_tokens = audio_tokens.squeeze(0)
# add EOS frame
eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
audio_frame = torch.zeros(audio_tokens.size(1), 17).long().to(self.device)
audio_frame_mask = torch.zeros(audio_tokens.size(1), 17).bool().to(self.device)
audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
audio_frame_mask[:, :-1] = True
frame_tokens.append(audio_frame)
frame_masks.append(audio_frame_mask)
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns:
(seq_len,17), (seq_len, 17)
"""
text_tokens, text_masks = self._tokenize_text_segment(
segment.text, segment.speaker
)
audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat(
[text_masks, audio_masks], dim=0
)
@torch.inference_mode()
def generate(
self,
text: str,
speaker: str,
context: List[Segment],
max_audio_length_ms: float = 90_000,
temperature: float = 0.9,
topk: int = 20,
) -> torch.Tensor:
self._model.reset_caches()
max_generation_len = int(max_audio_length_ms / 80)
tokens, tokens_mask = [], []
for segment in context:
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
tokens.append(segment_tokens)
tokens_mask.append(segment_tokens_mask)
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(
text, speaker
)
tokens.append(gen_segment_tokens)
tokens_mask.append(gen_segment_tokens_mask)
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
samples = []
curr_tokens = prompt_tokens.unsqueeze(0)
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
curr_pos = (
torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
)
max_seq_len = 3100
max_context_len = max_seq_len - max_generation_len
if curr_tokens.size(1) >= max_context_len:
raise ValueError(
f"Inputs too long, must be below max_seq_len - max_generation_len: {max_context_len}"
)
for _ in range(max_generation_len):
sample = self._model.generate_frame(
curr_tokens, curr_tokens_mask, curr_pos, temperature, topk
)
# eos
if torch.all(sample == 0):
break
samples.append(sample)
curr_tokens = torch.cat(
[sample, torch.zeros(1, 1).long().to(self.device)], dim=1
).unsqueeze(1)
curr_tokens_mask = torch.cat(
[
torch.ones_like(sample).bool(),
torch.zeros(1, 1).bool().to(self.device),
],
dim=1,
).unsqueeze(1)
curr_pos = curr_pos[:, -1:] + 1
audio = (
self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0))
.squeeze(0)
.squeeze(0)
)
return audio
def generate_single(
self, context: List[Segment], temperature: float = 0.9, topk: int = 20
):
self._model.reset_caches()
max_generation_len = 400
tokens, tokens_mask = [], []
for segment in context:
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
tokens.append(segment_tokens)
tokens_mask.append(segment_tokens_mask)
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
prompt_tokens = prompt_tokens[:-3, :]
prompt_tokens_mask = prompt_tokens_mask[:-3, :]
samples = []
curr_tokens = prompt_tokens.unsqueeze(0)
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
curr_pos = (
torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
)
num_token = 0
start_time = time.time()
for _ in range(max_generation_len):
sample = self._model.generate_frame(
curr_tokens, curr_tokens_mask, curr_pos, temperature, topk
)
# eos
if torch.all(sample == 0):
break
samples.append(sample)
curr_tokens = torch.cat(
[sample, torch.zeros(1, 1).long().to(self.device)], dim=1
).unsqueeze(1)
curr_tokens_mask = torch.cat(
[
torch.ones_like(sample).bool(),
torch.zeros(1, 1).bool().to(self.device),
],
dim=1,
).unsqueeze(1)
curr_pos = curr_pos[:, -1:] + 1
num_token += 1
if num_token == 2:
end_time = time.time()
duration = end_time - start_time
print("---first pack duration:", duration)
gen_tokens = torch.stack(samples).permute(1, 2, 0)
return gen_tokens
# @torch.inference_mode()
# def generate_stream(
# self,
# text: str,
# speaker: str,
# context: List[Segment],
# max_audio_length_ms: float = 90_000,
# temperature: float = 0.9,
# topk: int = 50,
# ):
# self._model.reset_caches()
# max_generation_len = int(max_audio_length_ms / 80)
# tokens, tokens_mask = [], []
# for segment in context:
# segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
# tokens.append(segment_tokens)
# tokens_mask.append(segment_tokens_mask)
# gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(
# text, speaker
# )
# tokens.append(gen_segment_tokens)
# tokens_mask.append(gen_segment_tokens_mask)
# prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
# prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
# samples = []
# curr_tokens = prompt_tokens.unsqueeze(0)
# curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
# curr_pos = (
# torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
# )
# max_seq_len = 3100
# max_context_len = max_seq_len - max_generation_len
# if curr_tokens.size(1) >= max_context_len:
# raise ValueError(
# f"Inputs too long, must be below max_seq_len - max_generation_len: {max_context_len}"
# )
# # codec cache
# codec_cache = {}
# prev_sample = None
# for _ in range(max_generation_len):
# sample = self._model.generate_frame(
# curr_tokens, curr_tokens_mask, curr_pos, temperature, topk
# )
# # eos
# if torch.all(sample == 0):
# break
# # decode one token
# if prev_sample is None:
# prev_sample = sample # sample: (b, nq)
# else:
# audio_chunk, codec_cache = self._audio_tokenizer.decode_one_token(
# prev_sample.unsqueeze(-1),
# codec_cache,
# last_token=False,
# )
# yield audio_chunk.squeeze(0)
# prev_sample = sample
# samples.append(sample) # sample: (b, nq)
# curr_tokens = torch.cat(
# [sample, torch.zeros(1, 1).long().to(self.device)], dim=1
# ).unsqueeze(1)
# curr_tokens_mask = torch.cat(
# [
# torch.ones_like(sample).bool(),
# torch.zeros(1, 1).bool().to(self.device),
# ],
# dim=1,
# ).unsqueeze(1)
# curr_pos = curr_pos[:, -1:] + 1
# audio_chunk, codec_cache = self._audio_tokenizer.decode_one_token(
# prev_sample.unsqueeze(-1),
# codec_cache,
# last_token=True,
# )
# yield audio_chunk.squeeze(0)
@torch.inference_mode()
def generate_dialogue(
self,
text_list,
prompt_wav_list=None,
prompt_text_list=None,
temperature=0.9,
topk=20,
):
all_generated_segments = []
all_storage_segments = []
prompt_segments = []
text_list = process_text_list(text_list=text_list)
if prompt_wav_list is not None:
assert len(prompt_wav_list) == len(prompt_text_list)
# Prepare prompts
for i in range(len(prompt_wav_list)):
prompt_wav = prompt_wav_list[i]
prompt_text = prompt_text_list[i]
speaker = prompt_text[:4]
assert speaker in ["[S1]", "[S2]", "[S3]", "[S4]"]
prompt_segments.append(
self.prepare_prompt(
text=prompt_text, speaker=speaker, audio_path=prompt_wav
)
)
for text in tqdm(text_list):
speaker = text[:4]
text = text[4:]
# print("---speaker:", speaker)
# print("---text:", text)
assert speaker in ["[S1]", "[S2]", "[S3]", "[S4]"]
audio_tensor = self.generate(
text=text,
speaker=speaker,
context=prompt_segments + all_generated_segments,
max_audio_length_ms=30_000,
temperature=temperature,
topk=topk,
)
# 做上下文管理的时候需要将audio 转到16k
audio_16k = torchaudio.functional.resample(
audio_tensor.unsqueeze(0), 24000, 16000
)
all_generated_segments.append(
Segment(text=text, speaker=speaker, audio=audio_16k)
)
all_storage_segments.append(
Segment(text=text, speaker=speaker, audio=audio_tensor.unsqueeze(0))
)
# Concatenate all generations
all_audio = torch.cat([seg.audio for seg in all_storage_segments], dim=1)
all_audio = all_audio.cpu()
return all_audio
@torch.inference_mode()
def generate_monologue(
self, text, prompt_wav=None, prompt_text=None, temperature=0.75, topk=20
):
# step1. construct context
if prompt_wav is not None:
assert os.path.exists(prompt_wav)
assert prompt_text is not None
all_generated_segments = []
all_storage_segments = []
prompt_segments = []
prompt_text = clean_text(text=prompt_text)
text = clean_text(text=text)
text_list = split_text(text=text, length=400)
audio_list = []
for text in text_list:
text = clean_text(text=text)
input_text = prompt_text[:-1] + "," + text
prompt_a = self.prepare_prompt(
text=input_text, speaker="[S1]", audio_path=prompt_wav
)
context = [prompt_a]
while True:
gen_tokens = self.generate_single(
context=context, temperature=temperature, topk=topk
)
if gen_tokens.shape[2] > 18:
break
# else:
# print("生成结果小于1s,重新跑")
gen_tokens = gen_tokens[:, :, 2:] # cut leading silence
audio = self._audio_tokenizer.decode(gen_tokens).squeeze(0).squeeze(0)
audio_list.append(audio.unsqueeze(0))
all_audio = torch.cat(tensors=audio_list, dim=1)
return all_audio
else:
# random speaker
text = clean_text(text=text.strip())
audio_tensor = self.generate(
text=text,
speaker="[S1]",
context=[],
max_audio_length_ms=30_000,
temperature=temperature,
topk=topk,
)
return audio_tensor.unsqueeze(0)