Spaces:
Running
Running
File size: 4,009 Bytes
c2e60bb b389fb6 7796889 20b52a3 7796889 20b52a3 b389fb6 c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb f718747 b389fb6 7796889 c2e60bb b389fb6 7796889 c2e60bb b389fb6 c2e60bb 7796889 20b52a3 7796889 20b52a3 7796889 20b52a3 7796889 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
# This module handles model inference
import torch
from transformers import AutoProcessor, AutoModelForCTC
from espnet2.bin.s2t_inference import Speech2Text
from inference_huberphoneme import HuBERTPhoneme, Tokenizer
MODEL_TYPES = ["Transformers CTC", "POWSM", "HuBERTPhoneme"]
DEVICE = (
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
# set espeak library path for macOS
import sys
if sys.platform == "darwin":
from phonemizer.backend.espeak.wrapper import EspeakWrapper
_ESPEAK_LIBRARY = "/opt/homebrew/Cellar/espeak/1.48.04_1/lib/libespeak.1.1.48.dylib"
EspeakWrapper.set_library(_ESPEAK_LIBRARY)
def clear_cache():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
# ================================== POWSM ==================================
def load_powsm(model_id, language="<eng>", device=DEVICE):
s2t = Speech2Text.from_pretrained(
model_id,
device=device.replace("mps", "cpu"),
lang_sym=language,
task_sym="<pr>",
)
if device == "mps":
s2t.s2t_model.to(device=device, dtype=torch.float32)
s2t.beam_search.to(device=device, dtype=torch.float32)
s2t.dtype = "float32"
s2t.device = device
return s2t
def transcribe_powsm(audio, model):
pred = model(audio, text_prev="<na>")[0][0]
return pred.split("<notimestamps>")[1].strip().replace("/", "")
# ===========================================================================
# ============================= Transformers CTC ============================
def load_transformers_ctc(model_id, device=DEVICE):
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForCTC.from_pretrained(model_id).to(device)
return model, processor
def transcribe_transformers_ctc(audio, model) -> str:
model, processor = model
input_values = (
processor(
[audio],
sampling_rate=processor.feature_extractor.sampling_rate,
return_tensors="pt",
padding=True,
)
.input_values.type(torch.float32)
.to(model.device)
)
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
return processor.decode(predicted_ids[0])
# ===========================================================================
# ============================== HuBERTPhoneme ==============================
def load_hubert_phoneme(model_id, device=DEVICE):
model = HuBERTPhoneme.from_pretrained(model_id).to(device).eval()
tokenizer = Tokenizer(with_blank=model.ctc_training)
return model, tokenizer, device
def transcribe_hubert_phoneme(audio, model) -> str:
model, tokenizer, device = model
with torch.inference_mode():
output, _ = model.inference(torch.from_numpy(audio).to(device).unsqueeze(0))
predictions = output.argmax(dim=-1).squeeze().cpu()
arpabet = tokenizer.decode(predictions.unique_consecutive())
return arpabet
# ===========================================================================
def load_model(model_id, type, device=DEVICE):
if type == "POWSM":
return load_powsm(model_id, device=device)
elif type == "Transformers CTC":
return load_transformers_ctc(model_id, device=device)
elif type == "HuBERTPhoneme":
return load_hubert_phoneme(model_id, device=device)
else:
raise ValueError("Unsupported model type: " + str(type))
def transcribe(audio, type, model) -> str:
if type == "POWSM":
return transcribe_powsm(audio, model)
elif type == "Transformers CTC":
return transcribe_transformers_ctc(audio, model)
elif type == "HuBERTPhoneme":
return transcribe_hubert_phoneme(audio, model)
else:
raise ValueError("Unsupported model type: " + str(type))
|