Spaces:
Paused
Paused
| import torch | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from dataclasses import dataclass | |
| from pytorch_lightning import seed_everything | |
| from src.model.pipeline import AudioLDMPipeline, TangoPipeline | |
| from src.utils.utils import ( | |
| process_move, | |
| process_paste, | |
| process_remove, | |
| ) | |
| from src.utils.audio_processing import TacotronSTFT, wav_to_fbank, maybe_add_dimension | |
| from src.utils.factory import slerp, fill_with_neighbor, optimize_neighborhood_points | |
| # NUM_DDIM_STEPS = 50 # 50 | |
| SIZES = { | |
| 0: 4, | |
| 1: 2, | |
| 2: 1, | |
| 3: 1, | |
| } | |
| class SoundEditorOutput: | |
| waveform: torch.tensor | |
| mel_spectrogram: torch.tensor | |
| class AudioMorphix: | |
| def __init__( | |
| self, | |
| pretrained_model_path, | |
| num_ddim_steps=50, | |
| device = "cuda" if torch.cuda.is_available() else "cpu", | |
| ): | |
| self.ip_scale = 0.1 | |
| self.precision = torch.float32 # torch.float16 | |
| if "audioldm" in pretrained_model_path: | |
| _pipe_cls = AudioLDMPipeline | |
| elif "tango" in pretrained_model_path: | |
| _pipe_cls = TangoPipeline | |
| self.editor = _pipe_cls( | |
| sd_id=pretrained_model_path, | |
| NUM_DDIM_STEPS=num_ddim_steps, | |
| precision=self.precision, | |
| ip_scale=self.ip_scale, | |
| device=device, | |
| ) | |
| self.up_ft_index = [2, 3] # fixed in gradio demo # TODO: change to 2,3 | |
| self.up_scale = 2 # fixed in gradio demo | |
| self.device = device | |
| self.num_ddim_steps = num_ddim_steps | |
| def to(self, device): | |
| self.editor.pipe = self.editor.pipe.to(device) | |
| self.editor.pipe._device = device | |
| self.editor.device = device | |
| self.device = device | |
| def run_move( | |
| self, | |
| fbank_org, | |
| mask, | |
| dx, dy, | |
| mask_ref, | |
| prompt, | |
| resize_scale_x, | |
| resize_scale_y, | |
| w_edit, | |
| w_content, | |
| w_contrast, | |
| w_inpaint, | |
| seed, | |
| guidance_scale, | |
| energy_scale, | |
| SDE_strength, | |
| mask_keep=None, | |
| ip_scale=None, | |
| save_kv=False, | |
| disable_tangent_proj=False, | |
| scale_denoised=True, | |
| ): | |
| seed_everything(seed) | |
| energy_scale = energy_scale * 1e3 | |
| # Prepare input spec and mask | |
| input_scale = 1 | |
| fbank_org = maybe_add_dimension(fbank_org, 4).to( | |
| self.device, dtype=self.precision | |
| ) # shape = (B,C,T,F) | |
| f, t = fbank_org.shape[-1], fbank_org.shape[-2] | |
| if save_kv: | |
| self.editor.load_adapter() | |
| ### FIXME | |
| if mask_ref is not None and np.sum(mask_ref) != 0: | |
| mask_ref = np.repeat(mask_ref[:,:,None], 3, 2) | |
| else: | |
| mask_ref = None | |
| latent = self.editor.fbank2latent(fbank_org) | |
| ddim_latents = self.editor.ddim_inv(latent=latent, prompt=prompt) | |
| latent_in = ddim_latents[-1].squeeze(2) | |
| scale = 4 * SIZES[max(self.up_ft_index)] / self.up_scale | |
| edit_kwargs = process_move( | |
| path_mask=mask, | |
| h=f, | |
| w=t, | |
| dx=dx, | |
| dy=dy, | |
| scale=scale, | |
| input_scale=input_scale, | |
| resize_scale_x=resize_scale_x, | |
| resize_scale_y=resize_scale_y, | |
| up_scale=self.up_scale, | |
| up_ft_index=self.up_ft_index, | |
| w_edit=w_edit, | |
| w_content=w_content, | |
| w_contrast=w_contrast, | |
| w_inpaint=w_inpaint, | |
| precision=self.precision, | |
| path_mask_ref=mask_ref, | |
| path_mask_keep=mask_keep, | |
| ) | |
| # Pre-process zT | |
| mask_tmp = (F.interpolate(mask.unsqueeze(0).unsqueeze(0), (int(latent_in.shape[-2]*resize_scale_y), int(latent_in.shape[-1]*resize_scale_x)))>0).float().to('cuda', dtype=latent_in.dtype) | |
| latent_tmp = F.interpolate(latent_in, (int(latent_in.shape[-2]*resize_scale_y), int(latent_in.shape[-1]*resize_scale_x))) | |
| mask_tmp = torch.roll(mask_tmp, (int(dy/(t/latent_in.shape[-2])*resize_scale_y), int(dx/(t/latent_in.shape[-2])*resize_scale_x)), (-2,-1)) | |
| latent_tmp = torch.roll(latent_tmp, (int(dy/(t/latent_in.shape[-2])*resize_scale_y), int(dx/(t/latent_in.shape[-2])*resize_scale_x)), (-2,-1)) | |
| _mask_temp = torch.zeros(1,1,latent_in.shape[-2], latent_in.shape[-1]).to( | |
| latent_in.device, dtype=latent_in.dtype) | |
| _latent_temp = torch.zeros_like(latent_in) | |
| pad_x = (_mask_temp.shape[-1] - mask_tmp.shape[-1]) // 2 | |
| pad_y = (_mask_temp.shape[-2] - mask_tmp.shape[-2]) // 2 | |
| px_tmp, py_tmp = max(pad_x, 0), max(pad_y, 0) | |
| px_tar, py_tar = max(-pad_x, 0), max(-pad_y, 0) | |
| _mask_temp[:,:,py_tmp:mask_tmp.shape[-2]+py_tmp,px_tmp:mask_tmp.shape[-1]+px_tmp] = mask_tmp[ | |
| :,:,py_tar:_mask_temp.shape[-2]+py_tar,px_tar:_mask_temp.shape[-1]+px_tar] | |
| _latent_temp[:,:,py_tmp:latent_tmp.shape[-2]+py_tmp,px_tmp:latent_tmp.shape[-1]+px_tmp] = latent_tmp[ | |
| :,:,py_tar:_latent_temp.shape[-2]+py_tar,px_tar:_latent_temp.shape[-1]+px_tar] | |
| mask_tmp = (_mask_temp>0.5).float() | |
| latent_tmp = _latent_temp | |
| if edit_kwargs["mask_keep"] is not None: | |
| mask_keep = edit_kwargs["mask_keep"] | |
| mask_keep = (F.interpolate(mask_keep, (latent_in.shape[-2], latent_in.shape[-1]))>0).float().to('cuda', dtype=latent_in.dtype) | |
| else: | |
| mask_keep = 1 - mask_tmp | |
| latent_in = (torch.zeros_like(latent_in)+latent_in*mask_keep+latent_tmp*mask_tmp).to(dtype=latent_in.dtype) | |
| latent_rec = self.editor.pipe.edit( | |
| mode='move', | |
| latent=latent_in, | |
| prompt=prompt, | |
| guidance_scale=guidance_scale, | |
| energy_scale=energy_scale, | |
| latent_noise_ref=ddim_latents, | |
| SDE_strength=SDE_strength, | |
| edit_kwargs=edit_kwargs, | |
| disable_tangent_proj=disable_tangent_proj, | |
| ) | |
| # Scale output latent | |
| if scale_denoised: | |
| _max = torch.max(torch.abs(latent_rec)) | |
| latent_rec = latent_rec * 5 / _max | |
| spec_rec = self.editor.decode_latents(latent_rec) | |
| wav_rc = self.editor.mel_spectrogram_to_waveform(spec_rec) | |
| torch.cuda.empty_cache() | |
| return SoundEditorOutput(wav_rc, spec_rec) | |
| def run_paste( | |
| self, | |
| fbank_bg, | |
| mask_bg, | |
| fbank_fg, | |
| prompt, | |
| prompt_replace, | |
| w_edit, | |
| w_content, | |
| seed, | |
| guidance_scale, | |
| energy_scale, | |
| dx, | |
| dy, | |
| resize_scale_x, | |
| resize_scale_y, | |
| SDE_strength, | |
| save_kv=False, | |
| disable_tangent_proj=False, | |
| scale_denoised=True, | |
| ): | |
| seed_everything(seed) | |
| energy_scale = energy_scale * 1e3 | |
| # Prepare input spec and mask | |
| input_scale = 1 | |
| fbank_bg = maybe_add_dimension(fbank_bg, 4).to( | |
| self.device, dtype=self.precision | |
| ) # shape = (B,C,T,F) | |
| f, t = fbank_bg.shape[-1], fbank_bg.shape[-2] | |
| fbank_fg = maybe_add_dimension(fbank_fg, 4).to( | |
| self.device, dtype=self.precision | |
| ) | |
| # mask_bg = maybe_add_dimension(mask_bg, 3).permute(1,2,0).numpy().astype('uint8') # shape = (C,T,F) | |
| # mask_bg = mask_bg.numpy().astype('uint8') # shape = (C,T,F) | |
| if save_kv: | |
| self.editor.load_adapter() | |
| latent_base = self.editor.fbank2latent(fbank_bg) | |
| #####[START] Original rescale and fit method.##### | |
| # if resize_scale != 1: | |
| # hr, wr = fbank_fg.shape[-2], fbank_fg.shape[-1] | |
| # fbank_fg = F.interpolate( | |
| # fbank_fg, (int(hr * resize_scale), int(wr * resize_scale)) | |
| # ) | |
| # pad_size_x = abs(fbank_fg.shape[-1] - wr) // 2 | |
| # pad_size_y = abs(fbank_fg.shape[-2] - hr) // 2 | |
| # if resize_scale > 1: | |
| # fbank_fg = fbank_fg[ | |
| # :, :, pad_size_y : pad_size_y + hr, pad_size_x : pad_size_x + wr | |
| # ] | |
| # else: | |
| # temp = torch.zeros(1, 1, hr, wr).to(self.device, dtype=self.precision) | |
| # temp[ | |
| # :, | |
| # :, | |
| # pad_size_y : pad_size_y + fbank_fg.shape[-2], | |
| # pad_size_x : pad_size_x + fbank_fg.shape[-1], | |
| # ] = fbank_fg | |
| # fbank_fg = temp | |
| #####[END] Original rescale and fit method.##### | |
| hr, wr = fbank_fg.shape[-2], fbank_fg.shape[-1] | |
| fbank_tmp = torch.zeros_like(fbank_fg) | |
| fbank_fg = F.interpolate( | |
| fbank_fg, (int(hr * resize_scale_y), int(wr * resize_scale_x)) | |
| ) | |
| pad_x = (wr - fbank_fg.shape[-1]) // 2 | |
| pad_y = (hr - fbank_fg.shape[-2]) // 2 | |
| px_tmp, py_tmp = max(pad_x, 0), max(pad_y, 0) | |
| px_tar, py_tar = max(-pad_x, 0), max(-pad_y, 0) | |
| fbank_tmp[:,:,py_tmp:fbank_fg.shape[-2]+py_tmp,px_tmp:fbank_fg.shape[-1]+px_tmp] = fbank_fg[ | |
| :,:,py_tar:fbank_tmp.shape[-2]+py_tar,px_tar:fbank_tmp.shape[-1]+px_tar] | |
| fbank_fg = fbank_tmp | |
| latent_replace = self.editor.fbank2latent(fbank_fg) | |
| ddim_latents = self.editor.ddim_inv( | |
| latent=torch.cat([latent_base, latent_replace]), | |
| prompt=[prompt, prompt_replace], | |
| ) | |
| latent_in = ddim_latents[-1][:1].squeeze(2) # latent_base_noise | |
| scale = 8 * SIZES[max(self.up_ft_index)] / self.up_scale / 2 | |
| edit_kwargs = process_paste( | |
| path_mask=mask_bg, | |
| h=f, | |
| w=t, | |
| dx=dx, | |
| dy=dy, | |
| scale=scale, | |
| input_scale=input_scale, | |
| up_scale=self.up_scale, | |
| up_ft_index=self.up_ft_index, | |
| w_edit=w_edit, | |
| w_content=w_content, | |
| precision=self.precision, | |
| resize_scale_x=resize_scale_x, | |
| resize_scale_y=resize_scale_y, | |
| ) | |
| mask_tmp = ( | |
| F.interpolate( | |
| edit_kwargs["mask_base_cur"].float(), | |
| (latent_in.shape[-2], latent_in.shape[-1]), | |
| ) | |
| > 0 | |
| ).float() | |
| # latent_replace_noise with rolling | |
| latent_tmp = torch.roll( | |
| ddim_latents[-1][1:].squeeze(2), | |
| (int(dy / (t / latent_in.shape[-2])), int(dx / (t / latent_in.shape[-2]))), | |
| (-2, -1), | |
| ) | |
| # blended latent | |
| latent_in = (latent_in * (1 - mask_tmp) + latent_tmp * mask_tmp).to( | |
| dtype=latent_in.dtype | |
| ) | |
| latent_rec = self.editor.pipe.edit( | |
| mode="paste", | |
| latent=latent_in, | |
| prompt=prompt, | |
| guidance_scale=guidance_scale, | |
| energy_scale=energy_scale, | |
| latent_noise_ref=ddim_latents, | |
| SDE_strength=SDE_strength, | |
| edit_kwargs=edit_kwargs, | |
| disable_tangent_proj=disable_tangent_proj, | |
| ) | |
| # Scale output latent | |
| if scale_denoised: | |
| _max = torch.max(torch.abs(latent_rec)) | |
| latent_rec = latent_rec * 5 / _max | |
| spec_rec = self.editor.decode_latents(latent_rec) | |
| wav_rc = self.editor.mel_spectrogram_to_waveform(spec_rec) | |
| torch.cuda.empty_cache() | |
| return SoundEditorOutput(wav_rc, spec_rec) | |
| def run_mix( | |
| self, | |
| fbank_bg, | |
| mask_bg, | |
| fbank_fg, | |
| prompt, | |
| prompt_replace, | |
| w_edit, | |
| w_content, | |
| seed, | |
| guidance_scale, | |
| energy_scale, | |
| dx, | |
| dy, | |
| resize_scale_x, | |
| resize_scale_y, | |
| SDE_strength, | |
| save_kv=False, | |
| bg_to_fg_ratio=0.7, | |
| disable_tangent_proj=False, | |
| scale_denoised=False, | |
| ): | |
| seed_everything(seed) | |
| energy_scale = energy_scale * 1e3 | |
| # Prepare input spec and mask | |
| input_scale = 1 | |
| fbank_bg = maybe_add_dimension(fbank_bg, 4).to( | |
| self.device, dtype=self.precision | |
| ) # shape = (B,C,T,F) | |
| f, t = fbank_bg.shape[-1], fbank_bg.shape[-2] | |
| fbank_fg = maybe_add_dimension(fbank_fg, 4).to( | |
| self.device, dtype=self.precision | |
| ) | |
| if save_kv: | |
| self.editor.load_adapter() | |
| latent_base = self.editor.fbank2latent(fbank_bg) | |
| #####[START] Original rescale and fit method.##### | |
| # if resize_scale != 1: | |
| # hr, wr = fbank_fg.shape[-2], fbank_fg.shape[-1] | |
| # fbank_fg = F.interpolate( | |
| # fbank_fg, (int(hr * resize_scale), int(wr * resize_scale)) | |
| # ) | |
| # pad_size_x = abs(fbank_fg.shape[-1] - wr) // 2 | |
| # pad_size_y = abs(fbank_fg.shape[-2] - hr) // 2 | |
| # if resize_scale > 1: | |
| # fbank_fg = fbank_fg[ | |
| # :, :, pad_size_y : pad_size_y + hr, pad_size_x : pad_size_x + wr | |
| # ] | |
| # else: | |
| # temp = torch.zeros(1, 1, hr, wr).to(self.device, dtype=self.precision) | |
| # temp[ | |
| # :, | |
| # :, | |
| # pad_size_y : pad_size_y + fbank_fg.shape[-2], | |
| # pad_size_x : pad_size_x + fbank_fg.shape[-1], | |
| # ] = fbank_fg | |
| # fbank_fg = temp | |
| #####[END] Original rescale and fit method.##### | |
| hr, wr = fbank_fg.shape[-2], fbank_fg.shape[-1] | |
| fbank_tmp = torch.zeros_like(fbank_fg) | |
| fbank_fg = F.interpolate( | |
| fbank_fg, (int(hr * resize_scale_y), int(wr * resize_scale_x)) | |
| ) | |
| pad_x = (wr - fbank_fg.shape[-1]) // 2 | |
| pad_y = (hr - fbank_fg.shape[-2]) // 2 | |
| px_tmp, py_tmp = max(pad_x, 0), max(pad_y, 0) | |
| px_tar, py_tar = max(-pad_x, 0), max(-pad_y, 0) | |
| fbank_tmp[:,:,py_tmp:fbank_fg.shape[-2]+py_tmp,px_tmp:fbank_fg.shape[-1]+px_tmp] = fbank_fg[ | |
| :,:,py_tar:fbank_tmp.shape[-2]+py_tar,px_tar:fbank_tmp.shape[-1]+px_tar] | |
| fbank_fg = fbank_tmp | |
| latent_replace = self.editor.fbank2latent(fbank_fg) | |
| ddim_latents = self.editor.ddim_inv( | |
| latent=torch.cat([latent_base, latent_replace]), | |
| prompt=[prompt, prompt_replace], | |
| ) | |
| latent_in = ddim_latents[-1][:1].squeeze(2) # latent_base_noise | |
| # TODO: adapt it to different Gen models | |
| scale = 4 * SIZES[max(self.up_ft_index)] / self.up_scale | |
| edit_kwargs = process_paste( | |
| path_mask=mask_bg, | |
| h=f, | |
| w=t, | |
| dx=dx, | |
| dy=dy, | |
| scale=scale, | |
| input_scale=input_scale, | |
| up_scale=self.up_scale, | |
| up_ft_index=self.up_ft_index, | |
| w_edit=w_edit, | |
| w_content=w_content, | |
| precision=self.precision, | |
| resize_scale_x=resize_scale_x, | |
| resize_scale_y=resize_scale_y, | |
| ) | |
| mask_tmp = ( | |
| F.interpolate( | |
| edit_kwargs["mask_base_cur"].float(), | |
| (latent_in.shape[-2], latent_in.shape[-1]), | |
| ) | |
| > 0 | |
| ).float() | |
| # latent_replace_noise with rolling | |
| latent_tmp = torch.roll( | |
| ddim_latents[-1][1:].squeeze(2), | |
| (int(dy / (t / latent_in.shape[-2])), int(dx / (t / latent_in.shape[-2]))), | |
| (-2, -1), | |
| ) | |
| latent_mix = slerp(bg_to_fg_ratio, latent_in, latent_tmp) | |
| latent_in = (latent_in * (1 - mask_tmp) + latent_mix * mask_tmp).to( | |
| dtype=latent_in.dtype | |
| ) | |
| latent_rec = self.editor.pipe.edit( | |
| mode="mix", | |
| latent=latent_in, | |
| prompt=prompt, # NOTE: emperically, make the rec the same as prompt base is the best | |
| guidance_scale=guidance_scale, | |
| energy_scale=energy_scale, | |
| latent_noise_ref=ddim_latents, | |
| SDE_strength=SDE_strength, | |
| edit_kwargs=edit_kwargs, | |
| disable_tangent_proj=disable_tangent_proj, | |
| ) | |
| # Scale output latent | |
| if scale_denoised: | |
| _max = torch.max(torch.abs(latent_rec)) | |
| latent_rec = latent_rec * 5 / _max | |
| spec_rec = self.editor.decode_latents(latent_rec) | |
| wav_rc = self.editor.mel_spectrogram_to_waveform(spec_rec) | |
| torch.cuda.empty_cache() | |
| return SoundEditorOutput(wav_rc, spec_rec) | |
| def run_remove( | |
| self, | |
| fbank_bg, | |
| mask_bg, | |
| fbank_fg, | |
| prompt, | |
| prompt_replace, | |
| w_edit, | |
| w_contrast, | |
| w_content, | |
| seed, | |
| guidance_scale, | |
| energy_scale, | |
| dx, | |
| dy, | |
| resize_scale_x, | |
| resize_scale_y, | |
| SDE_strength, | |
| save_kv=False, | |
| bg_to_fg_ratio=0.5, | |
| iterations=50, | |
| enable_penalty=True, | |
| disable_tangent_proj=False, | |
| scale_denoised=True, | |
| ): | |
| seed_everything(seed) | |
| energy_scale = energy_scale * 1e3 | |
| # Prepare input spec and mask | |
| input_scale = 1 | |
| fbank_bg = maybe_add_dimension(fbank_bg, 4).to( | |
| self.device, dtype=self.precision | |
| ) # shape = (B,C,T,F) | |
| f, t = fbank_bg.shape[-1], fbank_bg.shape[-2] | |
| fbank_fg = maybe_add_dimension(fbank_fg, 4).to( | |
| self.device, dtype=self.precision | |
| ) | |
| if save_kv: | |
| self.editor.load_adapter() | |
| latent_base = self.editor.fbank2latent(fbank_bg) | |
| #####[START] Original rescale and fit method.##### | |
| # if resize_scale != 1: | |
| # hr, wr = fbank_fg.shape[-2], fbank_fg.shape[-1] | |
| # fbank_fg = F.interpolate( | |
| # fbank_fg, (int(hr * resize_scale), int(wr * resize_scale)) | |
| # ) | |
| # pad_size_x = abs(fbank_fg.shape[-1] - wr) // 2 | |
| # pad_size_y = abs(fbank_fg.shape[-2] - hr) // 2 | |
| # if resize_scale > 1: | |
| # fbank_fg = fbank_fg[ | |
| # :, :, pad_size_y : pad_size_y + hr, pad_size_x : pad_size_x + wr | |
| # ] | |
| # else: | |
| # temp = torch.zeros(1, 1, hr, wr).to(self.device, dtype=self.precision) | |
| # temp[ | |
| # :, | |
| # :, | |
| # pad_size_y : pad_size_y + fbank_fg.shape[-2], | |
| # pad_size_x : pad_size_x + fbank_fg.shape[-1], | |
| # ] = fbank_fg | |
| # fbank_fg = temp | |
| #####[END] Original rescale and fit method.##### | |
| hr, wr = fbank_fg.shape[-2], fbank_fg.shape[-1] | |
| fbank_tmp = torch.zeros_like(fbank_fg) | |
| fbank_fg = F.interpolate( | |
| fbank_fg, (int(hr * resize_scale_y), int(wr * resize_scale_x)) | |
| ) | |
| pad_x = (wr - fbank_fg.shape[-1]) // 2 | |
| pad_y = (hr - fbank_fg.shape[-2]) // 2 | |
| px_tmp, py_tmp = max(pad_x, 0), max(pad_y, 0) | |
| px_tar, py_tar = max(-pad_x, 0), max(-pad_y, 0) | |
| fbank_tmp[:,:,py_tmp:fbank_fg.shape[-2]+py_tmp,px_tmp:fbank_fg.shape[-1]+px_tmp] = fbank_fg[ | |
| :,:,py_tar:fbank_tmp.shape[-2]+py_tar,px_tar:fbank_tmp.shape[-1]+px_tar] | |
| fbank_fg = fbank_tmp | |
| latent_replace = self.editor.fbank2latent(fbank_fg) | |
| ddim_latents = self.editor.ddim_inv( | |
| latent=torch.cat([latent_base, latent_replace]), | |
| prompt=[prompt, prompt_replace], | |
| ) | |
| latent_in = ddim_latents[-1][:1].squeeze(2) | |
| # TODO: adapt it to different Gen models | |
| scale = 4 * SIZES[max(self.up_ft_index)] / self.up_scale | |
| edit_kwargs = process_remove( | |
| path_mask=mask_bg, | |
| h=f, | |
| w=t, | |
| dx=dx, | |
| dy=dy, | |
| scale=scale, | |
| input_scale=input_scale, | |
| up_scale=self.up_scale, | |
| up_ft_index=self.up_ft_index, | |
| w_edit=w_edit, | |
| w_contrast=w_contrast, | |
| w_content=w_content, | |
| precision=self.precision, | |
| resize_scale_x=resize_scale_x, | |
| resize_scale_y=resize_scale_y, | |
| ) | |
| mask_tmp = ( | |
| F.interpolate( | |
| edit_kwargs["mask_base_cur"].float(), | |
| (latent_in.shape[-2], latent_in.shape[-1]), | |
| ) | |
| > 0 | |
| ).float() | |
| latent_tmp = torch.roll( | |
| ddim_latents[-1][1:].squeeze(2), | |
| (int(dy / (t / latent_in.shape[-2])), int(dx / (t / latent_in.shape[-2]))), | |
| (-2, -1), | |
| ) | |
| # # F(B) <- F(M) - a * F(A) | |
| # latent_new = torch.randn_like(latent_tmp) | |
| # # latent_tmp = latent_tmp * latent_in.max()/latent_tmp.max() * 0.6 # 0.6 is the scale factor, a | |
| # m_ori, s_ori = latent_new.mean(dim=-2, keepdim=True), latent_new.std(dim=-2, keepdim=True) | |
| # # m_ref, s_ref = latent_tmp.mean(dim=-2, keepdim=True), latent_tmp.std(dim=-2, keepdim=True) | |
| # m_src, s_src = latent_in.mean(dim=-2, keepdim=True), latent_in.std(dim=-2, keepdim=True) | |
| # # s_new = torch.sqrt(s_src**2 - s_ref**2) | |
| # # latent_new = (latent_new - m_ori) / s_ori * s_new + (m_src - m_ref) | |
| # latent_new = (latent_new - m_ori) / s_ori * s_src + m_src | |
| # # Start from the latent of neighbor region | |
| # _m = mask_tmp.squeeze().sum(dim=1).nonzero().cpu() | |
| # stt_frame, end_frame = _m.min(), _m.max() | |
| # latent_neighbor = fill_with_neighbor( | |
| # latent_in.squeeze(0), stt_frame, end_frame, neighbor_length=100 | |
| # ) # 1s | |
| # __neighbor_energy_per_freq = (latent_neighbor*mask_tmp).mean(dim=0) | |
| # latent_neighbor[:,:,8:] *= 0.0001 | |
| # Latent neighbor start from randomlized latent | |
| latent_neighbor = torch.randn_like(latent_in.squeeze(0)) * 0.9 | |
| latent_neighbor = latent_neighbor + torch.randn_like(latent_neighbor) * 1e-3 # a little perturbation | |
| latent_neighbor, _ = optimize_neighborhood_points( | |
| latent_neighbor * mask_tmp, | |
| latent_tmp * mask_tmp, | |
| latent_in * mask_tmp, | |
| t=bg_to_fg_ratio, | |
| iterations=iterations, | |
| enable_penalty=enable_penalty, | |
| enable_tangent_proj=True, | |
| ) # TODO: try to turn off tangent | |
| latent_in = (latent_in * (1 - mask_tmp) + latent_neighbor * mask_tmp).to( | |
| dtype=latent_in.dtype | |
| ) | |
| # latent_neighbor = torch.randn_like(latent_in) * 0.9 | |
| # latent_in = (latent_in * (1 - mask_tmp) + latent_neighbor * mask_tmp).to( | |
| # dtype=latent_in.dtype | |
| # ) | |
| latent_rec = self.editor.pipe.edit( | |
| mode="remove", | |
| latent=latent_in, | |
| prompt=prompt, | |
| guidance_scale=guidance_scale, | |
| energy_scale=energy_scale, | |
| latent_noise_ref=ddim_latents, | |
| SDE_strength=SDE_strength, | |
| edit_kwargs=edit_kwargs, | |
| num_inference_steps=self.num_ddim_steps, | |
| start_time=self.num_ddim_steps, | |
| disable_tangent_proj=disable_tangent_proj, | |
| ) | |
| # Scale output latent | |
| if scale_denoised: | |
| _max = torch.max(torch.abs(latent_rec)) | |
| latent_rec = latent_rec * 5 / _max | |
| spec_rec = self.editor.decode_latents(latent_rec) | |
| wav_rc = self.editor.mel_spectrogram_to_waveform(spec_rec) | |
| torch.cuda.empty_cache() | |
| return SoundEditorOutput(wav_rc, spec_rec) | |
| def run_audio_generation( | |
| self, | |
| fbank_bg, | |
| mask_bg, | |
| fbank_fg, | |
| prompt, | |
| prompt_replace, | |
| w_edit, | |
| w_content, | |
| seed, | |
| guidance_scale, | |
| energy_scale, | |
| dx, | |
| dy, | |
| resize_scale_x, | |
| resize_scale_y, | |
| SDE_strength, | |
| save_kv=False, | |
| disable_tangent_proj=False, | |
| scale_denoised=True, | |
| ): | |
| seed_everything(seed) | |
| energy_scale = energy_scale * 1e3 | |
| # Prepare input spec and mask | |
| input_scale = 1 | |
| fbank_bg = maybe_add_dimension(fbank_bg, 4).to( | |
| self.device, dtype=self.precision | |
| ) # shape = (B,C,T,F) | |
| f, t = fbank_bg.shape[-1], fbank_bg.shape[-2] | |
| fbank_fg = maybe_add_dimension(fbank_fg, 4).to( | |
| self.device, dtype=self.precision | |
| ) | |
| if save_kv: | |
| self.editor.load_adapter() | |
| latent_base = self.editor.fbank2latent(fbank_bg) | |
| #####[START] Original rescale and fit method.##### | |
| # if resize_scale != 1: | |
| # hr, wr = fbank_fg.shape[-2], fbank_fg.shape[-1] | |
| # fbank_fg = F.interpolate( | |
| # fbank_fg, (int(hr * resize_scale), int(wr * resize_scale)) | |
| # ) | |
| # pad_size_x = abs(fbank_fg.shape[-1] - wr) // 2 | |
| # pad_size_y = abs(fbank_fg.shape[-2] - hr) // 2 | |
| # if resize_scale > 1: | |
| # fbank_fg = fbank_fg[ | |
| # :, :, pad_size_y : pad_size_y + hr, pad_size_x : pad_size_x + wr | |
| # ] | |
| # else: | |
| # temp = torch.zeros(1, 1, hr, wr).to(self.device, dtype=self.precision) | |
| # temp[ | |
| # :, | |
| # :, | |
| # pad_size_y : pad_size_y + fbank_fg.shape[-2], | |
| # pad_size_x : pad_size_x + fbank_fg.shape[-1], | |
| # ] = fbank_fg | |
| # fbank_fg = temp | |
| #####[END] Original rescale and fit method.##### | |
| hr, wr = fbank_fg.shape[-2], fbank_fg.shape[-1] | |
| fbank_tmp = torch.zeros_like(fbank_fg) | |
| fbank_fg = F.interpolate( | |
| fbank_fg, (int(hr * resize_scale_y), int(wr * resize_scale_x)) | |
| ) | |
| pad_x = (wr - fbank_fg.shape[-1]) // 2 | |
| pad_y = (hr - fbank_fg.shape[-2]) // 2 | |
| px_tmp, py_tmp = max(pad_x, 0), max(pad_y, 0) | |
| px_tar, py_tar = max(-pad_x, 0), max(-pad_y, 0) | |
| fbank_tmp[:,:,py_tmp:fbank_fg.shape[-2]+py_tmp,px_tmp:fbank_fg.shape[-1]+px_tmp] = fbank_fg[ | |
| :,:,py_tar:fbank_tmp.shape[-2]+py_tar,px_tar:fbank_tmp.shape[-1]+px_tar] | |
| fbank_fg = fbank_tmp | |
| ddim_latents = self.editor.ddim_inv( | |
| latent=torch.cat([latent_base, latent_base]), prompt=[prompt, prompt] | |
| ) | |
| latent_in = ddim_latents[-1][:1].squeeze(2) | |
| # TODO: adapt it to different Gen models | |
| scale = 4 * SIZES[max(self.up_ft_index)] / self.up_scale | |
| edit_kwargs = process_paste( | |
| path_mask=mask_bg, | |
| h=f, | |
| w=t, | |
| dx=dx, | |
| dy=dy, | |
| scale=scale, | |
| input_scale=input_scale, | |
| up_scale=self.up_scale, | |
| up_ft_index=self.up_ft_index, | |
| w_edit=w_edit, | |
| w_content=w_content, | |
| precision=self.precision, | |
| resize_scale_x=resize_scale_x, | |
| resize_scale_y=resize_scale_y, | |
| ) | |
| latent_tmp = torch.randn_like(latent_in) | |
| mean, std = latent_in.mean(dim=-1, keepdim=True), latent_in.std( | |
| dim=-1, keepdim=True | |
| ) | |
| m_ori, s_ori = latent_tmp.mean(dim=-1, keepdim=True), latent_tmp.std( | |
| dim=-1, keepdim=True | |
| ) | |
| latent_tmp = (latent_tmp - m_ori) / s_ori * std + mean | |
| latent_in = latent_tmp | |
| latent_rec = self.editor.pipe.edit( | |
| mode="generate", | |
| latent=latent_in, | |
| prompt=prompt_replace, | |
| guidance_scale=guidance_scale, | |
| energy_scale=0, | |
| latent_noise_ref=ddim_latents, | |
| SDE_strength=SDE_strength, | |
| edit_kwargs=edit_kwargs, | |
| num_inference_steps=self.num_ddim_steps, | |
| start_time=self.num_ddim_steps, | |
| alg="D", | |
| disable_tangent_proj=disable_tangent_proj, | |
| ) | |
| # Scale output latent | |
| if scale_denoised: | |
| _max = torch.max(torch.abs(latent_rec)) | |
| latent_rec = latent_rec * 5 / _max | |
| spec_rec = self.editor.decode_latents(latent_rec) | |
| wav_rc = self.editor.mel_spectrogram_to_waveform(spec_rec) | |
| torch.cuda.empty_cache() | |
| return SoundEditorOutput(wav_rc, spec_rec) | |
| def run_style_transferring( | |
| self, | |
| fbank_bg, | |
| mask_bg, | |
| fbank_fg, | |
| prompt, | |
| prompt_replace, | |
| w_edit, | |
| w_content, | |
| seed, | |
| guidance_scale, | |
| energy_scale, | |
| dx, | |
| dy, | |
| resize_scale_x, | |
| resize_scale_y, | |
| SDE_strength, | |
| save_kv=True, | |
| disable_tangent_proj=False, | |
| scale_denoised=True, | |
| ): | |
| seed_everything(seed) | |
| energy_scale = energy_scale * 1e3 | |
| # Prepare input spec and mask | |
| input_scale = 1 | |
| fbank_bg = maybe_add_dimension(fbank_bg, 4).to( | |
| self.device, dtype=self.precision | |
| ) # shape = (B,C,T,F) | |
| f, t = fbank_bg.shape[-1], fbank_bg.shape[-2] | |
| if save_kv: | |
| self.editor.load_adapter() | |
| latent_base = self.editor.fbank2latent(fbank_bg) | |
| # if(torch.max(torch.abs(latent_base)) > 1e2): | |
| # latent_base = torch.clip(latent_base, min=-10, max=10) | |
| ddim_latents = self.editor.ddim_inv(latent=latent_base, prompt=prompt, | |
| save_kv=True, mode="style_transfer",) | |
| latent_in = ddim_latents[-1].squeeze(2) | |
| scale = 4 * SIZES[max(self.up_ft_index)] / self.up_scale | |
| edit_kwargs = process_paste( | |
| path_mask=mask_bg, | |
| h=f, | |
| w=t, | |
| dx=dx, | |
| dy=dy, | |
| scale=scale, | |
| input_scale=input_scale, | |
| up_scale=self.up_scale, | |
| up_ft_index=self.up_ft_index, | |
| w_edit=w_edit, | |
| w_content=w_content, | |
| precision=self.precision, | |
| resize_scale_x=resize_scale_x, | |
| resize_scale_y=resize_scale_y, | |
| ) | |
| # latent_tmp = torch.randn_like(latent_in) | |
| # mean, std = latent_in.mean(dim=-1, keepdim=True), latent_in.std(dim=-1, keepdim=True) | |
| # m_ori, s_ori = latent_tmp.mean(dim=-1, keepdim=True), latent_tmp.std(dim=-1, keepdim=True) | |
| # latent_tmp = (latent_tmp - m_ori) / s_ori * std + mean | |
| # latent_in = latent_tmp | |
| # import pdb; pdb.set_trace() | |
| latent_rec = self.editor.pipe.edit( | |
| mode="style_transfer", | |
| latent=latent_in, | |
| prompt=prompt_replace, | |
| guidance_scale=guidance_scale, | |
| energy_scale=energy_scale, | |
| latent_noise_ref=ddim_latents, | |
| SDE_strength=SDE_strength, | |
| edit_kwargs=edit_kwargs, | |
| num_inference_steps=self.num_ddim_steps, | |
| start_time=self.num_ddim_steps, | |
| alg="D", | |
| disable_tangent_proj=disable_tangent_proj, | |
| ) | |
| # Scale output latent | |
| if scale_denoised: | |
| _max = torch.max(torch.abs(latent_rec)) | |
| latent_rec = latent_rec * 5 / _max | |
| spec_rec = self.editor.decode_latents(latent_rec) | |
| wav_rc = self.editor.mel_spectrogram_to_waveform(spec_rec) | |
| torch.cuda.empty_cache() | |
| return SoundEditorOutput(wav_rc, spec_rec) | |
| def run_ddim_inversion( | |
| self, | |
| fbank_bg, | |
| mask_bg, | |
| fbank_fg, | |
| prompt, | |
| prompt_replace, | |
| w_edit, | |
| w_content, | |
| seed, | |
| guidance_scale, | |
| energy_scale, | |
| dx, | |
| dy, | |
| resize_scale_x, | |
| resize_scale_y, | |
| SDE_strength, | |
| save_kv=False, | |
| disable_tangent_proj=False, | |
| scale_denoised=True, | |
| ): | |
| seed_everything(seed) | |
| energy_scale = energy_scale * 1e3 | |
| # Prepare input spec and mask | |
| input_scale = 1 | |
| fbank_bg = maybe_add_dimension(fbank_bg, 4).to( | |
| self.device, dtype=self.precision | |
| ) # shape = (B,C,T,F) | |
| f, t = fbank_bg.shape[-1], fbank_bg.shape[-2] | |
| fbank_fg = maybe_add_dimension(fbank_fg, 4).to( | |
| self.device, dtype=self.precision | |
| ) | |
| if save_kv: | |
| self.editor.load_adapter() | |
| latent_base = self.editor.fbank2latent(fbank_bg) | |
| if resize_scale != 1: | |
| hr, wr = fbank_fg.shape[-2], fbank_fg.shape[-1] | |
| fbank_fg = F.interpolate( | |
| fbank_fg, (int(hr * resize_scale), int(wr * resize_scale)) | |
| ) | |
| pad_size_x = abs(fbank_fg.shape[-1] - wr) // 2 | |
| pad_size_y = abs(fbank_fg.shape[-2] - hr) // 2 | |
| if resize_scale > 1: | |
| fbank_fg = fbank_fg[ | |
| :, :, pad_size_y : pad_size_y + hr, pad_size_x : pad_size_x + wr | |
| ] | |
| else: | |
| temp = torch.zeros(1, 1, hr, wr).to(self.device, dtype=self.precision) | |
| temp[ | |
| :, | |
| :, | |
| pad_size_y : pad_size_y + fbank_fg.shape[-2], | |
| pad_size_x : pad_size_x + fbank_fg.shape[-1], | |
| ] = fbank_fg | |
| fbank_fg = temp | |
| # latent_replace = self.editor.fbank2latent(fbank_fg) | |
| ddim_latents = self.editor.ddim_inv( | |
| latent=torch.cat([latent_base, latent_base]), prompt=[prompt, prompt] | |
| ) | |
| latent_in = ddim_latents[-1][:1].squeeze(2) | |
| # TODO: adapt it to different Gen models | |
| scale = 4 * SIZES[max(self.up_ft_index)] / self.up_scale | |
| edit_kwargs = process_paste( | |
| path_mask=mask_bg, | |
| h=f, | |
| w=t, | |
| dx=dx, | |
| dy=dy, | |
| scale=scale, | |
| input_scale=input_scale, | |
| up_scale=self.up_scale, | |
| up_ft_index=self.up_ft_index, | |
| w_edit=w_edit, | |
| w_content=w_content, | |
| precision=self.precision, | |
| resize_scale_x=resize_scale_x, | |
| resize_scale_y=resize_scale_y, | |
| ) | |
| latent_rec = self.editor.pipe.edit( | |
| mode="generate", | |
| latent=latent_in, | |
| prompt=prompt_replace, | |
| guidance_scale=guidance_scale, | |
| energy_scale=0, | |
| latent_noise_ref=ddim_latents, | |
| SDE_strength=SDE_strength, | |
| edit_kwargs=edit_kwargs, | |
| num_inference_steps=self.num_ddim_steps, | |
| start_time=self.num_ddim_steps, | |
| alg="D", | |
| disable_tangent_proj=disable_tangent_proj, | |
| ) | |
| # Scale output latent | |
| if scale_denoised: | |
| _max = torch.max(torch.abs(latent_rec)) | |
| latent_rec = latent_rec * 5 / _max | |
| spec_rec = self.editor.decode_latents(latent_rec) | |
| wav_rc = self.editor.mel_spectrogram_to_waveform(spec_rec) | |
| torch.cuda.empty_cache() | |
| return SoundEditorOutput(wav_rc, spec_rec) | |
| if __name__ == "__main__": | |
| mdl = AudioMorphix( | |
| "cvssp/audioldm-l-full", num_ddim_steps=50 | |
| ) # "cvssp/audioldm-l-full" | "declare-lab/tango" | |
| print(mdl.__dict__) | |