| import time | |
| import os | |
| import torch | |
| import numpy as np | |
| import torchvision | |
| import torch.nn.functional as F | |
| from torchvision.datasets import ImageFolder | |
| import torchvision.transforms as transforms | |
| from tqdm import tqdm | |
| import pickle | |
| import argparse | |
| from PIL import Image | |
| concat = lambda x: np.concatenate(x, axis=0) | |
| to_np = lambda x: x.data.to("cpu").numpy() | |
| class Wrapper(torch.nn.Module): | |
| def __init__(self, model): | |
| super(Wrapper, self).__init__() | |
| self.model = model | |
| self.avgpool_output = None | |
| self.query = None | |
| self.cossim_value = {} | |
| def fw_hook(module, input, output): | |
| self.avgpool_output = output.squeeze() | |
| self.model.avgpool.register_forward_hook(fw_hook) | |
| def forward(self, input): | |
| _ = self.model(input) | |
| return self.avgpool_output | |
| def __repr__(self): | |
| return "Wrappper" | |
| def QueryToEmbedding(query_path): | |
| dataset_transform = transforms.Compose( | |
| [ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| model = torchvision.models.resnet50(pretrained=True) | |
| model.eval() | |
| myw = Wrapper(model) | |
| query_pil = Image.open(query_path) | |
| query_pt = dataset_transform(query_pil) | |
| with torch.no_grad(): | |
| embedding = to_np(myw(query_pt.unsqueeze(0))) | |
| return np.asarray([embedding]) | |