Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import AutoProcessor, SiglipModel | |
| import faiss | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| from datasets import load_dataset | |
| import pandas as pd | |
| import requests | |
| from io import BytesIO | |
| import spaces | |
| # download model and dataset | |
| hf_hub_download("merve/siglip-faiss-wikiart", "siglip_10k_latest.index", local_dir="./") | |
| hf_hub_download("merve/siglip-faiss-wikiart", "wikiart_10k_latest.csv", local_dir="./") | |
| # read index, dataset and load siglip model and processor | |
| index = faiss.read_index("./siglip_10k_latest.index") | |
| df = pd.read_csv("./wikiart_10k_latest.csv") | |
| device = torch.device('cuda' if torch.cuda.is_available() else "cpu") | |
| processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") | |
| model = SiglipModel.from_pretrained("google/siglip-base-patch16-224").to(device) | |
| def read_image_from_url(url): | |
| response = requests.get(url) | |
| img = Image.open(BytesIO(response.content)).convert("RGB") | |
| return img | |
| #@spaces.GPU | |
| def extract_features_siglip(image): | |
| with torch.no_grad(): | |
| inputs = processor(images=image, return_tensors="pt").to(device) | |
| image_features = model.get_image_features(**inputs) | |
| return image_features | |
| def infer(input_image): | |
| input_features = extract_features_siglip(input_image["composite"].convert("RGB")) | |
| input_features = input_features.detach().cpu().numpy() | |
| input_features = np.float32(input_features) | |
| faiss.normalize_L2(input_features) | |
| distances, indices = index.search(input_features, 3) | |
| gallery_output = [] | |
| for i,v in enumerate(indices[0]): | |
| sim = -distances[0][i] | |
| image_url = df.iloc[v]["Link"] | |
| img_retrieved = read_image_from_url(image_url) | |
| gallery_output.append(img_retrieved) | |
| return gallery_output | |
| description="This is an application where you can draw or upload an image and find the closest artwork among 10k art from wikiart dataset. This is built on 🤗 transformers integration of [SigLIP](https://github.com/merveenoyan/siglip?tab=readme-ov-file#siglip-projects-) model by Google, and FAISS for indexing. In this [link](https://github.com/merveenoyan/siglip?tab=readme-ov-file#siglip-projects-) you can also find the notebook to index the dataset using SigLIP." | |
| sketchpad = gr.ImageEditor(type="pil") | |
| gr.Interface(infer, sketchpad, "gallery", description=description, title="Draw to Search Art 🖼️").launch() | |