Linoy Tsaban
commited on
Commit
·
6255790
1
Parent(s):
6908973
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,11 +4,88 @@ import requests
|
|
| 4 |
from io import BytesIO
|
| 5 |
from diffusers import StableDiffusionPipeline
|
| 6 |
from diffusers import DDIMScheduler
|
| 7 |
-
from utils import
|
|
|
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from io import BytesIO
|
| 5 |
from diffusers import StableDiffusionPipeline
|
| 6 |
from diffusers import DDIMScheduler
|
| 7 |
+
from utils import *
|
| 8 |
+
from inversion_utils import *
|
| 9 |
|
| 10 |
+
model_id = "CompVis/stable-diffusion-v1-4"
|
| 11 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 12 |
+
sd_pipe = StableDiffusionPipeline.from_pretrained(model_id).to(device)
|
| 13 |
+
sd_pipe.scheduler = DDIMScheduler.from_config(model_id, subfolder = "scheduler")
|
| 14 |
+
from torch import autocast, inference_mode
|
| 15 |
|
| 16 |
+
def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
|
| 17 |
+
|
| 18 |
+
# inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
|
| 19 |
+
# based on the code in https://github.com/inbarhub/DDPM_inversion
|
| 20 |
+
|
| 21 |
+
# returns wt, zs, wts:
|
| 22 |
+
# wt - inverted latent
|
| 23 |
+
# wts - intermediate inverted latents
|
| 24 |
+
# zs - noise maps
|
| 25 |
+
|
| 26 |
+
sd_pipe.scheduler.set_timesteps(num_diffusion_steps)
|
| 27 |
+
|
| 28 |
+
# vae encode image
|
| 29 |
+
with autocast("cuda"), inference_mode():
|
| 30 |
+
w0 = (sd_pipe.vae.encode(x0).latent_dist.mode() * 0.18215).float()
|
| 31 |
+
|
| 32 |
+
# find Zs and wts - forward process
|
| 33 |
+
wt, zs, wts = inversion_forward_process(sd_pipe, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src, prog_bar=True, num_inference_steps=num_diffusion_steps)
|
| 34 |
+
return wt, zs, wts
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def sample(wt, zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
|
| 39 |
+
|
| 40 |
+
# reverse process (via Zs and wT)
|
| 41 |
+
w0, _ = inversion_reverse_process(sd_pipe, xT=wts[skip], etas=eta, prompts=[prompt_tar], cfg_scales=[cfg_scale_tar], prog_bar=True, zs=zs[skip:])
|
| 42 |
+
|
| 43 |
+
# vae decode image
|
| 44 |
+
with autocast("cuda"), inference_mode():
|
| 45 |
+
x0_dec = sd_pipe.vae.decode(1 / 0.18215 * w0).sample
|
| 46 |
+
if x0_dec.dim()<4:
|
| 47 |
+
x0_dec = x0_dec[None,:,:,:]
|
| 48 |
+
img = image_grid(x0_dec)
|
| 49 |
+
return img
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def edit(input_image, input_image_prompt, target_prompt, guidance_scale=15, skip=36, num_diffusion_steps=100):
|
| 55 |
+
offsets=(0,0,0,0)
|
| 56 |
+
x0 = load_512(input_image, *offsets, device)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# invert
|
| 60 |
+
wt, zs, wts = invert(x0 =x0 , prompt_src=input_image_prompt, num_diffusion_steps=num_diffusion_steps)
|
| 61 |
+
latnets = wts[skip].expand(1, -1, -1, -1)
|
| 62 |
+
|
| 63 |
+
eta = 1
|
| 64 |
+
#pure DDPM output
|
| 65 |
+
pure_ddpm_out = sample(wt, zs, wts, prompt_tar=target_prompt,
|
| 66 |
+
cfg_scale_tar=guidance_scale, skip=skip,
|
| 67 |
+
eta = eta)
|
| 68 |
+
return pure_ddpm_out
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# See the gradio docs for the types of inputs and outputs available
|
| 72 |
+
inputs = [
|
| 73 |
+
gr.Image(label="input image", shape=(512, 512)),
|
| 74 |
+
gr.Textbox(label="input prompt"),
|
| 75 |
+
gr.Textbox(label="target prompt"),
|
| 76 |
+
gr.Slider(label="guidance_scale", minimum=7, maximum=18, value=15),
|
| 77 |
+
gr.Slider(label="skip", minimum=0, maximum=40, value=36),
|
| 78 |
+
gr.Slider(label="num_diffusion_steps", minimum=0, maximum=300, value=100),
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
]
|
| 82 |
+
outputs = gr.Image(label="result")
|
| 83 |
+
|
| 84 |
+
# And the minimal interface
|
| 85 |
+
demo = gr.Interface(
|
| 86 |
+
fn=edit,
|
| 87 |
+
inputs=inputs,
|
| 88 |
+
outputs=outputs,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
demo.launch()
|