File size: 5,175 Bytes
133d5e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import faiss
import numpy as np
import json
from transformers import AutoProcessor, AutoModel
from PIL import Image
import torch
import argparse
import requests,os
from dotenv import load_dotenv
import tempfile
# ------------------------------
# 1️⃣ Setup Model and Device
# ------------------------------
print("[INFO] Loading model...")
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"[INFO] Using device: {device}")
# ------------------------------
# 2️⃣ Load Manifest
# ------------------------------
load_dotenv() # Load .env file
env_blob = os.environ.get("find_it_or_create", "")
secrets = {}
for line in env_blob.splitlines():
if "=" in line:
k, v = line.split("=", 1)
secrets[k.strip()] = v.strip().strip('"')
json_url = secrets.get("JSON_URL")
faiss_image_url=secrets.get("faiss_image")
faiss_text_url=secrets.get("faiss_text")
print("✅ Extracted JSON_URL:")
response = requests.get(json_url)
response.raise_for_status()
manifest_data = response.json()
def download_faiss_index(url):
response = requests.get(url)
response.raise_for_status()
tmp_path = tempfile.mktemp(suffix=".faiss")
with open(tmp_path, "wb") as f:
f.write(response.content)
return tmp_path
faiss_image = download_faiss_index(faiss_image_url)
faiss_text = download_faiss_index(faiss_text_url)
print("✅ Extracted index:")
print("[DEBUG] manifest_data type:", type(manifest_data))
manifest = manifest_data
if "products" in manifest:
manifest = manifest["products"]
# Build mappings from FAISS position → product_id safely
image_pos_to_id = {}
text_pos_to_id = {}
for pid, v in manifest.items():
if "image_pos" in v:
image_pos_to_id[v["image_pos"]] = pid
if "text_pos" in v:
text_pos_to_id[v["text_pos"]] = pid
# ------------------------------
# 3️⃣ Load FAISS Indexes
# ------------------------------
print("[INFO] Loading FAISS indexes...")
image_index = faiss.read_index(faiss_image)
text_index = faiss.read_index(faiss_text)
# ------------------------------
# 4️⃣ Search Functions
# ------------------------------
def search_image(image_path, topk=5):
"""Search similar images and return full manifest data"""
print(f"[DEBUG] Running image search on: {image_path}")
img = Image.open(image_path).convert("RGB")
inputs = processor(images=img, return_tensors="pt").to(device)
with torch.no_grad():
image_embedding = model.get_image_features(**inputs)
vec = image_embedding.cpu().numpy()
vec /= np.linalg.norm(vec)
scores, positions = image_index.search(vec, topk)
# print(f"[DEBUG] FAISS returned positions: {positions[0]}, scores: {scores[0]}")
results = []
for pos, score in zip(positions[0], scores[0]):
pid = image_pos_to_id.get(int(pos))
if pid:
results.append({
"score": float(score),
"data": manifest[pid]
})
for r in results:
print(json.dumps(r["data"], indent=2, ensure_ascii=False))
# print(f"[INFO] Found {len(results)} results")
return results
def search_text(query, topk=5):
"""Search similar texts and return full manifest data"""
print(f"[DEBUG] Running text search for query: '{query}'")
inputs = processor(text=query, return_tensors="pt").to(device)
with torch.no_grad():
text_embedding = model.get_text_features(**inputs)
vec = text_embedding.cpu().numpy()
vec /= np.linalg.norm(vec)
scores, positions = text_index.search(vec, topk)
# print(f"[DEBUG] FAISS returned positions: {positions[0]}, scores: {scores[0]}")
results = []
for pos, score in zip(positions[0], scores[0]):
pid = text_pos_to_id.get(int(pos))
if pid:
results.append({
"score": float(score),
"data": manifest[pid]
})
# print(f"[INFO] Found {len(results)} results")
return results
# ------------------------------
# 5️⃣ Command-Line Interface
# ------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="SigLIP Image/Text Search")
parser.add_argument("--image", type=str, help="Path to query image")
parser.add_argument("--text", type=str, help="Query text")
parser.add_argument("--topk", type=int, default=5, help="Number of results")
args = parser.parse_args()
if args.image:
results = search_image(args.image, topk=args.topk)
print(f"\n🔎 Image Search Results ({len(results)}):")
# for r in results:
# print(json.dumps(r["data"], indent=2, ensure_ascii=False))
if args.text:
results = search_text(args.text, topk=args.topk)
# print(f"\n🔎 Text Search Results ({len(results)}):")
# for r in results:
# print(json.dumps(r["data"], indent=2, ensure_ascii=False))
if not args.image and not args.text:
print("Please provide --image or --text for search.") |