Spaces:
Running
Running
| import io | |
| from pathlib import Path | |
| import logging | |
| import streamlit as st | |
| import torch | |
| import numpy as np | |
| import soundfile as sf | |
| from kokoro import KPipeline | |
| from transformers import ( | |
| VitsModel, | |
| AutoTokenizer, | |
| SpeechT5Processor, | |
| SpeechT5ForTextToSpeech, | |
| SpeechT5HifiGan, | |
| ) | |
| from datasets import load_dataset | |
| from scipy.io.wavfile import write as wav_write | |
| from huggingface_hub import InferenceClient, snapshot_download, hf_hub_download | |
| from huggingface_hub.utils import HfHubHTTPError | |
| # Pre-selected Arabic-focused TTS models on Hugging Face (verified public repos) | |
| ARABIC_TTS_MODELS = { | |
| "MMS (MSA) — facebook/mms-tts-ara": { | |
| "repo_id": "facebook/mms-tts-ara", | |
| "engine": "vits", | |
| "hosted": False, | |
| "description": "Official MMS checkpoint for Modern Standard Arabic", | |
| }, | |
| "VITS (Community) — wasmdashai/vits-ar-sa-A": { | |
| "repo_id": "wasmdashai/vits-ar-sa-A", | |
| "engine": "vits", | |
| "hosted": False, | |
| "description": "Community-trained VITS voice focused on Arabic", | |
| }, | |
| "SpeechT5 (CLAra) — MBZUAI/speecht5_tts_clartts_ar": { | |
| "repo_id": "MBZUAI/speecht5_tts_clartts_ar", | |
| "engine": "speecht5", | |
| "hosted": False, | |
| "description": "MBZUAI SpeechT5 fine-tune for Classical Arabic", | |
| }, | |
| "Saudi TTS — AhmedEladl/saudi-tts": { | |
| "repo_id": "AhmedEladl/saudi-tts", | |
| "engine": "xtts", | |
| "hosted": False, | |
| "description": "Coqui XTTS-style Saudi Arabic model (.pth checkpoint). Provide local paths below.", | |
| }, | |
| "XTTS v2 — coqui/XTTS-v2": { | |
| "repo_id": "coqui/XTTS-v2", | |
| "engine": "xtts", | |
| "hosted": False, | |
| "description": "Official Coqui XTTS v2. Use local snapshot and speaker WAV; supports synthesize().", | |
| }, | |
| } | |
| LOG_FILE = Path("app.log") | |
| DEFAULT_DOWNLOAD_DIR = Path("models_cache") | |
| def _init_logger() -> logging.Logger: | |
| """Configure logging once per Streamlit session.""" | |
| if not st.session_state.get("_logger_configured"): | |
| LOG_FILE.parent.mkdir(parents=True, exist_ok=True) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s | %(levelname)s | %(message)s", | |
| handlers=[ | |
| logging.FileHandler(LOG_FILE, encoding="utf-8"), | |
| logging.StreamHandler(), | |
| ], | |
| ) | |
| st.session_state["_logger_configured"] = True | |
| return logging.getLogger("arabic_tts_app") | |
| logger = _init_logger() | |
| st.set_page_config(page_title="Arabic TTS (Hugging Face)", page_icon="🗣️", layout="centered") | |
| st.title("🗣️ Arabic Text-to-Speech — Hugging Face + Streamlit") | |
| st.caption("Generate Arabic speech from text using four open-source Arabic-focused models (MMS, community VITS, and SpeechT5).") | |
| # Sidebar configuration | |
| st.sidebar.header("Model & Settings") | |
| model_label = st.sidebar.selectbox("Choose a TTS model", list(ARABIC_TTS_MODELS.keys())) | |
| model_meta = ARABIC_TTS_MODELS[model_label] | |
| model_id = model_meta["repo_id"] | |
| st.sidebar.markdown( | |
| f"Selected: `{model_id}`\n\n" | |
| f"{model_meta['description']}" | |
| ) | |
| hf_token = st.sidebar.text_input( | |
| "Optional: Hugging Face access token", | |
| type="password", | |
| help="Provide a token if you hit rate limits or want private usage." | |
| ) | |
| # Model download controls | |
| with st.sidebar.expander("Model assets", expanded=False): | |
| download_dir = st.text_input( | |
| "Local download directory", | |
| value=str(DEFAULT_DOWNLOAD_DIR), | |
| help="Where downloaded model files will be stored (relative or absolute path).", | |
| ) | |
| download_now = st.button("⬇️ Download selected model", key="download_model_button") | |
| if download_now: | |
| try: | |
| status = st.sidebar.info("Downloading… please wait.") | |
| local_path = snapshot_download( | |
| repo_id=model_id, | |
| local_dir=download_dir, | |
| token=hf_token or None, | |
| ) | |
| status.empty() | |
| st.sidebar.success(f"Model cached at {local_path}") | |
| logger.info("Downloaded model %s to %s", model_id, local_path) | |
| except HfHubHTTPError as hub_err: | |
| st.sidebar.error(f"Hugging Face download error: {hub_err}") | |
| logger.exception("HF download failed for %s", model_id) | |
| except Exception as dl_err: | |
| st.sidebar.error(f"Download failed: {dl_err}") | |
| logger.exception("Download failed for %s", model_id) | |
| # Remember last chosen download dir for defaults | |
| try: | |
| st.session_state["_last_download_dir"] = download_dir | |
| except Exception: | |
| pass | |
| # XTTS-specific path inputs now that download_dir is defined | |
| xtts_config_path = None | |
| xtts_vocab_path = None | |
| xtts_checkpoint_dir = None | |
| xtts_speaker_wav = None | |
| xtts_temperature = 0.75 | |
| if model_meta["engine"] == "xtts": | |
| with st.sidebar.expander("XTTS local paths", expanded=True): | |
| base = Path(st.session_state.get("_last_download_dir", DEFAULT_DOWNLOAD_DIR)).expanduser() | |
| xtts_config_path = st.text_input( | |
| "config.json path", | |
| value=str(base / "config.json"), | |
| help="Absolute or relative path to XTTS config.json", | |
| ) | |
| xtts_vocab_path = st.text_input( | |
| "vocab.json path", | |
| value=str(base / "vocab.json"), | |
| help="Optional: path to vocab.json (if required by your checkpoint)", | |
| ) | |
| xtts_checkpoint_dir = st.text_input( | |
| "Checkpoint directory", | |
| value=str(base), | |
| help="Directory containing the model .pth checkpoint", | |
| ) | |
| xtts_speaker_wav = st.text_input( | |
| "Speaker WAV path", | |
| value=str(base / "speaker.wav"), | |
| help="Path to a short reference WAV for voice cloning", | |
| ) | |
| xtts_temperature = st.slider("XTTS temperature", 0.1, 1.2, 0.75, 0.05) | |
| with st.sidebar.expander("XTTS options", expanded=False): | |
| xtts_language = st.text_input("Language code", value="ar", help="e.g., ar, en, fr…") | |
| xtts_gpt_cond_len = st.slider("GPT conditioning length", 1, 10, 3, 1) | |
| xtts_use_synthesize = st.checkbox("Use synthesize() if available", value=True) | |
| if LOG_FILE.exists(): | |
| with open(LOG_FILE, "rb") as log_file: | |
| st.sidebar.download_button( | |
| label="Download app logs", | |
| data=log_file, | |
| file_name=LOG_FILE.name, | |
| mime="text/plain", | |
| ) | |
| # Backend selection & device info | |
| supports_local = model_meta["engine"] in {"vits", "speecht5", "kokoro"} | |
| hosted_available = model_meta.get("hosted", False) | |
| backend_options = [] | |
| if supports_local: | |
| backend_options.append("Local (Transformers)") | |
| if hosted_available: | |
| backend_options.append("Hosted (HF Inference)") | |
| if not backend_options: | |
| backend_options = ["Local (Transformers)"] | |
| backend = st.sidebar.radio("Inference backend", backend_options, index=0) | |
| kokoro_lang = model_meta.get("lang_code", "a") | |
| kokoro_voice = model_meta.get("default_voice", "af_heart") | |
| if model_meta["engine"] == "kokoro": | |
| kokoro_lang = st.sidebar.text_input( | |
| "Kokoro language code", | |
| value=kokoro_lang, | |
| help="Keep 'a' for Arabic. Refer to Kokoro docs for other codes.", | |
| ) | |
| kokoro_voice = st.sidebar.text_input( | |
| "Kokoro voice ID", | |
| value=kokoro_voice, | |
| help="Default voice is af_heart. See Kokoro repo for available voices.", | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| st.sidebar.markdown(f"**Device:** `{device}`") | |
| # Voice settings (sample rate used for hosted fallback) | |
| sample_rate = st.sidebar.number_input("Sample rate", value=16000, min_value=8000, max_value=48000, step=1000) | |
| def load_local_model(repo_id: str, cache_dir: str): | |
| try: | |
| model = VitsModel.from_pretrained(repo_id, cache_dir=cache_dir) | |
| tokenizer = AutoTokenizer.from_pretrained(repo_id, cache_dir=cache_dir) | |
| return model, tokenizer | |
| except OSError as missing_weights: | |
| raise RuntimeError( | |
| f"Model {repo_id} does not ship a supported checkpoint (pytorch_model.bin/model.safetensors)." | |
| " Download the raw .pth manually and convert it to HF format, or pick another model." | |
| ) from missing_weights | |
| def load_speecht5_bundle(repo_id: str, cache_dir: str): | |
| try: | |
| processor = SpeechT5Processor.from_pretrained(repo_id, cache_dir=cache_dir) | |
| model = SpeechT5ForTextToSpeech.from_pretrained(repo_id, cache_dir=cache_dir) | |
| vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan", cache_dir=cache_dir) | |
| speaker_embedding = _load_speecht5_speaker_embedding(cache_dir) | |
| return processor, model, vocoder, speaker_embedding | |
| except ImportError as imp_err: | |
| raise RuntimeError( | |
| "SpeechT5 needs optional deps (sentencepiece). Run `pip install sentencepiece` then restart the app." | |
| ) from imp_err | |
| def load_kokoro_pipeline(lang_code: str): | |
| return KPipeline(lang_code=lang_code) | |
| def _load_speecht5_speaker_embedding(cache_dir: str) -> torch.Tensor: | |
| """Load a speaker embedding for SpeechT5 without using dataset scripts. | |
| If remote assets are unavailable, return a neutral 512-dim embedding. | |
| """ | |
| # Try a known xvector file if available (no trust_remote_code) | |
| try: | |
| xvector_path = hf_hub_download( | |
| repo_id="Matthijs/cmu-arctic-xvectors", | |
| filename="validation/000000.xvector.npy", | |
| repo_type="dataset", | |
| cache_dir=cache_dir, | |
| ) | |
| arr = np.load(xvector_path) | |
| vector = torch.from_numpy(arr) | |
| if vector.ndim == 1: | |
| vector = vector.unsqueeze(0) | |
| return vector | |
| except Exception as err: | |
| logger.warning("Speaker xvector file not accessible (%s); using neutral embedding.", err) | |
| # Fallback: neutral speaker embedding (512 dims expected by SpeechT5) | |
| neutral = torch.zeros((1, 512), dtype=torch.float32) | |
| return neutral | |
| def load_xtts_model(config_path: str, checkpoint_dir: str, vocab_path: str | None, device: str): | |
| try: | |
| from TTS.tts.configs.xtts_config import XttsConfig | |
| from TTS.tts.models.xtts import Xtts | |
| except ImportError as e: | |
| raise RuntimeError( | |
| "XTTS requires the Coqui TTS library. Install via `pip install TTS` and restart the app." | |
| ) from e | |
| cfg_path = Path(config_path) | |
| voc_path = Path(vocab_path) if vocab_path else None | |
| ckpt_dir = Path(checkpoint_dir) | |
| if not cfg_path.exists(): | |
| raise RuntimeError(f"XTTS config.json not found at {cfg_path}") | |
| if voc_path is not None and not voc_path.exists(): | |
| raise RuntimeError(f"XTTS vocab.json not found at {voc_path}") | |
| if not ckpt_dir.exists(): | |
| raise RuntimeError(f"XTTS checkpoint directory not found at {ckpt_dir}") | |
| config = XttsConfig() | |
| config.load_json(str(cfg_path)) | |
| model = Xtts.init_from_config(config) | |
| if voc_path is not None: | |
| model.load_checkpoint( | |
| config, | |
| checkpoint_dir=str(ckpt_dir), | |
| eval=True, | |
| vocab_path=str(voc_path), | |
| ) | |
| else: | |
| model.load_checkpoint( | |
| config, | |
| checkpoint_dir=str(ckpt_dir), | |
| eval=True, | |
| ) | |
| if device == "cuda": | |
| model.cuda() | |
| model.eval() | |
| return model | |
| def ensure_valid_tokens(token_batch: dict): | |
| seq_len = token_batch["input_ids"].shape[-1] | |
| if seq_len < 2: | |
| raise ValueError( | |
| "النص المدخل لم ينتج أي رموز صالحة لهذا النموذج. أضف حروفًا عربية واضحة أو جملة أطول ثم أعد المحاولة." | |
| ) | |
| # Main input area | |
| st.subheader("Input Arabic Text") | |
| text = st.text_area( | |
| "Enter Arabic text", | |
| placeholder="اكتب النص العربي هنا لتحويله إلى كلام", | |
| height=150, | |
| ) | |
| # Generate button | |
| generate = st.button("🔊 Generate Speech") | |
| # Output area | |
| audio_placeholder = st.empty() | |
| status_placeholder = st.empty() | |
| if generate: | |
| if not text.strip(): | |
| st.warning("من فضلك أدخل نصًا عربيًا أولاً.") | |
| else: | |
| status_placeholder.info("Running inference… This may take a few seconds.") | |
| success = False | |
| should_run_hosted = backend.startswith("Hosted") and hosted_available | |
| if backend.startswith("Local") and supports_local: | |
| cache_dir = Path(download_dir).expanduser() | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| try: | |
| if model_meta["engine"] == "vits": | |
| model, tokenizer = load_local_model(model_id, str(cache_dir)) | |
| model.to(device) | |
| model.eval() | |
| inputs = tokenizer(text, return_tensors="pt") | |
| ensure_valid_tokens(inputs) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.inference_mode(): | |
| outputs = model(**inputs) | |
| waveform = outputs.waveform.squeeze(0).cpu().numpy() | |
| sr = getattr(model.config, "sampling_rate", sample_rate) | |
| elif model_meta["engine"] == "speecht5": | |
| processor, model, vocoder, speaker = load_speecht5_bundle(model_id, str(cache_dir)) | |
| model.to(device) | |
| vocoder.to(device) | |
| inputs = processor(text=text, return_tensors="pt") | |
| ensure_valid_tokens(inputs) | |
| input_ids = inputs["input_ids"].to(device) | |
| speaker_embedding = speaker.to(device) | |
| with torch.inference_mode(): | |
| speech = model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder) | |
| waveform = speech.cpu().numpy() | |
| sr = getattr(model.config, "sampling_rate", 16000) | |
| elif model_meta["engine"] == "kokoro": | |
| pipeline = load_kokoro_pipeline(kokoro_lang) | |
| generator = pipeline(text, voice=kokoro_voice) | |
| audio_chunks = [] | |
| for _, _, audio in generator: | |
| if audio is not None: | |
| audio_chunks.append(audio) | |
| if not audio_chunks: | |
| raise RuntimeError("Kokoro pipeline returned no audio. Try a different voice or text.") | |
| waveform = np.concatenate(audio_chunks).astype(np.float32) | |
| sr = model_meta.get("sample_rate", 24000) | |
| elif model_meta["engine"] == "xtts": | |
| model = load_xtts_model( | |
| str(Path(xtts_config_path).expanduser()), | |
| str(Path(xtts_checkpoint_dir).expanduser()), | |
| str(Path(xtts_vocab_path).expanduser()), | |
| device, | |
| ) | |
| spk_path = Path(xtts_speaker_wav).expanduser() | |
| if not spk_path.exists(): | |
| raise RuntimeError(f"Speaker WAV not found at {spk_path}") | |
| try: | |
| if 'xtts_use_synthesize' in locals() and xtts_use_synthesize and hasattr(model, 'synthesize'): | |
| out = model.synthesize( | |
| text, | |
| model.config, | |
| speaker_wav=str(spk_path), | |
| gpt_cond_len=int(xtts_gpt_cond_len), | |
| language=xtts_language, | |
| temperature=float(xtts_temperature), | |
| ) | |
| wav = out.get("wav") if isinstance(out, dict) else out | |
| waveform = np.asarray(wav, dtype=np.float32) | |
| sr = 24000 | |
| else: | |
| gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=[str(spk_path)]) | |
| out = model.inference( | |
| text, | |
| xtts_language, | |
| gpt_cond_latent, | |
| speaker_embedding, | |
| temperature=float(xtts_temperature), | |
| ) | |
| waveform = np.asarray(out["wav"], dtype=np.float32) | |
| sr = 24000 | |
| except Exception as xtts_err: | |
| raise RuntimeError( | |
| f"XTTS inference failed. Ensure config, vocab, checkpoint (.pth) and speaker WAV are correct. Error: {xtts_err}" | |
| ) from xtts_err | |
| else: | |
| raise RuntimeError(f"Engine {model_meta['engine']} not supported locally") | |
| wav_io = io.BytesIO() | |
| if model_meta["engine"] == "kokoro": | |
| sf.write(wav_io, waveform, int(sr), format="WAV", closefd=False) | |
| else: | |
| wav_write(wav_io, int(sr), waveform) | |
| wav_io.seek(0) | |
| audio_placeholder.audio(wav_io, format="audio/wav") | |
| status_placeholder.success("Done! Press play above to listen.") | |
| logger.info("Local inference succeeded for %s", model_id) | |
| success = True | |
| except ValueError as token_err: | |
| status_placeholder.error(str(token_err)) | |
| logger.warning("Tokenization failed for %s: %s", model_id, token_err) | |
| st.stop() | |
| except Exception as local_err: | |
| logger.exception("Local inference failed for %s", model_id) | |
| if hosted_available: | |
| should_run_hosted = True | |
| status_placeholder.warning( | |
| f"Local inference فشل ({local_err}). سيتم استخدام واجهة Hugging Face المستضافة تلقائيًا عند توفرها." | |
| ) | |
| else: | |
| status_placeholder.error(f"Local inference failed: {local_err}. راجع السجلات أو جرّب نموذجًا آخر.") | |
| if not success and should_run_hosted and hosted_available: | |
| try: | |
| client = InferenceClient(model=model_id, token=hf_token or None) | |
| audio_bytes = client.text_to_speech(text) | |
| audio_buf = io.BytesIO(audio_bytes) | |
| audio_placeholder.audio(audio_buf, format="audio/wav", sample_rate=sample_rate) | |
| status_placeholder.success("Done! Press play above to listen.") | |
| logger.info("Hosted inference succeeded for %s", model_id) | |
| success = True | |
| except HfHubHTTPError as hub_error: | |
| error_msg = f"Hugging Face inference error: {hub_error}" | |
| status_placeholder.error(error_msg) | |
| logger.exception("HF inference failed for %s", model_id) | |
| except Exception as err: | |
| status_placeholder.error("Inference failed. Check app.log for details.") | |
| logger.exception("Inference failed for %s", model_id) | |
| st.markdown("---") | |
| st.markdown( | |
| "Notes:\n" | |
| "- For best performance, run on a GPU (CUDA) so MMS/VITS/SpeechT5 models synthesize faster.\n" | |
| "- MMS + community VITS checkpoints cover different Arabic dialects; try several to match your accent.\n" | |
| "- SpeechT5 downloads an additional HiFi-GAN vocoder and speaker embedding on first use.\n" | |
| "- Kokoro requires the system package `espeak-ng` for phonemization.\n" | |
| "- Hosted Hugging Face inference is disabled for these repos, so keep local copies handy.\n" | |
| "- Use the sidebar to download model weights and export app logs if you need support.\n" | |
| ) |