Spaces:
Sleeping
Sleeping
| import copy | |
| import os | |
| import subprocess | |
| import time | |
| from typing import Dict, List, Optional, Tuple | |
| import spaces | |
| import gradio as gr | |
| import soundfile as sf | |
| import torch | |
| from MuseControlLite_setup import initialize_condition_extractors, process_musical_conditions, setup_MuseControlLite | |
| from config_inference import get_config | |
| # Stable Audio uses fixed-length 47.5s chunks (2097152 / 44100) | |
| TOTAL_AUDIO_SECONDS = 2097152 / 44100 | |
| DEFAULT_CONFIG = get_config() | |
| DEFAULT_PROMPT = DEFAULT_CONFIG["text"][0] if DEFAULT_CONFIG.get("text") else "" | |
| OUTPUT_ROOT = os.path.join(DEFAULT_CONFIG["output_dir"], "gradio_runs") | |
| CONDITION_CHOICES = ["melody_stereo", "melody_mono", "dynamics", "rhythm", "audio"] | |
| CHECKPOINT_EXPECTED = [ | |
| "./checkpoints/woSDD-all/model_3.safetensors", | |
| "./checkpoints/woSDD-all/model_1.safetensors", | |
| "./checkpoints/woSDD-all/model_2.safetensors", | |
| "./checkpoints/woSDD-all/model.safetensors", | |
| ] | |
| os.makedirs(OUTPUT_ROOT, exist_ok=True) | |
| def ensure_checkpoints() -> None: | |
| """Download checkpoints with gdown if they are missing.""" | |
| if all(os.path.exists(path) for path in CHECKPOINT_EXPECTED): | |
| return | |
| os.makedirs("checkpoints", exist_ok=True) | |
| try: | |
| subprocess.run( | |
| ["gdown", "1Q9B333jcq1czA11JKTbM-DHANJ8YqGbP", "--folder"], | |
| check=True, | |
| ) | |
| except Exception as exc: # pylint: disable=broad-except | |
| # Do not crash the space on startup; inference will surface an error later if checkpoints are missing. | |
| print(f"[warn] Checkpoint download failed: {exc}") | |
| ensure_checkpoints() | |
| class ModelCache: | |
| """Lazy loader for heavy pipelines and condition extractors.""" | |
| def __init__(self) -> None: | |
| self.cache: Dict[Tuple, Dict] = {} | |
| def get(self, config: Dict) -> Dict: | |
| key = ( | |
| tuple(sorted(config["condition_type"])), | |
| config["weight_dtype"], | |
| float(config["ap_scale"]), | |
| config["apadapter"], | |
| ) | |
| if key in self.cache: | |
| return self.cache[key] | |
| weight_dtype = torch.float16 if config["weight_dtype"] == "fp16" else torch.float32 | |
| if config["apadapter"]: | |
| condition_extractors, transformer_ckpt = initialize_condition_extractors(config) | |
| pipe = setup_MuseControlLite(config, weight_dtype, transformer_ckpt).to("cuda") | |
| payload = { | |
| "pipe": pipe, | |
| "condition_extractors": condition_extractors, | |
| "weight_dtype": weight_dtype, | |
| "mode": "musecontrol", | |
| } | |
| else: | |
| from diffusers import StableAudioPipeline | |
| pipe = StableAudioPipeline.from_pretrained( | |
| "stabilityai/stable-audio-open-1.0", | |
| torch_dtype=weight_dtype, | |
| ).to("cuda") | |
| payload = {"pipe": pipe, "condition_extractors": None, "weight_dtype": weight_dtype, "mode": "vanilla"} | |
| self.cache[key] = payload | |
| return payload | |
| model_cache = ModelCache() | |
| def _build_base_config() -> Dict: | |
| return copy.deepcopy(DEFAULT_CONFIG) | |
| def _create_run_dir() -> str: | |
| run_dir = os.path.join(OUTPUT_ROOT, f"run_{int(time.time() * 1000)}") | |
| os.makedirs(run_dir, exist_ok=True) | |
| return run_dir | |
| def _seed_to_generator(seed: Optional[float]) -> Optional[torch.Generator]: | |
| if seed is None or seed == "": | |
| return None | |
| try: | |
| seed_int = int(seed) | |
| except (TypeError, ValueError): | |
| return None | |
| generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu") | |
| return generator.manual_seed(seed_int) | |
| def _validate_condition_choices(condition_type: Optional[List[str]]) -> List[str]: | |
| condition_type = condition_type or [] | |
| if "melody_stereo" in condition_type and any( | |
| choice in condition_type for choice in ("dynamics", "rhythm", "melody_mono") | |
| ): | |
| raise gr.Error("`melody_stereo` cannot be combined with dynamics, rhythm, or melody_mono.") | |
| return condition_type | |
| def run_inference( | |
| prompt_text: str, | |
| condition_audio: Optional[str], | |
| condition_type: Optional[List[str]], | |
| use_musecontrol: bool, | |
| no_text: bool, | |
| negative_text_prompt: str, | |
| guidance_scale_text: float, | |
| guidance_scale_con: float, | |
| guidance_scale_audio: float, | |
| denoise_step: int, | |
| weight_dtype: str, | |
| ap_scale: float, | |
| sigma_min: float, | |
| sigma_max: float, | |
| audio_mask_start: float, | |
| audio_mask_end: float, | |
| musical_mask_start: float, | |
| musical_mask_end: float, | |
| seed: Optional[float], | |
| ): | |
| condition_type = _validate_condition_choices(condition_type) | |
| config = _build_base_config() | |
| config.update( | |
| { | |
| "text": [prompt_text or ""], | |
| "audio_files": [condition_audio or ""], | |
| "apadapter": use_musecontrol, | |
| "no_text": bool(no_text), | |
| "negative_text_prompt": negative_text_prompt or "", | |
| "guidance_scale_text": float(guidance_scale_text), | |
| "guidance_scale_con": float(guidance_scale_con), | |
| "guidance_scale_audio": float(guidance_scale_audio), | |
| "denoise_step": int(denoise_step), | |
| "weight_dtype": weight_dtype, | |
| "ap_scale": float(ap_scale), | |
| "sigma_min": float(sigma_min), | |
| "sigma_max": float(sigma_max), | |
| "audio_mask_start_seconds": float(audio_mask_start or 0), | |
| "audio_mask_end_seconds": float(audio_mask_end or 0), | |
| "musical_attribute_mask_start_seconds": float(musical_mask_start or 0), | |
| "musical_attribute_mask_end_seconds": float(musical_mask_end or 0), | |
| "show_result_and_plt": False, | |
| } | |
| ) | |
| config["condition_type"] = condition_type | |
| if config["apadapter"]: | |
| if not condition_type: | |
| raise gr.Error("Select at least one condition type when using MuseControlLite.") | |
| if not condition_audio: | |
| raise gr.Error("Upload an audio file for conditioning.") | |
| if not os.path.exists(condition_audio): | |
| raise gr.Error("Condition audio file not found.") | |
| run_dir = _create_run_dir() | |
| config["output_dir"] = run_dir | |
| generator = _seed_to_generator(seed) | |
| try: | |
| models = model_cache.get(config) | |
| pipe = models["pipe"].to("cuda") | |
| pipe.enable_attention_slicing() | |
| pipe.scheduler.config.sigma_min = config["sigma_min"] | |
| pipe.scheduler.config.sigma_max = config["sigma_max"] | |
| prompt_for_model = "" if config["no_text"] else (prompt_text or "") | |
| with torch.no_grad(): | |
| if config["apadapter"]: | |
| final_condition, final_condition_audio = process_musical_conditions( | |
| config, condition_audio, models["condition_extractors"], run_dir, 0, models["weight_dtype"], pipe | |
| ) | |
| waveform = pipe( | |
| extracted_condition=final_condition, | |
| extracted_condition_audio=final_condition_audio, | |
| prompt=prompt_for_model, | |
| negative_prompt=config["negative_text_prompt"], | |
| num_inference_steps=config["denoise_step"], | |
| guidance_scale_text=config["guidance_scale_text"], | |
| guidance_scale_con=config["guidance_scale_con"], | |
| guidance_scale_audio=config["guidance_scale_audio"], | |
| num_waveforms_per_prompt=1, | |
| audio_end_in_s=TOTAL_AUDIO_SECONDS, | |
| generator=generator, | |
| ).audios | |
| output = waveform[0].T.float().cpu().numpy() | |
| sr = pipe.vae.sampling_rate | |
| else: | |
| audio = pipe( | |
| prompt=prompt_for_model, | |
| negative_prompt=config["negative_text_prompt"], | |
| num_inference_steps=config["denoise_step"], | |
| guidance_scale=config["guidance_scale_text"], | |
| num_waveforms_per_prompt=1, | |
| audio_end_in_s=TOTAL_AUDIO_SECONDS, | |
| generator=generator, | |
| ).audios | |
| output = audio[0].T.float().cpu().numpy() | |
| sr = pipe.vae.sampling_rate | |
| generated_path = os.path.join(run_dir, "generated.wav") | |
| sf.write(generated_path, output, sr) | |
| status_lines = [ | |
| f"Run directory: `{run_dir}`", | |
| f"Mode: {'MuseControlLite' if config['apadapter'] else 'Stable Audio base'}", | |
| f"Condition type: {', '.join(condition_type) if condition_type else 'text only'}", | |
| f"Dtype: {config['weight_dtype']}, steps: {config['denoise_step']}, sigma [{config['sigma_min']}, {config['sigma_max']}]", | |
| ] | |
| if config["apadapter"]: | |
| status_lines.append( | |
| f"Guidance (text/cond/audio): {config['guidance_scale_text']}/{config['guidance_scale_con']}/{config['guidance_scale_audio']}" | |
| ) | |
| if generator is not None: | |
| status_lines.append(f"Seed: {int(seed)}") | |
| status_md = "\n".join(f"- {line}" for line in status_lines) | |
| return generated_path, status_md | |
| except gr.Error: | |
| raise | |
| except Exception as err: # pylint: disable=broad-except | |
| raise gr.Error(f"Generation failed: {err}") from err | |
| EXAMPLES = [ | |
| [ | |
| "Electronic music that has a constant melody throughout with accompanying instruments used to supplement the melody which can be heard in possibly a casual setting", | |
| "melody_condition_audio/49_piano.mp3", | |
| ["melody_stereo"], | |
| True, | |
| False, | |
| "", | |
| 7.0, | |
| 1.5, | |
| 1.0, | |
| 50, | |
| "fp16", | |
| 1.0, | |
| 0.3, | |
| 500, | |
| 0, | |
| 0, | |
| 0, | |
| 0, | |
| 42, | |
| ], | |
| [ | |
| "fast and fun beat-based indie pop to set a protagonist-gets-good-at-x movie montage to.", | |
| "melody_condition_audio/610_bass.mp3", | |
| ["melody_mono", "dynamics", "rhythm"], | |
| True, | |
| False, | |
| "", | |
| 7.0, | |
| 1.5, | |
| 1.0, | |
| 50, | |
| "fp16", | |
| 1.0, | |
| 0.3, | |
| 500, | |
| 0, | |
| 0, | |
| 0, | |
| 0, | |
| 7, | |
| ], | |
| ] | |
| def build_interface() -> gr.Blocks: | |
| with gr.Blocks(title="MuseControlLite") as demo: | |
| gr.Markdown( | |
| """ | |
| ## MuseControlLite demo | |
| UI for MuseControlLite (47.5s generations). This Space downloads checkpoints on startup with gdown and expects a GPU runtime; duplicate to a GPU Space or run locally for actual generation. | |
| """ | |
| ) | |
| with gr.Row(): | |
| prompt = gr.Textbox(label="Text prompt", lines=3, value=DEFAULT_PROMPT) | |
| use_musecontrol = gr.Checkbox(label="Use MuseControlLite adapters", value=True) | |
| no_text = gr.Checkbox(label="Ignore text prompt (audio-only guidance)", value=False) | |
| condition_audio = gr.Audio( | |
| label="Condition audio (required for MuseControlLite)", type="filepath", sources=["upload", "microphone"] | |
| ) | |
| condition_type = gr.CheckboxGroup( | |
| CONDITION_CHOICES, label="Condition types", value=DEFAULT_CONFIG.get("condition_type", []) | |
| ) | |
| with gr.Accordion("Advanced controls", open=False): | |
| negative_prompt = gr.Textbox(label="Negative prompt", lines=2, value=DEFAULT_CONFIG.get("negative_text_prompt", "")) | |
| with gr.Row(): | |
| guidance_scale_text = gr.Slider( | |
| minimum=0.0, | |
| maximum=12.0, | |
| value=DEFAULT_CONFIG["guidance_scale_text"], | |
| step=0.1, | |
| label="Guidance scale (text)", | |
| ) | |
| guidance_scale_con = gr.Slider( | |
| minimum=0.0, | |
| maximum=5.0, | |
| value=DEFAULT_CONFIG["guidance_scale_con"], | |
| step=0.1, | |
| label="Guidance scale (conditions)", | |
| ) | |
| guidance_scale_audio = gr.Slider( | |
| minimum=0.0, | |
| maximum=5.0, | |
| value=DEFAULT_CONFIG["guidance_scale_audio"], | |
| step=0.1, | |
| label="Guidance scale (audio)", | |
| ) | |
| with gr.Row(): | |
| denoise_step = gr.Slider( | |
| minimum=10, maximum=100, value=DEFAULT_CONFIG["denoise_step"], step=1, label="Denoising steps" | |
| ) | |
| weight_dtype = gr.Radio(["fp16", "fp32"], value=DEFAULT_CONFIG["weight_dtype"], label="Weight dtype") | |
| ap_scale = gr.Slider( | |
| minimum=0.5, maximum=2.0, value=DEFAULT_CONFIG["ap_scale"], step=0.05, label="AP scale" | |
| ) | |
| with gr.Row(): | |
| sigma_min = gr.Slider( | |
| minimum=0.1, maximum=5.0, value=DEFAULT_CONFIG["sigma_min"], step=0.05, label="Scheduler sigma min" | |
| ) | |
| sigma_max = gr.Slider( | |
| minimum=50, maximum=700, value=DEFAULT_CONFIG["sigma_max"], step=1, label="Scheduler sigma max" | |
| ) | |
| seed = gr.Number(label="Seed (optional)", precision=0) | |
| with gr.Row(): | |
| audio_mask_start = gr.Number( | |
| label="Audio mask start (s)", value=DEFAULT_CONFIG["audio_mask_start_seconds"] | |
| ) | |
| audio_mask_end = gr.Number(label="Audio mask end (s)", value=DEFAULT_CONFIG["audio_mask_end_seconds"]) | |
| with gr.Row(): | |
| musical_mask_start = gr.Number( | |
| label="Musical attribute mask start (s)", value=DEFAULT_CONFIG["musical_attribute_mask_start_seconds"] | |
| ) | |
| musical_mask_end = gr.Number( | |
| label="Musical attribute mask end (s)", value=DEFAULT_CONFIG["musical_attribute_mask_end_seconds"] | |
| ) | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| generated_audio = gr.Audio(label="Generated audio", type="filepath") | |
| status = gr.Markdown(label="Run details") | |
| generate_btn.click( | |
| fn=run_inference, | |
| inputs=[ | |
| prompt, | |
| condition_audio, | |
| condition_type, | |
| use_musecontrol, | |
| no_text, | |
| negative_prompt, | |
| guidance_scale_text, | |
| guidance_scale_con, | |
| guidance_scale_audio, | |
| denoise_step, | |
| weight_dtype, | |
| ap_scale, | |
| sigma_min, | |
| sigma_max, | |
| audio_mask_start, | |
| audio_mask_end, | |
| musical_mask_start, | |
| musical_mask_end, | |
| seed, | |
| ], | |
| outputs=[generated_audio, status], | |
| ) | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=[ | |
| prompt, | |
| condition_audio, | |
| condition_type, | |
| use_musecontrol, | |
| no_text, | |
| negative_prompt, | |
| guidance_scale_text, | |
| guidance_scale_con, | |
| guidance_scale_audio, | |
| denoise_step, | |
| weight_dtype, | |
| ap_scale, | |
| sigma_min, | |
| sigma_max, | |
| audio_mask_start, | |
| audio_mask_end, | |
| musical_mask_start, | |
| musical_mask_end, | |
| seed, | |
| ], | |
| label="Quick start examples (click to populate the form)", | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_interface() | |
| demo.launch() | |