MuseControlLite / app.py
manoskary's picture
Remove CUDA environment variable setting and enhance GPU handling in run_inference function
eb50b30
import copy
import os
import subprocess
import time
from typing import Dict, List, Optional, Tuple
import spaces
import gradio as gr
import soundfile as sf
import torch
from MuseControlLite_setup import initialize_condition_extractors, process_musical_conditions, setup_MuseControlLite
from config_inference import get_config
# Stable Audio uses fixed-length 47.5s chunks (2097152 / 44100)
TOTAL_AUDIO_SECONDS = 2097152 / 44100
DEFAULT_CONFIG = get_config()
DEFAULT_PROMPT = DEFAULT_CONFIG["text"][0] if DEFAULT_CONFIG.get("text") else ""
OUTPUT_ROOT = os.path.join(DEFAULT_CONFIG["output_dir"], "gradio_runs")
CONDITION_CHOICES = ["melody_stereo", "melody_mono", "dynamics", "rhythm", "audio"]
CHECKPOINT_EXPECTED = [
"./checkpoints/woSDD-all/model_3.safetensors",
"./checkpoints/woSDD-all/model_1.safetensors",
"./checkpoints/woSDD-all/model_2.safetensors",
"./checkpoints/woSDD-all/model.safetensors",
]
os.makedirs(OUTPUT_ROOT, exist_ok=True)
def ensure_checkpoints() -> None:
"""Download checkpoints with gdown if they are missing."""
if all(os.path.exists(path) for path in CHECKPOINT_EXPECTED):
return
os.makedirs("checkpoints", exist_ok=True)
try:
subprocess.run(
["gdown", "1Q9B333jcq1czA11JKTbM-DHANJ8YqGbP", "--folder"],
check=True,
)
except Exception as exc: # pylint: disable=broad-except
# Do not crash the space on startup; inference will surface an error later if checkpoints are missing.
print(f"[warn] Checkpoint download failed: {exc}")
ensure_checkpoints()
class ModelCache:
"""Lazy loader for heavy pipelines and condition extractors."""
def __init__(self) -> None:
self.cache: Dict[Tuple, Dict] = {}
def get(self, config: Dict) -> Dict:
key = (
tuple(sorted(config["condition_type"])),
config["weight_dtype"],
float(config["ap_scale"]),
config["apadapter"],
)
if key in self.cache:
return self.cache[key]
weight_dtype = torch.float16 if config["weight_dtype"] == "fp16" else torch.float32
if config["apadapter"]:
condition_extractors, transformer_ckpt = initialize_condition_extractors(config)
pipe = setup_MuseControlLite(config, weight_dtype, transformer_ckpt).to("cuda")
payload = {
"pipe": pipe,
"condition_extractors": condition_extractors,
"weight_dtype": weight_dtype,
"mode": "musecontrol",
}
else:
from diffusers import StableAudioPipeline
pipe = StableAudioPipeline.from_pretrained(
"stabilityai/stable-audio-open-1.0",
torch_dtype=weight_dtype,
).to("cuda")
payload = {"pipe": pipe, "condition_extractors": None, "weight_dtype": weight_dtype, "mode": "vanilla"}
self.cache[key] = payload
return payload
model_cache = ModelCache()
def _build_base_config() -> Dict:
return copy.deepcopy(DEFAULT_CONFIG)
def _create_run_dir() -> str:
run_dir = os.path.join(OUTPUT_ROOT, f"run_{int(time.time() * 1000)}")
os.makedirs(run_dir, exist_ok=True)
return run_dir
def _seed_to_generator(seed: Optional[float]) -> Optional[torch.Generator]:
if seed is None or seed == "":
return None
try:
seed_int = int(seed)
except (TypeError, ValueError):
return None
generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu")
return generator.manual_seed(seed_int)
def _validate_condition_choices(condition_type: Optional[List[str]]) -> List[str]:
condition_type = condition_type or []
if "melody_stereo" in condition_type and any(
choice in condition_type for choice in ("dynamics", "rhythm", "melody_mono")
):
raise gr.Error("`melody_stereo` cannot be combined with dynamics, rhythm, or melody_mono.")
return condition_type
@spaces.GPU
def run_inference(
prompt_text: str,
condition_audio: Optional[str],
condition_type: Optional[List[str]],
use_musecontrol: bool,
no_text: bool,
negative_text_prompt: str,
guidance_scale_text: float,
guidance_scale_con: float,
guidance_scale_audio: float,
denoise_step: int,
weight_dtype: str,
ap_scale: float,
sigma_min: float,
sigma_max: float,
audio_mask_start: float,
audio_mask_end: float,
musical_mask_start: float,
musical_mask_end: float,
seed: Optional[float],
):
condition_type = _validate_condition_choices(condition_type)
config = _build_base_config()
config.update(
{
"text": [prompt_text or ""],
"audio_files": [condition_audio or ""],
"apadapter": use_musecontrol,
"no_text": bool(no_text),
"negative_text_prompt": negative_text_prompt or "",
"guidance_scale_text": float(guidance_scale_text),
"guidance_scale_con": float(guidance_scale_con),
"guidance_scale_audio": float(guidance_scale_audio),
"denoise_step": int(denoise_step),
"weight_dtype": weight_dtype,
"ap_scale": float(ap_scale),
"sigma_min": float(sigma_min),
"sigma_max": float(sigma_max),
"audio_mask_start_seconds": float(audio_mask_start or 0),
"audio_mask_end_seconds": float(audio_mask_end or 0),
"musical_attribute_mask_start_seconds": float(musical_mask_start or 0),
"musical_attribute_mask_end_seconds": float(musical_mask_end or 0),
"show_result_and_plt": False,
}
)
config["condition_type"] = condition_type
if config["apadapter"]:
if not condition_type:
raise gr.Error("Select at least one condition type when using MuseControlLite.")
if not condition_audio:
raise gr.Error("Upload an audio file for conditioning.")
if not os.path.exists(condition_audio):
raise gr.Error("Condition audio file not found.")
run_dir = _create_run_dir()
config["output_dir"] = run_dir
generator = _seed_to_generator(seed)
try:
models = model_cache.get(config)
pipe = models["pipe"].to("cuda")
pipe.enable_attention_slicing()
pipe.scheduler.config.sigma_min = config["sigma_min"]
pipe.scheduler.config.sigma_max = config["sigma_max"]
prompt_for_model = "" if config["no_text"] else (prompt_text or "")
with torch.no_grad():
if config["apadapter"]:
final_condition, final_condition_audio = process_musical_conditions(
config, condition_audio, models["condition_extractors"], run_dir, 0, models["weight_dtype"], pipe
)
waveform = pipe(
extracted_condition=final_condition,
extracted_condition_audio=final_condition_audio,
prompt=prompt_for_model,
negative_prompt=config["negative_text_prompt"],
num_inference_steps=config["denoise_step"],
guidance_scale_text=config["guidance_scale_text"],
guidance_scale_con=config["guidance_scale_con"],
guidance_scale_audio=config["guidance_scale_audio"],
num_waveforms_per_prompt=1,
audio_end_in_s=TOTAL_AUDIO_SECONDS,
generator=generator,
).audios
output = waveform[0].T.float().cpu().numpy()
sr = pipe.vae.sampling_rate
else:
audio = pipe(
prompt=prompt_for_model,
negative_prompt=config["negative_text_prompt"],
num_inference_steps=config["denoise_step"],
guidance_scale=config["guidance_scale_text"],
num_waveforms_per_prompt=1,
audio_end_in_s=TOTAL_AUDIO_SECONDS,
generator=generator,
).audios
output = audio[0].T.float().cpu().numpy()
sr = pipe.vae.sampling_rate
generated_path = os.path.join(run_dir, "generated.wav")
sf.write(generated_path, output, sr)
status_lines = [
f"Run directory: `{run_dir}`",
f"Mode: {'MuseControlLite' if config['apadapter'] else 'Stable Audio base'}",
f"Condition type: {', '.join(condition_type) if condition_type else 'text only'}",
f"Dtype: {config['weight_dtype']}, steps: {config['denoise_step']}, sigma [{config['sigma_min']}, {config['sigma_max']}]",
]
if config["apadapter"]:
status_lines.append(
f"Guidance (text/cond/audio): {config['guidance_scale_text']}/{config['guidance_scale_con']}/{config['guidance_scale_audio']}"
)
if generator is not None:
status_lines.append(f"Seed: {int(seed)}")
status_md = "\n".join(f"- {line}" for line in status_lines)
return generated_path, status_md
except gr.Error:
raise
except Exception as err: # pylint: disable=broad-except
raise gr.Error(f"Generation failed: {err}") from err
EXAMPLES = [
[
"Electronic music that has a constant melody throughout with accompanying instruments used to supplement the melody which can be heard in possibly a casual setting",
"melody_condition_audio/49_piano.mp3",
["melody_stereo"],
True,
False,
"",
7.0,
1.5,
1.0,
50,
"fp16",
1.0,
0.3,
500,
0,
0,
0,
0,
42,
],
[
"fast and fun beat-based indie pop to set a protagonist-gets-good-at-x movie montage to.",
"melody_condition_audio/610_bass.mp3",
["melody_mono", "dynamics", "rhythm"],
True,
False,
"",
7.0,
1.5,
1.0,
50,
"fp16",
1.0,
0.3,
500,
0,
0,
0,
0,
7,
],
]
def build_interface() -> gr.Blocks:
with gr.Blocks(title="MuseControlLite") as demo:
gr.Markdown(
"""
## MuseControlLite demo
UI for MuseControlLite (47.5s generations). This Space downloads checkpoints on startup with gdown and expects a GPU runtime; duplicate to a GPU Space or run locally for actual generation.
"""
)
with gr.Row():
prompt = gr.Textbox(label="Text prompt", lines=3, value=DEFAULT_PROMPT)
use_musecontrol = gr.Checkbox(label="Use MuseControlLite adapters", value=True)
no_text = gr.Checkbox(label="Ignore text prompt (audio-only guidance)", value=False)
condition_audio = gr.Audio(
label="Condition audio (required for MuseControlLite)", type="filepath", sources=["upload", "microphone"]
)
condition_type = gr.CheckboxGroup(
CONDITION_CHOICES, label="Condition types", value=DEFAULT_CONFIG.get("condition_type", [])
)
with gr.Accordion("Advanced controls", open=False):
negative_prompt = gr.Textbox(label="Negative prompt", lines=2, value=DEFAULT_CONFIG.get("negative_text_prompt", ""))
with gr.Row():
guidance_scale_text = gr.Slider(
minimum=0.0,
maximum=12.0,
value=DEFAULT_CONFIG["guidance_scale_text"],
step=0.1,
label="Guidance scale (text)",
)
guidance_scale_con = gr.Slider(
minimum=0.0,
maximum=5.0,
value=DEFAULT_CONFIG["guidance_scale_con"],
step=0.1,
label="Guidance scale (conditions)",
)
guidance_scale_audio = gr.Slider(
minimum=0.0,
maximum=5.0,
value=DEFAULT_CONFIG["guidance_scale_audio"],
step=0.1,
label="Guidance scale (audio)",
)
with gr.Row():
denoise_step = gr.Slider(
minimum=10, maximum=100, value=DEFAULT_CONFIG["denoise_step"], step=1, label="Denoising steps"
)
weight_dtype = gr.Radio(["fp16", "fp32"], value=DEFAULT_CONFIG["weight_dtype"], label="Weight dtype")
ap_scale = gr.Slider(
minimum=0.5, maximum=2.0, value=DEFAULT_CONFIG["ap_scale"], step=0.05, label="AP scale"
)
with gr.Row():
sigma_min = gr.Slider(
minimum=0.1, maximum=5.0, value=DEFAULT_CONFIG["sigma_min"], step=0.05, label="Scheduler sigma min"
)
sigma_max = gr.Slider(
minimum=50, maximum=700, value=DEFAULT_CONFIG["sigma_max"], step=1, label="Scheduler sigma max"
)
seed = gr.Number(label="Seed (optional)", precision=0)
with gr.Row():
audio_mask_start = gr.Number(
label="Audio mask start (s)", value=DEFAULT_CONFIG["audio_mask_start_seconds"]
)
audio_mask_end = gr.Number(label="Audio mask end (s)", value=DEFAULT_CONFIG["audio_mask_end_seconds"])
with gr.Row():
musical_mask_start = gr.Number(
label="Musical attribute mask start (s)", value=DEFAULT_CONFIG["musical_attribute_mask_start_seconds"]
)
musical_mask_end = gr.Number(
label="Musical attribute mask end (s)", value=DEFAULT_CONFIG["musical_attribute_mask_end_seconds"]
)
generate_btn = gr.Button("Generate", variant="primary")
generated_audio = gr.Audio(label="Generated audio", type="filepath")
status = gr.Markdown(label="Run details")
generate_btn.click(
fn=run_inference,
inputs=[
prompt,
condition_audio,
condition_type,
use_musecontrol,
no_text,
negative_prompt,
guidance_scale_text,
guidance_scale_con,
guidance_scale_audio,
denoise_step,
weight_dtype,
ap_scale,
sigma_min,
sigma_max,
audio_mask_start,
audio_mask_end,
musical_mask_start,
musical_mask_end,
seed,
],
outputs=[generated_audio, status],
)
gr.Examples(
examples=EXAMPLES,
inputs=[
prompt,
condition_audio,
condition_type,
use_musecontrol,
no_text,
negative_prompt,
guidance_scale_text,
guidance_scale_con,
guidance_scale_audio,
denoise_step,
weight_dtype,
ap_scale,
sigma_min,
sigma_max,
audio_mask_start,
audio_mask_end,
musical_mask_start,
musical_mask_end,
seed,
],
label="Quick start examples (click to populate the form)",
)
return demo
if __name__ == "__main__":
demo = build_interface()
demo.launch()