Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import librosa | |
| import pickle | |
| import os | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import pandas as pd | |
| import zipfile | |
| import json | |
| from transformers import ClapModel, ClapProcessor | |
| import torch | |
| dataset_zip = "dataset/all_sounds.zip" | |
| extracted_folder = "dataset/all_sounds" | |
| metadata_path = "dataset/licenses.txt" | |
| audio_embeddings_path = "dataset/audio_embeddings.pkl" | |
| # Unzip if not already extracted | |
| if not os.path.exists(extracted_folder): | |
| with zipfile.ZipFile(dataset_zip, "r") as zip_ref: | |
| zip_ref.extractall(extracted_folder) | |
| # Load Hugging Face's CLAP model | |
| processor = ClapProcessor.from_pretrained("laion/clap-htsat-fused") | |
| model = ClapModel.from_pretrained("laion/clap-htsat-fused") | |
| # Load dataset metadata | |
| with open(metadata_path, "r") as file: | |
| data = json.load(file) | |
| # Convert the JSON data into a Pandas DataFrame | |
| metadata = pd.DataFrame.from_dict(data, orient="index") | |
| metadata.index = metadata.index.astype(str) + '.wav' | |
| instrument_categories = { | |
| "Kick": ["kick", "bd", "bass", "808", "kd"], | |
| "Snare": ["snare", "sd", "sn"], | |
| "Hi-Hat": ["hihat", "hh", "hi_hat", "hi-hat"], | |
| "Tom": ["tom"], | |
| "Cymbal": ["crash", "ride", "splash", "cymbal"], | |
| "Clap": ["clap"], | |
| "Percussion": ["shaker", "perc", "tamb", "cowbell", "bongo", "conga", "egg"] | |
| } | |
| # Function to categorize filenames based on keywords | |
| def categorize_instrument(filename): | |
| lower_filename = filename.lower() | |
| for category, keywords in instrument_categories.items(): | |
| if any(keyword in lower_filename for keyword in keywords): | |
| return category | |
| return "Other" # Default category if no match is found | |
| # Apply function to create a new 'Instrument' column | |
| metadata["Instrument"] = metadata["name"].apply(categorize_instrument) | |
| metadata["Instrument"].value_counts() | |
| # Load precomputed audio embeddings (to avoid recomputing on every request) | |
| with open(audio_embeddings_path, "rb") as f: | |
| audio_embeddings = pickle.load(f) | |
| def get_clap_embeddings_from_text(text): | |
| """Convert user text input to a CLAP embedding using Hugging Face's CLAP.""" | |
| inputs = processor(text=text, return_tensors="pt") | |
| with torch.no_grad(): | |
| text_embeddings = model.get_text_features(**inputs) | |
| return text_embeddings.squeeze(0).numpy() | |
| def get_clap_embeddings_from_audio(audio_path): | |
| audio, sr = librosa.load(audio_path) | |
| inputs = processor(audios=[audio], return_tensors="pt", sampling_rate=48000) | |
| with torch.no_grad(): | |
| return model.get_audio_features(**inputs).squeeze(0).numpy() | |
| def find_top_sounds(text_embed, instrument, top_N=4): | |
| """Finds the closest N sounds for an instrument.""" | |
| valid_sounds = metadata[metadata["Instrument"] == instrument].index.tolist() | |
| relevant_embeddings = {k: v for k, v in audio_embeddings.items() if k in valid_sounds} | |
| # Compute cosine similarity | |
| all_embeds = np.array([v for v in relevant_embeddings.values()]) | |
| similarities = cosine_similarity([text_embed], all_embeds)[0] | |
| # Get top N matches | |
| top_indices = np.argsort(similarities)[-top_N:][::-1] | |
| top_files = [os.path.join(extracted_folder, valid_sounds[i]) for i in top_indices] | |
| return top_files | |
| def generate_drum_kit(prompt, kit_size=4): | |
| """Generate a drum kit dictionary from user input.""" | |
| text_embed = get_clap_embeddings_from_text(prompt) | |
| drum_kit = {} | |
| for instrument in ["Kick", "Snare", "Hi-Hat", "Tom", "Cymbal", "Clap", "Percussion", "Other"]: | |
| drum_kit[instrument] = find_top_sounds(text_embed, instrument, top_N=kit_size) | |
| return drum_kit | |