File size: 4,009 Bytes
c2e60bb
b389fb6
 
 
7796889
20b52a3
7796889
20b52a3
b389fb6
c2e60bb
 
 
 
 
b389fb6
c2e60bb
 
b389fb6
c2e60bb
 
b389fb6
c2e60bb
 
b389fb6
 
c2e60bb
 
 
 
f718747
 
b389fb6
 
7796889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2e60bb
 
 
b389fb6
 
7796889
 
c2e60bb
 
 
 
 
 
 
 
 
 
 
 
b389fb6
c2e60bb
 
7796889
 
20b52a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7796889
 
 
 
 
 
 
 
20b52a3
 
7796889
 
 
 
 
 
 
 
 
20b52a3
 
7796889
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# This module handles model inference

import torch
from transformers import AutoProcessor, AutoModelForCTC
from espnet2.bin.s2t_inference import Speech2Text
from inference_huberphoneme import HuBERTPhoneme, Tokenizer

MODEL_TYPES = ["Transformers CTC", "POWSM", "HuBERTPhoneme"]

DEVICE = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

# set espeak library path for macOS
import sys

if sys.platform == "darwin":
    from phonemizer.backend.espeak.wrapper import EspeakWrapper

    _ESPEAK_LIBRARY = "/opt/homebrew/Cellar/espeak/1.48.04_1/lib/libespeak.1.1.48.dylib"
    EspeakWrapper.set_library(_ESPEAK_LIBRARY)


def clear_cache():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    if torch.backends.mps.is_available():
        torch.mps.empty_cache()


# ================================== POWSM ==================================
def load_powsm(model_id, language="<eng>", device=DEVICE):
    s2t = Speech2Text.from_pretrained(
        model_id,
        device=device.replace("mps", "cpu"),
        lang_sym=language,
        task_sym="<pr>",
    )
    if device == "mps":
        s2t.s2t_model.to(device=device, dtype=torch.float32)
        s2t.beam_search.to(device=device, dtype=torch.float32)
        s2t.dtype = "float32"
        s2t.device = device
    return s2t


def transcribe_powsm(audio, model):
    pred = model(audio, text_prev="<na>")[0][0]
    return pred.split("<notimestamps>")[1].strip().replace("/", "")


# ===========================================================================
# ============================= Transformers CTC ============================
def load_transformers_ctc(model_id, device=DEVICE):
    processor = AutoProcessor.from_pretrained(model_id)
    model = AutoModelForCTC.from_pretrained(model_id).to(device)
    return model, processor


def transcribe_transformers_ctc(audio, model) -> str:
    model, processor = model
    input_values = (
        processor(
            [audio],
            sampling_rate=processor.feature_extractor.sampling_rate,
            return_tensors="pt",
            padding=True,
        )
        .input_values.type(torch.float32)
        .to(model.device)
    )
    with torch.no_grad():
        logits = model(input_values).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    return processor.decode(predicted_ids[0])


# ===========================================================================
# ============================== HuBERTPhoneme ==============================
def load_hubert_phoneme(model_id, device=DEVICE):
    model = HuBERTPhoneme.from_pretrained(model_id).to(device).eval()
    tokenizer = Tokenizer(with_blank=model.ctc_training)
    return model, tokenizer, device


def transcribe_hubert_phoneme(audio, model) -> str:
    model, tokenizer, device = model
    with torch.inference_mode():
        output, _ = model.inference(torch.from_numpy(audio).to(device).unsqueeze(0))
        predictions = output.argmax(dim=-1).squeeze().cpu()
        arpabet = tokenizer.decode(predictions.unique_consecutive())
        return arpabet


# ===========================================================================


def load_model(model_id, type, device=DEVICE):
    if type == "POWSM":
        return load_powsm(model_id, device=device)
    elif type == "Transformers CTC":
        return load_transformers_ctc(model_id, device=device)
    elif type == "HuBERTPhoneme":
        return load_hubert_phoneme(model_id, device=device)
    else:
        raise ValueError("Unsupported model type: " + str(type))


def transcribe(audio, type, model) -> str:
    if type == "POWSM":
        return transcribe_powsm(audio, model)
    elif type == "Transformers CTC":
        return transcribe_transformers_ctc(audio, model)
    elif type == "HuBERTPhoneme":
        return transcribe_hubert_phoneme(audio, model)
    else:
        raise ValueError("Unsupported model type: " + str(type))