dccrn-demo / app.py
chenxie95's picture
Update app.py
265d119 verified
import os
import numpy as np
import torch
import gradio as gr
import librosa
from huggingface_hub import hf_hub_download
from model import DCCRN # requires model.py and utils/ dependencies
# ===== Basic config =====
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SR = int(os.getenv("SAMPLE_RATE", "16000"))
# Read model repo and filename from environment variables
REPO_ID = os.getenv("MODEL_REPO_ID", "Ada312/DCCRN") # change default if needed
FILENAME = os.getenv("MODEL_FILENAME", "dccrn.ckpt")
TOKEN = os.getenv("HF_TOKEN") # only required if the model repo is private
# ===== Download & load weights =====
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=TOKEN)
net = DCCRN() # if you trained with custom args, instantiate with the same args here
ckpt = torch.load(ckpt_path, map_location=DEVICE)
state = ckpt.get("state_dict", ckpt)
state = {k.replace("model.", "").replace("module.", ""): v for k, v in state.items()}
net.load_state_dict(state, strict=False)
net.to(DEVICE).eval()
# ===== Inference =====
def enhance(audio_path: str):
wav, _ = librosa.load(audio_path, sr=SR, mono=True)
x = torch.from_numpy(wav).float().to(DEVICE)
if x.ndim == 1:
x = x.unsqueeze(0) # [1, T]
with torch.no_grad():
# Many DCCRNs expect [B,1,T]; try that first, fallback to [B,T]
try:
y = net(x.unsqueeze(1)) # [1, 1, T]
except Exception:
y = net(x) # [1, T]
y = y.squeeze().detach().cpu().numpy()
return (SR, y)
# ===== Gradio UI =====
with gr.Blocks() as demo:
gr.Markdown(
"""
# 🎧 DCCRN Speech Enhancement (Demo)
**How to use:** drag & drop a noisy audio clip (or upload / record) → click **Enhance** → listen & download the result.
**Sample audio:** click a sample below to auto-fill the input, then click **Enhance**.
"""
)
with gr.Row():
inp = gr.Audio(
sources=["upload", "microphone"], # drag & drop supported by default
type="filepath",
label="Input: noisy speech (drag & drop or upload / record)"
)
out = gr.Audio(
label="Output: enhanced speech (downloadable)",
show_download_button=True
)
enhance_btn = gr.Button("Enhance")
# On-page sample clips (make sure these files exist in the repo)
gr.Examples(
examples=[
["examples/noisy_1.wav"],
["examples/noisy_2.wav"],
["examples/noisy_3.wav"],
],
inputs=inp,
label="Sample audio",
examples_per_page=3,
)
# Gradio ≥4.44: set concurrency on the event listener
enhance_btn.click(enhance, inputs=inp, outputs=out, concurrency_limit=1)
# Queue: keep a small queue to avoid OOM
demo.queue(max_size=16)
demo.launch()