File size: 3,654 Bytes
2573c20
1034391
 
2573c20
1034391
 
 
 
2573c20
 
 
 
 
1034391
2573c20
 
 
c476e57
1034391
2573c20
1034391
2573c20
 
 
 
 
 
 
 
 
1034391
2573c20
 
 
1034391
2573c20
 
c476e57
2573c20
 
 
 
 
 
 
 
 
c476e57
2573c20
1034391
2573c20
c476e57
1a5ca4d
2573c20
 
 
 
 
 
 
c476e57
2573c20
 
 
 
 
 
 
 
 
1034391
 
c476e57
 
 
 
 
2573c20
1034391
 
2573c20
 
 
 
c476e57
2573c20
 
 
602fbb1
2573c20
 
4aa0f34
2573c20
c476e57
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os, logging, sys
import gradio as gr
import numpy as np
from groq import Groq
from dia.model import Dia



logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    stream=sys.stdout
)

DEFAULT_REF_PATH = "https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac"
DEFAULT_GEN_TEXT = "Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible."
SAMPLES_PATH = os.path.join(os.getcwd(), "samples")
DEFAULT_REF_TEXT = "That place in the distance, it's huge and dedicated to Lady Shah. It can only mean one thing. I have a hidden place close to the cloister where night orchids bloom."

model = Dia.from_pretrained("nari-labs/Dia-1.6B-0626")

def transcribe(file_path: str):
    client = Groq()
    with open(file_path, "rb") as file:
        transcription = client.audio.transcriptions.create(
            file=(file_path, file.read()),
            model="whisper-large-v3-turbo",
            temperature=0,
            response_format="verbose_json",
        )

        if len(transcription.text) <= 0: logging.warn("Error while transcripting the reference audio.")
        else: logging.info(f"Transcribed: {transcription.text}")
        return transcription.text

def infer(
    gen_text: str,
    ref_text: str = DEFAULT_REF_TEXT,
    ref_audio_path: str = DEFAULT_REF_PATH,
) -> tuple[int, np.ndarray]:
    """
    Generates speech using NeuTTS-Air given a reference audio and text, and new text to synthesize.
    Args:
        gen_text (str): The new text to synthesize.
        ref_text (str): The text corresponding to the reference audio.
        ref_audio_path (str): The file path to the reference audio.
    Returns:
        tuple [int, np.ndarray]: A tuple containing the sample rate (44100) and the generated audio waveform as a numpy array.
    """

    if gen_text is None or not len(gen_text): 
        raise ValueError("Please insert the new text to synthesize.")
    if "female_shadowheart4.flac" not in ref_audio_path and ref_text == DEFAULT_REF_TEXT: ref_text = ""
    if not len(ref_text): 
        ref_text = transcribe(ref_audio_path)

    logging.info(f"Using reference: {ref_audio_path}")
    gr.Info("Starting inference request!")
    gr.Info("Encoding reference...")

    # ndarray[Unknown, Unknown] | list[ndarray[Unknown, Unknown]]
    output = model.generate(
        ref_text + gen_text,
        audio_prompt=ref_audio_path,
        use_torch_compile=False,
        verbose=True,
        cfg_scale=4.0,
        temperature=1.8,
        top_p=0.90,
        cfg_filter_top_k=50,
    )

    if isinstance(output, list):
        output = np.concatenate(output, axis=-1)  # Junta os pedaços de áudio
    elif not isinstance(output, np.ndarray):
        output = np.array(output, dtype=np.float32)

    return (44100, output)


demo = gr.Interface(
    fn=infer,
    inputs=[
        gr.Textbox(label="Text to Generate", value=DEFAULT_GEN_TEXT),
        gr.Textbox(label="Reference Text (Optional)", value=DEFAULT_REF_TEXT),
        gr.Audio(type="filepath", label="Reference Audio", value=DEFAULT_REF_PATH),
    ],
    outputs=gr.Audio(type="numpy", label="Generated Speech"),
    title="Dia-1.6B-0626",
    description="Upload a reference audio sample, provide the reference text, and enter new text to synthesize."
)

if __name__ == "__main__":
    demo.queue(max_size=10).launch(allowed_paths=[SAMPLES_PATH], mcp_server=False, inbrowser=True)