Spaces:
Running
Running
| import os | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| import librosa | |
| from huggingface_hub import hf_hub_download | |
| from model import DCCRN # requires model.py and utils/ dependencies | |
| # ===== Basic config ===== | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| SR = int(os.getenv("SAMPLE_RATE", "16000")) | |
| # Read model repo and filename from environment variables | |
| REPO_ID = os.getenv("MODEL_REPO_ID", "Ada312/DCCRN") # change default if needed | |
| FILENAME = os.getenv("MODEL_FILENAME", "dccrn.ckpt") | |
| TOKEN = os.getenv("HF_TOKEN") # only required if the model repo is private | |
| # ===== Download & load weights ===== | |
| ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=TOKEN) | |
| net = DCCRN() # if you trained with custom args, instantiate with the same args here | |
| ckpt = torch.load(ckpt_path, map_location=DEVICE) | |
| state = ckpt.get("state_dict", ckpt) | |
| state = {k.replace("model.", "").replace("module.", ""): v for k, v in state.items()} | |
| net.load_state_dict(state, strict=False) | |
| net.to(DEVICE).eval() | |
| # ===== Inference ===== | |
| def enhance(audio_path: str): | |
| wav, _ = librosa.load(audio_path, sr=SR, mono=True) | |
| x = torch.from_numpy(wav).float().to(DEVICE) | |
| if x.ndim == 1: | |
| x = x.unsqueeze(0) # [1, T] | |
| with torch.no_grad(): | |
| # Many DCCRNs expect [B,1,T]; try that first, fallback to [B,T] | |
| try: | |
| y = net(x.unsqueeze(1)) # [1, 1, T] | |
| except Exception: | |
| y = net(x) # [1, T] | |
| y = y.squeeze().detach().cpu().numpy() | |
| return (SR, y) | |
| # ===== Gradio UI ===== | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # 🎧 DCCRN Speech Enhancement (Demo) | |
| **How to use:** drag & drop a noisy audio clip (or upload / record) → click **Enhance** → listen & download the result. | |
| **Sample audio:** click a sample below to auto-fill the input, then click **Enhance**. | |
| """ | |
| ) | |
| with gr.Row(): | |
| inp = gr.Audio( | |
| sources=["upload", "microphone"], # drag & drop supported by default | |
| type="filepath", | |
| label="Input: noisy speech (drag & drop or upload / record)" | |
| ) | |
| out = gr.Audio( | |
| label="Output: enhanced speech (downloadable)", | |
| show_download_button=True | |
| ) | |
| enhance_btn = gr.Button("Enhance") | |
| # On-page sample clips (make sure these files exist in the repo) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/noisy_1.wav"], | |
| ["examples/noisy_2.wav"], | |
| ["examples/noisy_3.wav"], | |
| ], | |
| inputs=inp, | |
| label="Sample audio", | |
| examples_per_page=3, | |
| ) | |
| # Gradio ≥4.44: set concurrency on the event listener | |
| enhance_btn.click(enhance, inputs=inp, outputs=out, concurrency_limit=1) | |
| # Queue: keep a small queue to avoid OOM | |
| demo.queue(max_size=16) | |
| demo.launch() | |