Spaces:
Runtime error
Runtime error
| import torch | |
| from model import MaskedAutoencoderViT, mae_vit_base_patch16 | |
| import numpy as np | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from transformers import AutoTokenizer | |
| from collections import OrderedDict | |
| from huggingface_hub import hf_hub_download | |
| tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', ) | |
| ckpt = torch.load(hf_hub_download('tennant/MUG', 'laion_mug_vit_base_5ep.pth'), map_location='cpu') | |
| new_dict = OrderedDict() | |
| for k, v in ckpt.items(): | |
| k = k[len('image_encoder.model.'):] | |
| new_dict.update({k: v}) | |
| model = mae_vit_base_patch16(uni_dim=768, uni_heads=12, less_u=True) | |
| msg = model.load_state_dict(new_dict, strict=False) | |
| print(msg) | |
| if torch.cuda.is_available(): | |
| model.cuda() | |
| model.eval() | |
| def visual_recon(x, model, mask_ratio=0.75): | |
| target = model.patchify(x) | |
| mean = target.mean(dim=-1, keepdim=True) | |
| var = target.var(dim=-1, keepdim=True) | |
| latent, mask, ids_restore, _ = model.forward_encoder(x, mask_ratio=mask_ratio) | |
| y, _ = model.forward_decoder(latent, ids_restore) | |
| y = y * (var + 1.e-6)**.5 + mean | |
| y = model.unpatchify(y) | |
| y = torch.einsum('nchw->nhwc', y) | |
| mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3) # (N, H*W, p*p*3) | |
| mask = model.unpatchify(mask) # 1 is removing, 0 is keeping | |
| mask = torch.einsum('nchw->nhwc', mask) | |
| x = torch.einsum('nchw->nhwc', x) | |
| return x * (1 - mask), x * (1 - mask) + y * mask, y, latent | |
| def caption_next_word(latent, model, tokenizer, prefix='a photo of a'): | |
| assert latent.shape[0] == 1, 'can only caption one image at a time' | |
| x_l = torch.tensor(tokenizer([prefix, ])['input_ids'])[:, :-1] | |
| seq = x_l.shape[1] | |
| if torch.cuda.is_available(): | |
| x_l = x_l.cuda() | |
| cls_mask = rearrange(x_l != 0, 'b j -> b 1 j') | |
| attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True) | |
| x_l = model.embed_text(x_l) | |
| for cross_attn1, cross_attn2 in model.multimodal_layers: | |
| x_l = cross_attn1(x_l, latent) | |
| x_l = cross_attn2(x_l, latent) | |
| pred = model.to_logits(x_l) | |
| pred[:, :, 103] = -100 | |
| pred[:, :, 101] = -100 | |
| pred[:, :, 100] = -100 | |
| pred[:, :, 0] = -100 | |
| next_word = pred.argmax(dim=-1)[0, -1] | |
| next_word = tokenizer.decode(next_word) | |
| return next_word | |
| def caption(max_len, latent, model, tokenizer, prefix='a photo of a'): | |
| words = prefix.split() | |
| while len(words) < max_len: | |
| next_word = caption_next_word(latent, model, tokenizer, prefix=' '.join(words)) | |
| words.append(next_word) | |
| if next_word == '[SEP]': | |
| break | |
| return ' '.join(words) | |
| def gr_caption(x, mask_ratio=0.75, max_len=20, prefix='a'): | |
| imagenet_mean = np.array([0.485, 0.456, 0.406]) | |
| imagenet_std = np.array([0.229, 0.224, 0.225]) | |
| x = np.array(x) / 255. | |
| x = x - imagenet_mean | |
| x = x / imagenet_std | |
| x = torch.tensor(x).float() | |
| x = x.unsqueeze(0) | |
| x = torch.einsum('nhwc->nchw', x) | |
| if torch.cuda.is_available(): | |
| x = x.cuda() | |
| def unnorm_pix(img): | |
| img = img.squeeze(0).cpu().detach().numpy() | |
| img = img * imagenet_std + imagenet_mean | |
| return np.clip(img, a_min=0., a_max=1.) | |
| masked, masked_recon, recon, latent = visual_recon(x, model, mask_ratio=mask_ratio) | |
| caption_from_model = caption(max_len, latent, model, tokenizer, prefix=prefix) | |
| masked, masked_recon, recon = map(unnorm_pix, (masked, masked_recon, recon)) | |
| return_img = np.concatenate([masked, masked_recon, recon], axis=1) | |
| return return_img, caption_from_model | |
| import gradio as gr | |
| demo = gr.Interface(gr_caption, | |
| inputs=[gr.Image(value='cat.jpeg', shape=(224, 224)), | |
| gr.Number(value=0.75, label='mask ratio'), | |
| gr.Number(value=20, label='max length'), | |
| gr.Textbox(value='a photo of a', label='caption prefix')], | |
| outputs=[gr.Image(shape=(224, 224 * 3)), | |
| 'text'], | |
| # examples=[['cat.jpeg', 0.75, 20, 'a photo of a']], | |
| ) | |
| demo.launch() | |