Spaces:
Sleeping
Sleeping
| import json | |
| import random | |
| from concurrent.futures import ThreadPoolExecutor | |
| from datetime import datetime | |
| import librosa | |
| import pandas as pd | |
| from audio import ( | |
| loudness_normalize, | |
| compute_speaker_activity_masks, | |
| ) | |
| from config import * | |
| from distortions import apply_pm_distortions, apply_ps_distortions | |
| from metrics import ( | |
| compute_pm, | |
| compute_ps, | |
| diffusion_map_torch, | |
| pm_ci_components_full, | |
| ps_ci_components_full, | |
| ) | |
| from models import embed_batch, load_model | |
| from utils import * | |
| def compute_mapss_measures( | |
| models, | |
| mixtures, | |
| *, | |
| systems=None, | |
| algos=None, | |
| experiment_id=None, | |
| layer=DEFAULT_LAYER, | |
| add_ci=DEFAULT_ADD_CI, | |
| alpha=DEFAULT_ALPHA, | |
| seed=42, | |
| on_missing="skip", | |
| verbose=False, | |
| max_gpus=None, | |
| ): | |
| """ | |
| Compute MAPSS measures (PM, PS, and their errors). Data is saved to csv files. | |
| :param models: backbone self-supervised models. | |
| :param mixtures: data to process from _read_manifest | |
| :param systems: specific systems (algos and data) | |
| :param algos: specific algorithms to use | |
| :param experiment_id: user-specified name for experiment | |
| :param layer: transformer layer of model to consider | |
| :param add_ci: True will compute error radius and tail bounds. False will not. | |
| :param alpha: normalization factor of the diffusion maps. Lives in [0, 1]. | |
| :param seed: random seed number. | |
| :param on_missing: "skip" when missing values or throw an "error". | |
| :param verbose: True will print process info to console during runtime. False will minimize it. | |
| :param max_gpus: maximal amount of GPUs the program tries to utilize in parallel. | |
| """ | |
| gpu_distributor = GPUWorkDistributor(max_gpus) | |
| ngpu = get_gpu_count(max_gpus) | |
| if on_missing not in {"skip", "error"}: | |
| raise ValueError("on_missing must be 'skip' or 'error'.") | |
| torch.manual_seed(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| canon_mix = canonicalize_mixtures(mixtures, systems=systems) | |
| mixture_entries = [] | |
| for m in canon_mix: | |
| entries = [] | |
| for i, refp in enumerate(m.refs): | |
| sid = m.speaker_ids[i] | |
| entries.append( | |
| {"id": sid, "ref": Path(refp), "mixture": m.mixture_id, "outs": {}} | |
| ) | |
| mixture_entries.append(entries) | |
| for m, mix_entries in zip(canon_mix, mixture_entries): | |
| for algo, out_list in (m.systems or {}).items(): | |
| if len(out_list) != len(mix_entries): | |
| msg = f"[{algo}] Number of outputs ({len(out_list)}) does not match number of references ({len(mix_entries)}) for mixture {m.mixture_id}" | |
| if on_missing == "error": | |
| raise ValueError(msg) | |
| else: | |
| if verbose: | |
| warnings.warn(msg + " Skipping this algorithm.") | |
| continue | |
| for idx, e in enumerate(mix_entries): | |
| e["outs"][algo] = out_list[idx] | |
| if algos is None: | |
| algos_to_run = sorted( | |
| {algo for algo in canon_mix[0].systems.keys()} if canon_mix and canon_mix[0].systems else [] | |
| ) | |
| else: | |
| algos_to_run = list(algos) | |
| exp_id = experiment_id or datetime.now().strftime("%Y%m%d_%H%M%S") | |
| exp_root = os.path.join(RESULTS_ROOT, f"experiment_{exp_id}") | |
| os.makedirs(exp_root, exist_ok=True) | |
| params = { | |
| "models": models, | |
| "layer": layer, | |
| "add_ci": add_ci, | |
| "alpha": alpha, | |
| "seed": seed, | |
| "batch_size": BATCH_SIZE, | |
| "ngpu": ngpu, | |
| "max_gpus": max_gpus, | |
| } | |
| with open(os.path.join(exp_root, "params.json"), "w") as f: | |
| json.dump(params, f, indent=2) | |
| canon_struct = [ | |
| { | |
| "mixture_id": m.mixture_id, | |
| "references": [str(p) for p in m.refs], | |
| "systems": { | |
| a: [str(p) for p in outs] for a, outs in (m.systems or {}).items() | |
| }, | |
| "speaker_ids": m.speaker_ids, | |
| } | |
| for m in canon_mix | |
| ] | |
| with open(os.path.join(exp_root, "manifest_canonical.json"), "w") as f: | |
| json.dump(canon_struct, f, indent=2) | |
| print(f"Starting experiment {exp_id} with {ngpu} GPUs") | |
| print(f"Results will be saved to: {exp_root}") | |
| print("NOTE: Output files must be provided in the same order as reference files.") | |
| clear_gpu_memory() | |
| get_gpu_memory_info(verbose) | |
| flat_entries = [e for mix in mixture_entries for e in mix] | |
| all_refs = {} | |
| if verbose: | |
| print("Loading reference signals...") | |
| for e in flat_entries: | |
| wav, _ = librosa.load(str(e["ref"]), sr=SR) | |
| all_refs[e["id"]] = torch.from_numpy(loudness_normalize(wav)) | |
| if verbose: | |
| print("Computing speaker activity masks...") | |
| win = int(ENERGY_WIN_MS * SR / 1000) | |
| hop = int(ENERGY_HOP_MS * SR / 1000) | |
| multi_speaker_masks_mix = [] | |
| individual_speaker_masks_mix = [] | |
| total_frames_per_mix = [] | |
| for i, mix in enumerate(mixture_entries): | |
| if verbose: | |
| print(f" Computing masks for mixture {i + 1}/{len(mixture_entries)}") | |
| if ngpu > 0: | |
| with torch.cuda.device(0): | |
| refs_for_mix = [all_refs[e["id"]].cuda() for e in mix] | |
| multi_mask, individual_masks = compute_speaker_activity_masks(refs_for_mix, win, hop) | |
| multi_speaker_masks_mix.append(multi_mask.cpu()) | |
| individual_speaker_masks_mix.append([m.cpu() for m in individual_masks]) | |
| total_frames_per_mix.append(multi_mask.shape[0]) | |
| for ref in refs_for_mix: | |
| del ref | |
| torch.cuda.empty_cache() | |
| else: | |
| refs_for_mix = [all_refs[e["id"]].cpu() for e in mix] | |
| multi_mask, individual_masks = compute_speaker_activity_masks(refs_for_mix, win, hop) | |
| multi_speaker_masks_mix.append(multi_mask.cpu()) | |
| individual_speaker_masks_mix.append([m.cpu() for m in individual_masks]) | |
| total_frames_per_mix.append(multi_mask.shape[0]) | |
| ordered_speakers = [e["id"] for e in flat_entries] | |
| all_mixture_results = {} | |
| for mix_idx, (mix_canon, mix_entries) in enumerate(zip(canon_mix, mixture_entries)): | |
| mixture_id = mix_canon.mixture_id | |
| all_mixture_results[mixture_id] = {} | |
| total_frames = total_frames_per_mix[mix_idx] | |
| mixture_speakers = [e["id"] for e in mix_entries] | |
| for algo_idx, algo in enumerate(algos_to_run): | |
| if verbose: | |
| print(f"\nProcessing Mixture {mixture_id}, Algorithm {algo_idx + 1}/{len(algos_to_run)}: {algo}") | |
| all_outs = {} | |
| missing = [] | |
| for e in mix_entries: | |
| assigned_path = e.get("outs", {}).get(algo) | |
| if assigned_path is None: | |
| missing.append((e["mixture"], e["id"])) | |
| continue | |
| wav, _ = librosa.load(str(assigned_path), sr=SR) | |
| all_outs[e["id"]] = torch.from_numpy(loudness_normalize(wav)) | |
| if missing: | |
| msg = f"[{algo}] missing outputs for {len(missing)} speaker(s) in mixture {mixture_id}" | |
| if on_missing == "error": | |
| raise FileNotFoundError(msg) | |
| else: | |
| if verbose: | |
| warnings.warn(msg + " Skipping those speakers.") | |
| if not all_outs: | |
| if verbose: | |
| warnings.warn(f"[{algo}] No outputs for mixture {mixture_id}. Skipping.") | |
| continue | |
| if algo not in all_mixture_results[mixture_id]: | |
| all_mixture_results[mixture_id][algo] = {} | |
| ps_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models} | |
| pm_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models} | |
| ps_bias_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models} | |
| ps_prob_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models} | |
| pm_bias_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models} | |
| pm_prob_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models} | |
| for model_idx, mname in enumerate(models): | |
| if verbose: | |
| print(f" Processing Model {model_idx + 1}/{len(models)}: {mname}") | |
| for metric_type in ["PS", "PM"]: | |
| clear_gpu_memory() | |
| gc.collect() | |
| model_wrapper, layer_eff = load_model(mname, layer, max_gpus) | |
| get_gpu_memory_info(verbose) | |
| speakers_this_mix = [e for e in mix_entries if e["id"] in all_outs] | |
| if not speakers_this_mix: | |
| continue | |
| if verbose: | |
| print(f" Processing {metric_type} for mixture {mixture_id}") | |
| multi_speaker_mask = multi_speaker_masks_mix[mix_idx] | |
| individual_masks = individual_speaker_masks_mix[mix_idx] | |
| valid_frame_indices = torch.where(multi_speaker_mask)[0].tolist() | |
| speaker_signals = {} | |
| speaker_labels = {} | |
| for speaker_idx, e in enumerate(speakers_this_mix): | |
| s = e["id"] | |
| if metric_type == "PS": | |
| dists = [ | |
| loudness_normalize(d) | |
| for d in apply_ps_distortions(all_refs[s].numpy(), "all") | |
| ] | |
| else: | |
| dists = [ | |
| loudness_normalize(d) | |
| for d in apply_pm_distortions( | |
| all_refs[s].numpy(), "all" | |
| ) | |
| ] | |
| sigs = [all_refs[s].numpy(), all_outs[s].numpy()] + dists | |
| lbls = ["ref", "out"] + [f"d{i}" for i in range(len(dists))] | |
| speaker_signals[s] = sigs | |
| speaker_labels[s] = [f"{s}-{l}" for l in lbls] | |
| all_embeddings = {} | |
| for s in speaker_signals: | |
| sigs = speaker_signals[s] | |
| masks = [multi_speaker_mask] * len(sigs) | |
| batch_size = min(2, BATCH_SIZE) | |
| embeddings_list = [] | |
| for i in range(0, len(sigs), batch_size): | |
| batch_sigs = sigs[i:i + batch_size] | |
| batch_masks = masks[i:i + batch_size] | |
| batch_embs = embed_batch( | |
| batch_sigs, | |
| batch_masks, | |
| model_wrapper, | |
| layer_eff, | |
| use_mlm=False, | |
| ) | |
| if batch_embs.numel() > 0: | |
| embeddings_list.append(batch_embs.cpu()) | |
| torch.cuda.empty_cache() | |
| if embeddings_list: | |
| all_embeddings[s] = torch.cat(embeddings_list, dim=0) | |
| else: | |
| all_embeddings[s] = torch.empty(0, 0, 0) | |
| if not all_embeddings or all(e.numel() == 0 for e in all_embeddings.values()): | |
| if verbose: | |
| print(f"WARNING: mixture {mixture_id} produced 0 frames after masking; skipping.") | |
| continue | |
| L = next(iter(all_embeddings.values())).shape[1] if all_embeddings else 0 | |
| if L == 0: | |
| if verbose: | |
| print(f"WARNING: mixture {mixture_id} produced 0 frames after masking; skipping.") | |
| continue | |
| if verbose: | |
| print(f"Computing {metric_type} scores for {mname}...") | |
| with ThreadPoolExecutor( | |
| max_workers=min(2, ngpu if ngpu > 0 else 1) | |
| ) as executor: | |
| def process_frame(f, frame_idx, all_embeddings_dict, speaker_labels_dict, individual_masks_list, | |
| speaker_indices): | |
| try: | |
| active_speakers = [] | |
| for spk_idx, spk_id in enumerate(speaker_indices): | |
| if individual_masks_list[spk_idx][frame_idx]: | |
| active_speakers.append(spk_id) | |
| if len(active_speakers) < 2: | |
| return frame_idx, metric_type, {}, None, None | |
| frame_embeddings = [] | |
| frame_labels = [] | |
| for spk_id in active_speakers: | |
| spk_embs = all_embeddings_dict[spk_id][:, f, :] | |
| frame_embeddings.append(spk_embs) | |
| frame_labels.extend(speaker_labels_dict[spk_id]) | |
| frame_emb = torch.cat(frame_embeddings, dim=0).detach().cpu().numpy() | |
| if add_ci: | |
| coords_d, coords_c, eigvals, k_sub_gauss = ( | |
| gpu_distributor.execute_on_gpu( | |
| diffusion_map_torch, | |
| frame_emb, | |
| frame_labels, | |
| alpha=alpha, | |
| eig_solver="full", | |
| return_eigs=True, | |
| return_complement=True, | |
| return_cval=add_ci, | |
| ) | |
| ) | |
| else: | |
| coords_d = gpu_distributor.execute_on_gpu( | |
| diffusion_map_torch, | |
| frame_emb, | |
| frame_labels, | |
| alpha=alpha, | |
| eig_solver="full", | |
| return_eigs=False, | |
| return_complement=False, | |
| return_cval=False, | |
| ) | |
| coords_c = None | |
| eigvals = None | |
| k_sub_gauss = 1 | |
| if metric_type == "PS": | |
| score = compute_ps( | |
| coords_d, frame_labels, max_gpus | |
| ) | |
| bias = prob = None | |
| if add_ci: | |
| bias, prob = ps_ci_components_full( | |
| coords_d, | |
| coords_c, | |
| eigvals, | |
| frame_labels, | |
| delta=DEFAULT_DELTA_CI, | |
| ) | |
| return frame_idx, "PS", score, bias, prob | |
| else: | |
| score = compute_pm( | |
| coords_d, frame_labels, "gamma", max_gpus | |
| ) | |
| bias = prob = None | |
| if add_ci: | |
| bias, prob = pm_ci_components_full( | |
| coords_d, | |
| coords_c, | |
| eigvals, | |
| frame_labels, | |
| delta=DEFAULT_DELTA_CI, | |
| K=k_sub_gauss, | |
| ) | |
| return frame_idx, "PM", score, bias, prob | |
| except Exception as ex: | |
| if verbose: | |
| print(f"ERROR frame {frame_idx}: {ex}") | |
| return None | |
| speaker_ids = [e["id"] for e in speakers_this_mix] | |
| futures = [ | |
| executor.submit( | |
| process_frame, | |
| f, | |
| valid_frame_indices[f], | |
| all_embeddings, | |
| speaker_labels, | |
| individual_masks, | |
| speaker_ids | |
| ) | |
| for f in range(L) | |
| ] | |
| for fut in futures: | |
| result = fut.result() | |
| if result is None: | |
| continue | |
| frame_idx, metric, score, bias, prob = result | |
| if metric == "PS": | |
| for sp in mixture_speakers: | |
| if sp in score: | |
| ps_frames[mname][sp][frame_idx] = score[sp] | |
| if add_ci and bias is not None and sp in bias: | |
| ps_bias_frames[mname][sp][frame_idx] = bias[sp] | |
| ps_prob_frames[mname][sp][frame_idx] = prob[sp] | |
| else: | |
| for sp in mixture_speakers: | |
| if sp in score: | |
| pm_frames[mname][sp][frame_idx] = score[sp] | |
| if add_ci and bias is not None and sp in bias: | |
| pm_bias_frames[mname][sp][frame_idx] = bias[sp] | |
| pm_prob_frames[mname][sp][frame_idx] = prob[sp] | |
| clear_gpu_memory() | |
| gc.collect() | |
| del model_wrapper | |
| clear_gpu_memory() | |
| gc.collect() | |
| all_mixture_results[mixture_id][algo][mname] = { | |
| 'ps_frames': ps_frames[mname], | |
| 'pm_frames': pm_frames[mname], | |
| 'ps_bias_frames': ps_bias_frames[mname] if add_ci else None, | |
| 'ps_prob_frames': ps_prob_frames[mname] if add_ci else None, | |
| 'pm_bias_frames': pm_bias_frames[mname] if add_ci else None, | |
| 'pm_prob_frames': pm_prob_frames[mname] if add_ci else None, | |
| 'total_frames': total_frames | |
| } | |
| if verbose: | |
| print(f"Saving results for mixture {mixture_id}...") | |
| timestamps_ms = [i * hop * 1000 / SR for i in range(total_frames)] | |
| for model in models: | |
| ps_data = {'timestamp_ms': timestamps_ms} | |
| pm_data = {'timestamp_ms': timestamps_ms} | |
| ci_data = {'timestamp_ms': timestamps_ms} if add_ci else None | |
| for algo in algos_to_run: | |
| if algo not in all_mixture_results[mixture_id]: | |
| continue | |
| if model not in all_mixture_results[mixture_id][algo]: | |
| continue | |
| model_data = all_mixture_results[mixture_id][algo][model] | |
| for speaker in mixture_speakers: | |
| col_name = f"{algo}_{speaker}" | |
| ps_data[col_name] = model_data['ps_frames'][speaker] | |
| pm_data[col_name] = model_data['pm_frames'][speaker] | |
| if add_ci and ci_data is not None: | |
| ci_data[f"{algo}_{speaker}_ps_bias"] = model_data['ps_bias_frames'][speaker] | |
| ci_data[f"{algo}_{speaker}_ps_prob"] = model_data['ps_prob_frames'][speaker] | |
| ci_data[f"{algo}_{speaker}_pm_bias"] = model_data['pm_bias_frames'][speaker] | |
| ci_data[f"{algo}_{speaker}_pm_prob"] = model_data['pm_prob_frames'][speaker] | |
| mixture_dir = os.path.join(exp_root, mixture_id) | |
| os.makedirs(mixture_dir, exist_ok=True) | |
| pd.DataFrame(ps_data).to_csv( | |
| os.path.join(mixture_dir, f"ps_scores_{model}.csv"), | |
| index=False | |
| ) | |
| pd.DataFrame(pm_data).to_csv( | |
| os.path.join(mixture_dir, f"pm_scores_{model}.csv"), | |
| index=False | |
| ) | |
| if add_ci and ci_data is not None: | |
| pd.DataFrame(ci_data).to_csv( | |
| os.path.join(mixture_dir, f"ci_{model}.csv"), | |
| index=False | |
| ) | |
| del all_outs | |
| clear_gpu_memory() | |
| gc.collect() | |
| print(f"\nEXPERIMENT COMPLETED") | |
| print(f"Results saved to: {exp_root}") | |
| del all_refs, multi_speaker_masks_mix, individual_speaker_masks_mix | |
| from models import cleanup_all_models | |
| cleanup_all_models() | |
| clear_gpu_memory() | |
| get_gpu_memory_info(verbose) | |
| gc.collect() | |
| return exp_root |