Spaces:
Runtime error
Runtime error
support CPU
Browse files
app.py
CHANGED
|
@@ -162,7 +162,7 @@ class GraphitPipeline(StableDiffusionInstructPix2PixPipeline):
|
|
| 162 |
|
| 163 |
# 2. Encode input prompt
|
| 164 |
cond_embeds = torch.cat([image_cond_embeds, negative_image_cond_embeds])
|
| 165 |
-
cond_embeds = einops.repeat(cond_embeds, 'b n d -> (b num) n d', num=num_images_per_prompt)
|
| 166 |
prompt_embeds = cond_embeds
|
| 167 |
|
| 168 |
# 3. Preprocess image
|
|
@@ -312,38 +312,43 @@ class CustomRealESRGAN(RealESRGAN):
|
|
| 312 |
|
| 313 |
def build_models(args):
|
| 314 |
# Load scheduler, tokenizer and models.
|
|
|
|
|
|
|
| 315 |
|
| 316 |
model_path = 'navervision/Graphit-SD'
|
| 317 |
unet = UNet2DConditionModel.from_pretrained(
|
| 318 |
-
model_path, torch_dtype=
|
| 319 |
)
|
| 320 |
|
| 321 |
vae_name = 'stabilityai/sd-vae-ft-ema'
|
| 322 |
-
vae = AutoencoderKL.from_pretrained(vae_name, torch_dtype=
|
| 323 |
|
| 324 |
model_name = 'timbrooks/instruct-pix2pix'
|
| 325 |
-
pipe = GraphitPipeline.from_pretrained(model_name, torch_dtype=
|
| 326 |
unet = unet,
|
| 327 |
vae = vae,
|
| 328 |
)
|
| 329 |
-
pipe = pipe.to(
|
| 330 |
|
| 331 |
## load CompoDiff
|
| 332 |
compodiff_model, clip_model, clip_preprocess, clip_tokenizer = compodiff.build_model()
|
| 333 |
-
compodiff_model, clip_model = compodiff_model.to(
|
|
|
|
|
|
|
|
|
|
| 334 |
|
| 335 |
## load third-party models
|
| 336 |
model_name = 'Intel/dpt-large'
|
| 337 |
depth_preprocess = DPTFeatureExtractor.from_pretrained(model_name)
|
| 338 |
-
depth_predictor = DPTForDepthEstimation.from_pretrained(model_name, torch_dtype=
|
| 339 |
-
depth_predictor = depth_predictor.to(
|
| 340 |
|
| 341 |
if not os.path.exists('./third_party/remover_fast.pth'):
|
| 342 |
model_file_url = hf_hub_url(repo_id='Geonmo/remover_fast', filename='remover_fast.pth')
|
| 343 |
cached_download(model_file_url, cache_dir='./third_party', force_filename='remover_fast.pth')
|
| 344 |
-
remover = Remover(fast=True, jit=False, device=
|
| 345 |
|
| 346 |
-
sr_model = CustomRealESRGAN(
|
| 347 |
sr_model.load_weights('./third_party/RealESRGAN_x2.pth', download=True)
|
| 348 |
|
| 349 |
dataset = datasets.load_dataset("FredZhang7/stable-diffusion-prompts-2.47M")
|
|
@@ -361,28 +366,31 @@ def build_models(args):
|
|
| 361 |
'remover': remover,
|
| 362 |
'sr_model': sr_model,
|
| 363 |
'prompt_candidates': prompts,
|
|
|
|
|
|
|
| 364 |
}
|
| 365 |
return model_dict
|
| 366 |
|
| 367 |
|
| 368 |
def predict_compodiff(image, text_input, negative_text, cfg_image_scale, cfg_text_scale, mask, random_seed):
|
|
|
|
| 369 |
text_token_dict = model_dict['clip_tokenizer'](text=text_input, return_tensors='pt', padding='max_length', truncation=True)
|
| 370 |
-
text_tokens, text_attention_mask = text_token_dict['input_ids'].to(
|
| 371 |
|
| 372 |
negative_text_token_dict = model_dict['clip_tokenizer'](text=negative_text, return_tensors='pt', padding='max_length', truncation=True)
|
| 373 |
-
negative_text_tokens, negative_text_attention_mask = negative_text_token_dict['input_ids'].to(
|
| 374 |
|
| 375 |
with torch.no_grad():
|
| 376 |
if image is None:
|
| 377 |
-
image_cond = torch.zeros([1,1,768]).to(
|
| 378 |
-
mask = torch.tensor(np.zeros([64, 64], dtype='float32')).to(
|
| 379 |
else:
|
| 380 |
image_source = image.resize((512, 512))
|
| 381 |
-
image_source = model_dict['clip_preprocess'](image_source, return_tensors='pt')['pixel_values'].to(
|
| 382 |
mask = mask.resize((512, 512))
|
| 383 |
mask = model_dict['clip_preprocess'](mask, do_normalize=False, return_tensors='pt')['pixel_values']
|
| 384 |
mask = mask[:,:1,:,:]
|
| 385 |
-
mask = (mask > 0.5).float().to(
|
| 386 |
image_source = image_source * (1 - mask)
|
| 387 |
image_cond = model_dict['clip_model'].encode_images(image_source)
|
| 388 |
mask = transforms.Resize([64, 64])(mask)[:,0,:,:]
|
|
@@ -396,7 +404,9 @@ def predict_compodiff(image, text_input, negative_text, cfg_image_scale, cfg_tex
|
|
| 396 |
|
| 397 |
|
| 398 |
def generate_depth_map(image, height, width):
|
| 399 |
-
|
|
|
|
|
|
|
| 400 |
depth_map = model_dict['depth_predictor'](**depth_inputs).predicted_depth.unsqueeze(1)
|
| 401 |
depth_min = torch.amin(depth_map, dim=[1,2,3], keepdim=True)
|
| 402 |
depth_max = torch.amax(depth_map, dim=[1,2,3], keepdim=True)
|
|
@@ -421,6 +431,9 @@ def generate_color(image, compactness=30, n_segments=100, thresh=35, blur_kernel
|
|
| 421 |
|
| 422 |
@torch.no_grad()
|
| 423 |
def generate(image_source, image_reference, text_input, negative_prompt, steps, random_seed, cfg_image_scale, cfg_text_scale, cfg_image_space_scale, cfg_image_reference_mix_weight, cfg_image_source_mix_weight, mask_scale, use_edge, t2i_height, t2i_width, do_sr, mode):
|
|
|
|
|
|
|
|
|
|
| 424 |
text_input = text_input.lower()
|
| 425 |
if negative_prompt == '':
|
| 426 |
print('running without a negative prompt')
|
|
@@ -513,10 +526,10 @@ def generate(image_source, image_reference, text_input, negative_prompt, steps,
|
|
| 513 |
# do reference first
|
| 514 |
if image_reference is not None:
|
| 515 |
image_cond_reference = ImageOps.exif_transpose(image_reference)
|
| 516 |
-
image_cond_reference = model_dict['clip_preprocess'](image_cond_reference, return_tensors='pt')['pixel_values'].to(
|
| 517 |
image_cond_reference = model_dict['clip_model'].encode_images(image_cond_reference)
|
| 518 |
else:
|
| 519 |
-
image_cond_reference = torch.zeros([1, 1, 768]).to(
|
| 520 |
|
| 521 |
# do source or knn
|
| 522 |
image_cond_source = None
|
|
@@ -530,14 +543,14 @@ def generate(image_source, image_reference, text_input, negative_prompt, steps,
|
|
| 530 |
image_cond, image_cond_source = predict_compodiff(None, text_input, negative_prompt, cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
|
| 531 |
else:
|
| 532 |
image_cond, image_cond_source = predict_compodiff(image_source, text_input, negative_prompt, cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
|
| 533 |
-
image_cond = image_cond.to(
|
| 534 |
-
image_cond_source = image_cond_source.to(
|
| 535 |
else:
|
| 536 |
-
image_cond = torch.zeros([1, 1, 768]).to(
|
| 537 |
|
| 538 |
if image_cond_source is None and mode != 't2i':
|
| 539 |
image_cond_source = image_source.resize((512, 512))
|
| 540 |
-
image_cond_source = model_dict['clip_preprocess'](image_cond_source, return_tensors='pt')['pixel_values'].to(
|
| 541 |
image_cond_source = model_dict['clip_model'].encode_images(image_cond_source)
|
| 542 |
|
| 543 |
if cfg_image_reference_mix_weight > 0.0 and torch.sum(image_cond_reference).item() != 0.0:
|
|
@@ -551,7 +564,7 @@ def generate(image_source, image_reference, text_input, negative_prompt, steps,
|
|
| 551 |
|
| 552 |
if negative_prompt != '':
|
| 553 |
negative_image_cond, _ = predict_compodiff(None, negative_prompt, '', cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
|
| 554 |
-
negative_image_cond = negative_image_cond.to(
|
| 555 |
else:
|
| 556 |
negative_image_cond = torch.zeros_like(image_cond)
|
| 557 |
|
|
|
|
| 162 |
|
| 163 |
# 2. Encode input prompt
|
| 164 |
cond_embeds = torch.cat([image_cond_embeds, negative_image_cond_embeds])
|
| 165 |
+
cond_embeds = einops.repeat(cond_embeds, 'b n d -> (b num) n d', num=num_images_per_prompt) #.to(torch_dtype)
|
| 166 |
prompt_embeds = cond_embeds
|
| 167 |
|
| 168 |
# 3. Preprocess image
|
|
|
|
| 312 |
|
| 313 |
def build_models(args):
|
| 314 |
# Load scheduler, tokenizer and models.
|
| 315 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
| 316 |
+
torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
| 317 |
|
| 318 |
model_path = 'navervision/Graphit-SD'
|
| 319 |
unet = UNet2DConditionModel.from_pretrained(
|
| 320 |
+
model_path, torch_dtype=torch_dtype,
|
| 321 |
)
|
| 322 |
|
| 323 |
vae_name = 'stabilityai/sd-vae-ft-ema'
|
| 324 |
+
vae = AutoencoderKL.from_pretrained(vae_name, torch_dtype=torch_dtype)
|
| 325 |
|
| 326 |
model_name = 'timbrooks/instruct-pix2pix'
|
| 327 |
+
pipe = GraphitPipeline.from_pretrained(model_name, torch_dtype=torch_dtype, safety_checker=None,
|
| 328 |
unet = unet,
|
| 329 |
vae = vae,
|
| 330 |
)
|
| 331 |
+
pipe = pipe.to(device)
|
| 332 |
|
| 333 |
## load CompoDiff
|
| 334 |
compodiff_model, clip_model, clip_preprocess, clip_tokenizer = compodiff.build_model()
|
| 335 |
+
compodiff_model, clip_model = compodiff_model.to(device), clip_model.to(device)
|
| 336 |
+
|
| 337 |
+
if device != 'cpu':
|
| 338 |
+
clip_model = clip_model.half()
|
| 339 |
|
| 340 |
## load third-party models
|
| 341 |
model_name = 'Intel/dpt-large'
|
| 342 |
depth_preprocess = DPTFeatureExtractor.from_pretrained(model_name)
|
| 343 |
+
depth_predictor = DPTForDepthEstimation.from_pretrained(model_name, torch_dtype=torch_dtype)
|
| 344 |
+
depth_predictor = depth_predictor.to(device)
|
| 345 |
|
| 346 |
if not os.path.exists('./third_party/remover_fast.pth'):
|
| 347 |
model_file_url = hf_hub_url(repo_id='Geonmo/remover_fast', filename='remover_fast.pth')
|
| 348 |
cached_download(model_file_url, cache_dir='./third_party', force_filename='remover_fast.pth')
|
| 349 |
+
remover = Remover(fast=True, jit=False, device=device, ckpt='./third_party/remover_fast.pth')
|
| 350 |
|
| 351 |
+
sr_model = CustomRealESRGAN(device, scale=2)
|
| 352 |
sr_model.load_weights('./third_party/RealESRGAN_x2.pth', download=True)
|
| 353 |
|
| 354 |
dataset = datasets.load_dataset("FredZhang7/stable-diffusion-prompts-2.47M")
|
|
|
|
| 366 |
'remover': remover,
|
| 367 |
'sr_model': sr_model,
|
| 368 |
'prompt_candidates': prompts,
|
| 369 |
+
'device': device,
|
| 370 |
+
'torch_dtype': torch_dtype,
|
| 371 |
}
|
| 372 |
return model_dict
|
| 373 |
|
| 374 |
|
| 375 |
def predict_compodiff(image, text_input, negative_text, cfg_image_scale, cfg_text_scale, mask, random_seed):
|
| 376 |
+
device = model_dict['device']
|
| 377 |
text_token_dict = model_dict['clip_tokenizer'](text=text_input, return_tensors='pt', padding='max_length', truncation=True)
|
| 378 |
+
text_tokens, text_attention_mask = text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)
|
| 379 |
|
| 380 |
negative_text_token_dict = model_dict['clip_tokenizer'](text=negative_text, return_tensors='pt', padding='max_length', truncation=True)
|
| 381 |
+
negative_text_tokens, negative_text_attention_mask = negative_text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)
|
| 382 |
|
| 383 |
with torch.no_grad():
|
| 384 |
if image is None:
|
| 385 |
+
image_cond = torch.zeros([1,1,768]).to(device)
|
| 386 |
+
mask = torch.tensor(np.zeros([64, 64], dtype='float32')).to(device).unsqueeze(0)
|
| 387 |
else:
|
| 388 |
image_source = image.resize((512, 512))
|
| 389 |
+
image_source = model_dict['clip_preprocess'](image_source, return_tensors='pt')['pixel_values'].to(device)
|
| 390 |
mask = mask.resize((512, 512))
|
| 391 |
mask = model_dict['clip_preprocess'](mask, do_normalize=False, return_tensors='pt')['pixel_values']
|
| 392 |
mask = mask[:,:1,:,:]
|
| 393 |
+
mask = (mask > 0.5).float().to(device)
|
| 394 |
image_source = image_source * (1 - mask)
|
| 395 |
image_cond = model_dict['clip_model'].encode_images(image_source)
|
| 396 |
mask = transforms.Resize([64, 64])(mask)[:,0,:,:]
|
|
|
|
| 404 |
|
| 405 |
|
| 406 |
def generate_depth_map(image, height, width):
|
| 407 |
+
device = model_dict['device']
|
| 408 |
+
torch_dtype = model_dict['torch_dtype']
|
| 409 |
+
depth_inputs = {k: v.to(device, dtype=torch_dtype) for k, v in model_dict['depth_preprocess'](images=image, return_tensors='pt').items()}
|
| 410 |
depth_map = model_dict['depth_predictor'](**depth_inputs).predicted_depth.unsqueeze(1)
|
| 411 |
depth_min = torch.amin(depth_map, dim=[1,2,3], keepdim=True)
|
| 412 |
depth_max = torch.amax(depth_map, dim=[1,2,3], keepdim=True)
|
|
|
|
| 431 |
|
| 432 |
@torch.no_grad()
|
| 433 |
def generate(image_source, image_reference, text_input, negative_prompt, steps, random_seed, cfg_image_scale, cfg_text_scale, cfg_image_space_scale, cfg_image_reference_mix_weight, cfg_image_source_mix_weight, mask_scale, use_edge, t2i_height, t2i_width, do_sr, mode):
|
| 434 |
+
device = model_dict['device']
|
| 435 |
+
torch_dtype = model_dict['torch_dtype']
|
| 436 |
+
|
| 437 |
text_input = text_input.lower()
|
| 438 |
if negative_prompt == '':
|
| 439 |
print('running without a negative prompt')
|
|
|
|
| 526 |
# do reference first
|
| 527 |
if image_reference is not None:
|
| 528 |
image_cond_reference = ImageOps.exif_transpose(image_reference)
|
| 529 |
+
image_cond_reference = model_dict['clip_preprocess'](image_cond_reference, return_tensors='pt')['pixel_values'].to(device)
|
| 530 |
image_cond_reference = model_dict['clip_model'].encode_images(image_cond_reference)
|
| 531 |
else:
|
| 532 |
+
image_cond_reference = torch.zeros([1, 1, 768]).to(torch_dtype).to(device)
|
| 533 |
|
| 534 |
# do source or knn
|
| 535 |
image_cond_source = None
|
|
|
|
| 543 |
image_cond, image_cond_source = predict_compodiff(None, text_input, negative_prompt, cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
|
| 544 |
else:
|
| 545 |
image_cond, image_cond_source = predict_compodiff(image_source, text_input, negative_prompt, cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
|
| 546 |
+
image_cond = image_cond.to(torch_dtype).to(device)
|
| 547 |
+
image_cond_source = image_cond_source.to(torch_dtype).to(device)
|
| 548 |
else:
|
| 549 |
+
image_cond = torch.zeros([1, 1, 768]).to(torch_dtype).to(device)
|
| 550 |
|
| 551 |
if image_cond_source is None and mode != 't2i':
|
| 552 |
image_cond_source = image_source.resize((512, 512))
|
| 553 |
+
image_cond_source = model_dict['clip_preprocess'](image_cond_source, return_tensors='pt')['pixel_values'].to(device)
|
| 554 |
image_cond_source = model_dict['clip_model'].encode_images(image_cond_source)
|
| 555 |
|
| 556 |
if cfg_image_reference_mix_weight > 0.0 and torch.sum(image_cond_reference).item() != 0.0:
|
|
|
|
| 564 |
|
| 565 |
if negative_prompt != '':
|
| 566 |
negative_image_cond, _ = predict_compodiff(None, negative_prompt, '', cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
|
| 567 |
+
negative_image_cond = negative_image_cond.to(torch_dtype).to(device)
|
| 568 |
else:
|
| 569 |
negative_image_cond = torch.zeros_like(image_cond)
|
| 570 |
|