Spaces:
Runtime error
Runtime error
remove all optimization, add safety checker
Browse files
app.py
CHANGED
|
@@ -12,12 +12,12 @@ import numpy as np
|
|
| 12 |
from torchvision import transforms
|
| 13 |
|
| 14 |
from transformers import CLIPTextModel, CLIPTokenizer
|
|
|
|
| 15 |
|
| 16 |
from diffusers import UniPCMultistepScheduler
|
| 17 |
from diffusers import AutoencoderKL
|
| 18 |
from diffusers import StableDiffusionPipeline
|
| 19 |
-
from diffusers.
|
| 20 |
-
import intel_extension_for_pytorch as ipex
|
| 21 |
|
| 22 |
from stablegarment.models import GarmentEncoderModel,ControlNetModel
|
| 23 |
from stablegarment.piplines import StableGarmentPipeline,StableGarmentControlNetPipeline
|
|
@@ -38,27 +38,8 @@ garment_encoder = garment_encoder.to(device=device,dtype=torch_dtype)
|
|
| 38 |
pipeline_t2i = StableGarmentPipeline.from_pretrained(base_model_path, vae=vae, torch_dtype=torch_dtype, use_safetensors=True,).to(device=device) # variant="fp16"
|
| 39 |
# pipeline = StableDiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V4.0_noVAE", vae=vae, torch_dtype=torch_dtype).to(device=device)
|
| 40 |
pipeline_t2i.scheduler = scheduler
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
# speed up for cpu
|
| 44 |
-
# to channels last
|
| 45 |
-
pipeline_t2i.unet = pipeline_t2i.unet.to(memory_format=torch.channels_last)
|
| 46 |
-
pipeline_t2i.vae = pipeline_t2i.vae.to(memory_format=torch.channels_last)
|
| 47 |
-
pipeline_t2i.text_encoder = pipeline_t2i.text_encoder.to(memory_format=torch.channels_last)
|
| 48 |
-
# pipeline_t2i.safety_checker = pipeline_t2i.safety_checker.to(memory_format=torch.channels_last)
|
| 49 |
-
|
| 50 |
-
# Create random input to enable JIT compilation
|
| 51 |
-
sample = torch.randn(2,4,64,48).type(torch_dtype)
|
| 52 |
-
timestep = torch.rand(1)*999
|
| 53 |
-
encoder_hidden_status = torch.randn(2,77,768).type(torch_dtype)
|
| 54 |
-
input_example = (sample, timestep, encoder_hidden_status)
|
| 55 |
-
|
| 56 |
-
# optimize with IPEX
|
| 57 |
-
pipeline_t2i.unet = ipex.optimize(pipeline_t2i.unet.eval(), dtype=torch.bfloat16, inplace=True, sample_input=input_example)
|
| 58 |
-
pipeline_t2i.vae = ipex.optimize(pipeline_t2i.vae.eval(), dtype=torch.bfloat16, inplace=True)
|
| 59 |
-
pipeline_t2i.text_encoder = ipex.optimize(pipeline_t2i.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)
|
| 60 |
-
# pipeline_t2i.safety_checker = ipex.optimize(pipeline_t2i.safety_checker.eval(), dtype=torch.bfloat16, inplace=True)
|
| 61 |
-
|
| 62 |
|
| 63 |
pipeline_tryon = None
|
| 64 |
'''
|
|
@@ -77,7 +58,6 @@ pipeline_tryon = StableGarmentControlNetPipeline(
|
|
| 77 |
).to(device=device,dtype=torch_dtype)
|
| 78 |
'''
|
| 79 |
|
| 80 |
-
|
| 81 |
def prepare_controlnet_inputs(agn_mask_list,densepose_list):
|
| 82 |
for i,agn_mask_img in enumerate(agn_mask_list):
|
| 83 |
agn_mask_img = np.array(agn_mask_img.convert("L"))
|
|
@@ -101,7 +81,7 @@ def tryon(prompt,init_image,garment_top,garment_down,):
|
|
| 101 |
garment_images = [garment_top,]
|
| 102 |
prompt = [prompt,]
|
| 103 |
cloth_prompt = ["",]
|
| 104 |
-
controlnet_condition = prepare_controlnet_inputs([image_agn_mask],[densepose_image])
|
| 105 |
|
| 106 |
images = pipeline_tryon(prompt, negative_prompt="",cloth_prompt=cloth_prompt, # negative_cloth_prompt = n_prompt,
|
| 107 |
height=height,width=width,num_inference_steps=25,guidance_scale=1.5,eta=0.0,
|
|
@@ -128,7 +108,7 @@ def text2image(prompt,init_image,garment_top,garment_down,style_fidelity=1.):
|
|
| 128 |
garment_encoder=garment_encoder,garment_image=garment_images,).images
|
| 129 |
return images[0]
|
| 130 |
|
| 131 |
-
# def text2image(prompt,init_image,garment_top,garment_down
|
| 132 |
# return pipeline(prompt).images[0]
|
| 133 |
|
| 134 |
def infer(prompt,init_image,garment_top,garment_down,t2i_only,style_fidelity):
|
|
@@ -166,6 +146,8 @@ model = opj(model_dir, "13987_00.jpg")
|
|
| 166 |
all_person = [opj(model_dir,fname) for fname in os.listdir(model_dir) if fname.endswith(".jpg")]
|
| 167 |
with gr.Blocks(css = ".output-image, .input-image, .image-preview {height: 400px !important} ", ) as gradio_app:
|
| 168 |
gr.Markdown("# StableGarment")
|
|
|
|
|
|
|
| 169 |
with gr.Row():
|
| 170 |
with gr.Column():
|
| 171 |
init_image = gr.Image(sources='clipboard', type="filepath", label="model", value=None, interactive=False)
|
|
@@ -207,6 +189,7 @@ with gr.Blocks(css = ".output-image, .input-image, .image-preview {height: 400px
|
|
| 207 |
style_fidelity,
|
| 208 |
],
|
| 209 |
outputs=[gallery],)
|
|
|
|
| 210 |
|
| 211 |
if __name__ == "__main__":
|
| 212 |
gradio_app.launch()
|
|
|
|
| 12 |
from torchvision import transforms
|
| 13 |
|
| 14 |
from transformers import CLIPTextModel, CLIPTokenizer
|
| 15 |
+
from transformers.models.clip.image_processing_clip import CLIPImageProcessor
|
| 16 |
|
| 17 |
from diffusers import UniPCMultistepScheduler
|
| 18 |
from diffusers import AutoencoderKL
|
| 19 |
from diffusers import StableDiffusionPipeline
|
| 20 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
|
|
|
| 21 |
|
| 22 |
from stablegarment.models import GarmentEncoderModel,ControlNetModel
|
| 23 |
from stablegarment.piplines import StableGarmentPipeline,StableGarmentControlNetPipeline
|
|
|
|
| 38 |
pipeline_t2i = StableGarmentPipeline.from_pretrained(base_model_path, vae=vae, torch_dtype=torch_dtype, use_safetensors=True,).to(device=device) # variant="fp16"
|
| 39 |
# pipeline = StableDiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V4.0_noVAE", vae=vae, torch_dtype=torch_dtype).to(device=device)
|
| 40 |
pipeline_t2i.scheduler = scheduler
|
| 41 |
+
pipeline_t2i.safety_checker = StableDiffusionSafetyChecker.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch_dtype, subfolder="safety_checker").to(device=device)
|
| 42 |
+
pipeline_t2i.feature_extractor = CLIPImageProcessor.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch_dtype, subfolder="feature_extractor")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
pipeline_tryon = None
|
| 45 |
'''
|
|
|
|
| 58 |
).to(device=device,dtype=torch_dtype)
|
| 59 |
'''
|
| 60 |
|
|
|
|
| 61 |
def prepare_controlnet_inputs(agn_mask_list,densepose_list):
|
| 62 |
for i,agn_mask_img in enumerate(agn_mask_list):
|
| 63 |
agn_mask_img = np.array(agn_mask_img.convert("L"))
|
|
|
|
| 81 |
garment_images = [garment_top,]
|
| 82 |
prompt = [prompt,]
|
| 83 |
cloth_prompt = ["",]
|
| 84 |
+
controlnet_condition = prepare_controlnet_inputs([image_agn_mask],[densepose_image]).type(torch_dtype)
|
| 85 |
|
| 86 |
images = pipeline_tryon(prompt, negative_prompt="",cloth_prompt=cloth_prompt, # negative_cloth_prompt = n_prompt,
|
| 87 |
height=height,width=width,num_inference_steps=25,guidance_scale=1.5,eta=0.0,
|
|
|
|
| 108 |
garment_encoder=garment_encoder,garment_image=garment_images,).images
|
| 109 |
return images[0]
|
| 110 |
|
| 111 |
+
# def text2image(prompt,init_image,garment_top,garment_down,*args,**kwargs):
|
| 112 |
# return pipeline(prompt).images[0]
|
| 113 |
|
| 114 |
def infer(prompt,init_image,garment_top,garment_down,t2i_only,style_fidelity):
|
|
|
|
| 146 |
all_person = [opj(model_dir,fname) for fname in os.listdir(model_dir) if fname.endswith(".jpg")]
|
| 147 |
with gr.Blocks(css = ".output-image, .input-image, .image-preview {height: 400px !important} ", ) as gradio_app:
|
| 148 |
gr.Markdown("# StableGarment")
|
| 149 |
+
gr.Markdown("Demo for [StableGarment: Garment-Centric Generation via Stable Diffusion](https://arxiv.org/abs/2403.10783).")
|
| 150 |
+
gr.Markdown("*Running on cpu, so it is super slow. Feel free to duplicate the space or visit [StableGarment](https://github.com/logn-2024/StableGarment) for more info.*")
|
| 151 |
with gr.Row():
|
| 152 |
with gr.Column():
|
| 153 |
init_image = gr.Image(sources='clipboard', type="filepath", label="model", value=None, interactive=False)
|
|
|
|
| 189 |
style_fidelity,
|
| 190 |
],
|
| 191 |
outputs=[gallery],)
|
| 192 |
+
gr.Markdown("We borrow some code from [OutfitAnyone](https://huggingface.co/spaces/HumanAIGC/OutfitAnyone), thanks. This demo is not safe for all audiences, which may reflect implicit bias and other defects of base model.")
|
| 193 |
|
| 194 |
if __name__ == "__main__":
|
| 195 |
gradio_app.launch()
|