MahmoudElsamadony commited on
Commit
a3cf4a0
·
1 Parent(s): 60c1f8a

Update with new models

Browse files
Files changed (1) hide show
  1. app.py +183 -27
app.py CHANGED
@@ -15,7 +15,7 @@ from transformers import (
15
  )
16
  from datasets import load_dataset
17
  from scipy.io.wavfile import write as wav_write
18
- from huggingface_hub import InferenceClient, snapshot_download
19
  from huggingface_hub.utils import HfHubHTTPError
20
 
21
  # Pre-selected Arabic-focused TTS models on Hugging Face (verified public repos)
@@ -26,14 +26,8 @@ ARABIC_TTS_MODELS = {
26
  "hosted": False,
27
  "description": "Official MMS checkpoint for Modern Standard Arabic",
28
  },
29
- "MMS (Arabela) — facebook/mms-tts-arl": {
30
- "repo_id": "facebook/mms-tts-arl",
31
- "engine": "vits",
32
- "hosted": False,
33
- "description": "MMS model for the Arabic (Arabela) locale",
34
- },
35
- "VITS (Community) — wasmdashai/vits-ar": {
36
- "repo_id": "wasmdashai/vits-ar",
37
  "engine": "vits",
38
  "hosted": False,
39
  "description": "Community-trained VITS voice focused on Arabic",
@@ -44,14 +38,17 @@ ARABIC_TTS_MODELS = {
44
  "hosted": False,
45
  "description": "MBZUAI SpeechT5 fine-tune for Classical Arabic",
46
  },
47
- "Kokoro (Arabic)hexgrad/Kokoro-82M": {
48
- "repo_id": "hexgrad/Kokoro-82M",
49
- "engine": "kokoro",
50
  "hosted": False,
51
- "description": "Kokoro 82M multilingual voice with Arabic support (requires espeak-ng)",
52
- "lang_code": "a",
53
- "default_voice": "af_heart",
54
- "sample_rate": 24000,
 
 
 
55
  },
56
  }
57
 
@@ -124,6 +121,47 @@ with st.sidebar.expander("Model assets", expanded=False):
124
  st.sidebar.error(f"Download failed: {dl_err}")
125
  logger.exception("Download failed for %s", model_id)
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  if LOG_FILE.exists():
128
  with open(LOG_FILE, "rb") as log_file:
129
  st.sidebar.download_button(
@@ -170,19 +208,29 @@ sample_rate = st.sidebar.number_input("Sample rate", value=16000, min_value=8000
170
 
171
  @st.cache_resource(show_spinner=False)
172
  def load_local_model(repo_id: str, cache_dir: str):
173
- model = VitsModel.from_pretrained(repo_id, cache_dir=cache_dir)
174
- tokenizer = AutoTokenizer.from_pretrained(repo_id, cache_dir=cache_dir)
175
- return model, tokenizer
 
 
 
 
 
 
176
 
177
 
178
  @st.cache_resource(show_spinner=False)
179
  def load_speecht5_bundle(repo_id: str, cache_dir: str):
180
- processor = SpeechT5Processor.from_pretrained(repo_id, cache_dir=cache_dir)
181
- model = SpeechT5ForTextToSpeech.from_pretrained(repo_id, cache_dir=cache_dir)
182
- vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan", cache_dir=cache_dir)
183
- embeddings = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
184
- speaker_embedding = torch.tensor(embeddings[0]["xvector"]).unsqueeze(0)
185
- return processor, model, vocoder, speaker_embedding
 
 
 
 
186
 
187
 
188
  @st.cache_resource(show_spinner=False)
@@ -190,6 +238,74 @@ def load_kokoro_pipeline(lang_code: str):
190
  return KPipeline(lang_code=lang_code)
191
 
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  def ensure_valid_tokens(token_batch: dict):
194
  seq_len = token_batch["input_ids"].shape[-1]
195
  if seq_len < 2:
@@ -258,6 +374,44 @@ if generate:
258
  raise RuntimeError("Kokoro pipeline returned no audio. Try a different voice or text.")
259
  waveform = np.concatenate(audio_chunks).astype(np.float32)
260
  sr = model_meta.get("sample_rate", 24000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  else:
262
  raise RuntimeError(f"Engine {model_meta['engine']} not supported locally")
263
 
@@ -279,9 +433,11 @@ if generate:
279
  logger.exception("Local inference failed for %s", model_id)
280
  if hosted_available:
281
  should_run_hosted = True
282
- status_placeholder.warning("Local inference فشل، سيتم استخدام واجهة Hugging Face المستضافة تلقائيًا.")
 
 
283
  else:
284
- status_placeholder.error("Local inference failed. راجع السجلات أو جرّب نموذجًا آخر.")
285
 
286
  if not success and should_run_hosted and hosted_available:
287
  try:
 
15
  )
16
  from datasets import load_dataset
17
  from scipy.io.wavfile import write as wav_write
18
+ from huggingface_hub import InferenceClient, snapshot_download, hf_hub_download
19
  from huggingface_hub.utils import HfHubHTTPError
20
 
21
  # Pre-selected Arabic-focused TTS models on Hugging Face (verified public repos)
 
26
  "hosted": False,
27
  "description": "Official MMS checkpoint for Modern Standard Arabic",
28
  },
29
+ "VITS (Community) — wasmdashai/vits-ar-sa-A": {
30
+ "repo_id": "wasmdashai/vits-ar-sa-A",
 
 
 
 
 
 
31
  "engine": "vits",
32
  "hosted": False,
33
  "description": "Community-trained VITS voice focused on Arabic",
 
38
  "hosted": False,
39
  "description": "MBZUAI SpeechT5 fine-tune for Classical Arabic",
40
  },
41
+ "Saudi TTSAhmedEladl/saudi-tts": {
42
+ "repo_id": "AhmedEladl/saudi-tts",
43
+ "engine": "xtts",
44
  "hosted": False,
45
+ "description": "Coqui XTTS-style Saudi Arabic model (.pth checkpoint). Provide local paths below.",
46
+ },
47
+ "XTTS v2 — coqui/XTTS-v2": {
48
+ "repo_id": "coqui/XTTS-v2",
49
+ "engine": "xtts",
50
+ "hosted": False,
51
+ "description": "Official Coqui XTTS v2. Use local snapshot and speaker WAV; supports synthesize().",
52
  },
53
  }
54
 
 
121
  st.sidebar.error(f"Download failed: {dl_err}")
122
  logger.exception("Download failed for %s", model_id)
123
 
124
+ # Remember last chosen download dir for defaults
125
+ try:
126
+ st.session_state["_last_download_dir"] = download_dir
127
+ except Exception:
128
+ pass
129
+
130
+ # XTTS-specific path inputs now that download_dir is defined
131
+ xtts_config_path = None
132
+ xtts_vocab_path = None
133
+ xtts_checkpoint_dir = None
134
+ xtts_speaker_wav = None
135
+ xtts_temperature = 0.75
136
+ if model_meta["engine"] == "xtts":
137
+ with st.sidebar.expander("XTTS local paths", expanded=True):
138
+ base = Path(st.session_state.get("_last_download_dir", DEFAULT_DOWNLOAD_DIR)).expanduser()
139
+ xtts_config_path = st.text_input(
140
+ "config.json path",
141
+ value=str(base / "config.json"),
142
+ help="Absolute or relative path to XTTS config.json",
143
+ )
144
+ xtts_vocab_path = st.text_input(
145
+ "vocab.json path",
146
+ value=str(base / "vocab.json"),
147
+ help="Optional: path to vocab.json (if required by your checkpoint)",
148
+ )
149
+ xtts_checkpoint_dir = st.text_input(
150
+ "Checkpoint directory",
151
+ value=str(base),
152
+ help="Directory containing the model .pth checkpoint",
153
+ )
154
+ xtts_speaker_wav = st.text_input(
155
+ "Speaker WAV path",
156
+ value=str(base / "speaker.wav"),
157
+ help="Path to a short reference WAV for voice cloning",
158
+ )
159
+ xtts_temperature = st.slider("XTTS temperature", 0.1, 1.2, 0.75, 0.05)
160
+ with st.sidebar.expander("XTTS options", expanded=False):
161
+ xtts_language = st.text_input("Language code", value="ar", help="e.g., ar, en, fr…")
162
+ xtts_gpt_cond_len = st.slider("GPT conditioning length", 1, 10, 3, 1)
163
+ xtts_use_synthesize = st.checkbox("Use synthesize() if available", value=True)
164
+
165
  if LOG_FILE.exists():
166
  with open(LOG_FILE, "rb") as log_file:
167
  st.sidebar.download_button(
 
208
 
209
  @st.cache_resource(show_spinner=False)
210
  def load_local_model(repo_id: str, cache_dir: str):
211
+ try:
212
+ model = VitsModel.from_pretrained(repo_id, cache_dir=cache_dir)
213
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, cache_dir=cache_dir)
214
+ return model, tokenizer
215
+ except OSError as missing_weights:
216
+ raise RuntimeError(
217
+ f"Model {repo_id} does not ship a supported checkpoint (pytorch_model.bin/model.safetensors)."
218
+ " Download the raw .pth manually and convert it to HF format, or pick another model."
219
+ ) from missing_weights
220
 
221
 
222
  @st.cache_resource(show_spinner=False)
223
  def load_speecht5_bundle(repo_id: str, cache_dir: str):
224
+ try:
225
+ processor = SpeechT5Processor.from_pretrained(repo_id, cache_dir=cache_dir)
226
+ model = SpeechT5ForTextToSpeech.from_pretrained(repo_id, cache_dir=cache_dir)
227
+ vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan", cache_dir=cache_dir)
228
+ speaker_embedding = _load_speecht5_speaker_embedding(cache_dir)
229
+ return processor, model, vocoder, speaker_embedding
230
+ except ImportError as imp_err:
231
+ raise RuntimeError(
232
+ "SpeechT5 needs optional deps (sentencepiece). Run `pip install sentencepiece` then restart the app."
233
+ ) from imp_err
234
 
235
 
236
  @st.cache_resource(show_spinner=False)
 
238
  return KPipeline(lang_code=lang_code)
239
 
240
 
241
+ def _load_speecht5_speaker_embedding(cache_dir: str) -> torch.Tensor:
242
+ """Load a speaker embedding for SpeechT5 without using dataset scripts.
243
+
244
+ If remote assets are unavailable, return a neutral 512-dim embedding.
245
+ """
246
+ # Try a known xvector file if available (no trust_remote_code)
247
+ try:
248
+ xvector_path = hf_hub_download(
249
+ repo_id="Matthijs/cmu-arctic-xvectors",
250
+ filename="validation/000000.xvector.npy",
251
+ repo_type="dataset",
252
+ cache_dir=cache_dir,
253
+ )
254
+ arr = np.load(xvector_path)
255
+ vector = torch.from_numpy(arr)
256
+ if vector.ndim == 1:
257
+ vector = vector.unsqueeze(0)
258
+ return vector
259
+ except Exception as err:
260
+ logger.warning("Speaker xvector file not accessible (%s); using neutral embedding.", err)
261
+
262
+ # Fallback: neutral speaker embedding (512 dims expected by SpeechT5)
263
+ neutral = torch.zeros((1, 512), dtype=torch.float32)
264
+ return neutral
265
+
266
+
267
+ @st.cache_resource(show_spinner=False)
268
+ def load_xtts_model(config_path: str, checkpoint_dir: str, vocab_path: str | None, device: str):
269
+ try:
270
+ from TTS.tts.configs.xtts_config import XttsConfig
271
+ from TTS.tts.models.xtts import Xtts
272
+ except ImportError as e:
273
+ raise RuntimeError(
274
+ "XTTS requires the Coqui TTS library. Install via `pip install TTS` and restart the app."
275
+ ) from e
276
+
277
+ cfg_path = Path(config_path)
278
+ voc_path = Path(vocab_path) if vocab_path else None
279
+ ckpt_dir = Path(checkpoint_dir)
280
+ if not cfg_path.exists():
281
+ raise RuntimeError(f"XTTS config.json not found at {cfg_path}")
282
+ if voc_path is not None and not voc_path.exists():
283
+ raise RuntimeError(f"XTTS vocab.json not found at {voc_path}")
284
+ if not ckpt_dir.exists():
285
+ raise RuntimeError(f"XTTS checkpoint directory not found at {ckpt_dir}")
286
+
287
+ config = XttsConfig()
288
+ config.load_json(str(cfg_path))
289
+ model = Xtts.init_from_config(config)
290
+ if voc_path is not None:
291
+ model.load_checkpoint(
292
+ config,
293
+ checkpoint_dir=str(ckpt_dir),
294
+ eval=True,
295
+ vocab_path=str(voc_path),
296
+ )
297
+ else:
298
+ model.load_checkpoint(
299
+ config,
300
+ checkpoint_dir=str(ckpt_dir),
301
+ eval=True,
302
+ )
303
+ if device == "cuda":
304
+ model.cuda()
305
+ model.eval()
306
+ return model
307
+
308
+
309
  def ensure_valid_tokens(token_batch: dict):
310
  seq_len = token_batch["input_ids"].shape[-1]
311
  if seq_len < 2:
 
374
  raise RuntimeError("Kokoro pipeline returned no audio. Try a different voice or text.")
375
  waveform = np.concatenate(audio_chunks).astype(np.float32)
376
  sr = model_meta.get("sample_rate", 24000)
377
+ elif model_meta["engine"] == "xtts":
378
+ model = load_xtts_model(
379
+ str(Path(xtts_config_path).expanduser()),
380
+ str(Path(xtts_checkpoint_dir).expanduser()),
381
+ str(Path(xtts_vocab_path).expanduser()),
382
+ device,
383
+ )
384
+ spk_path = Path(xtts_speaker_wav).expanduser()
385
+ if not spk_path.exists():
386
+ raise RuntimeError(f"Speaker WAV not found at {spk_path}")
387
+ try:
388
+ if 'xtts_use_synthesize' in locals() and xtts_use_synthesize and hasattr(model, 'synthesize'):
389
+ out = model.synthesize(
390
+ text,
391
+ model.config,
392
+ speaker_wav=str(spk_path),
393
+ gpt_cond_len=int(xtts_gpt_cond_len),
394
+ language=xtts_language,
395
+ temperature=float(xtts_temperature),
396
+ )
397
+ wav = out.get("wav") if isinstance(out, dict) else out
398
+ waveform = np.asarray(wav, dtype=np.float32)
399
+ sr = 24000
400
+ else:
401
+ gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=[str(spk_path)])
402
+ out = model.inference(
403
+ text,
404
+ xtts_language,
405
+ gpt_cond_latent,
406
+ speaker_embedding,
407
+ temperature=float(xtts_temperature),
408
+ )
409
+ waveform = np.asarray(out["wav"], dtype=np.float32)
410
+ sr = 24000
411
+ except Exception as xtts_err:
412
+ raise RuntimeError(
413
+ f"XTTS inference failed. Ensure config, vocab, checkpoint (.pth) and speaker WAV are correct. Error: {xtts_err}"
414
+ ) from xtts_err
415
  else:
416
  raise RuntimeError(f"Engine {model_meta['engine']} not supported locally")
417
 
 
433
  logger.exception("Local inference failed for %s", model_id)
434
  if hosted_available:
435
  should_run_hosted = True
436
+ status_placeholder.warning(
437
+ f"Local inference فشل ({local_err}). سيتم استخدام واجهة Hugging Face المستضافة تلقائيًا عند توفرها."
438
+ )
439
  else:
440
+ status_placeholder.error(f"Local inference failed: {local_err}. راجع السجلات أو جرّب نموذجًا آخر.")
441
 
442
  if not success and should_run_hosted and hosted_available:
443
  try: