justinkay commited on
Commit
c72fcf7
·
1 Parent(s): 3c26f17

Update hf zeroshot

Browse files
Files changed (1) hide show
  1. hf_zeroshot.py +47 -32
hf_zeroshot.py CHANGED
@@ -71,51 +71,66 @@ def load_demo_annotations():
71
 
72
  return image_metadata
73
 
74
- def run_bioclip_inference(image_paths, class_names):
75
- """Run zero-shot inference using BioCLIP with pybioclip."""
76
- try:
77
- from bioclip import CustomLabelsClassifier
78
- print("Loading BioCLIP model...")
79
 
80
- # Create classifier with custom labels
 
81
  device = "cuda" if torch.cuda.is_available() else "cpu"
82
- classifier = CustomLabelsClassifier(
83
- cls_ary=class_names,
84
- device=device
85
- )
 
 
 
 
 
 
86
 
87
  results = {}
88
 
89
- for i, image_path in enumerate(image_paths):
90
- if i % 10 == 0:
91
- print(f"Processing image {i+1}/{len(image_paths)}: {os.path.basename(image_path)}")
 
92
 
93
- try:
94
- predictions = classifier.predict(image_path, k=len(class_names))
95
- scores = {}
96
- for class_name in class_names:
97
- scores[class_name] = 0.0
98
 
99
- # Fill in the predictions - predictions is a list of dicts with format:
100
- # [{'file_name': '...', 'classification': 'Ocelot', 'score': 0.999}, ...]
101
- for pred in predictions:
102
- class_name = pred['classification']
103
- score = pred['score']
104
- if class_name in scores:
105
- scores[class_name] = score
106
 
107
- results[os.path.basename(image_path)] = scores
 
 
108
 
109
- except Exception as e:
110
- print(f"Error processing {image_path}: {e}")
111
- # Fill with uniform probabilities if processing fails
112
- uniform_prob = 1.0 / len(class_names)
113
- results[os.path.basename(image_path)] = {class_name: uniform_prob for class_name in class_names}
 
 
 
 
 
 
 
 
 
114
 
115
  return results
116
 
117
  except Exception as e:
118
  print(f"Error loading BioCLIP: {e}")
 
 
119
  return None
120
 
121
  def run_openclip_inference(model_name, image_paths, class_names):
@@ -317,7 +332,7 @@ def main():
317
 
318
  # Handle different models with appropriate methods
319
  if model_name in ["imageomics/bioclip", "imageomics/bioclip-2"]:
320
- results = run_bioclip_inference(image_paths, CLASS_NAMES)
321
  elif model_name == "google/siglip2-so400m-patch16-naflex":
322
  results = run_siglip_inference(image_paths, CLASS_NAMES)
323
  elif model_name in ["facebook/PE-Core-L14-336", "laion/CLIP-ViT-L-14-laion2B-s32B-b82K"]:
 
71
 
72
  return image_metadata
73
 
74
+ def run_bioclip_inference(model_name, image_paths, class_names):
75
+ """Run zero-shot inference using BioCLIP via OpenCLIP."""
76
+ if not OPEN_CLIP_AVAILABLE:
77
+ print("open_clip is not available. Please install it with: pip install open_clip_torch")
78
+ return None
79
 
80
+ print(f"Loading BioCLIP model: {model_name}")
81
+ try:
82
  device = "cuda" if torch.cuda.is_available() else "cpu"
83
+
84
+ # Load model using OpenCLIP with hf-hub prefix
85
+ model, _, preprocess = open_clip.create_model_and_transforms(f'hf-hub:{model_name}')
86
+ model = model.to(device)
87
+ model.eval()
88
+ tokenizer = open_clip.get_tokenizer(f'hf-hub:{model_name}')
89
+
90
+ # Prepare text prompts
91
+ prompts = [f"a photo of a {class_name.lower()}" for class_name in class_names]
92
+ text_tokens = tokenizer(prompts).to(device)
93
 
94
  results = {}
95
 
96
+ with torch.no_grad():
97
+ # Encode text once
98
+ text_features = model.encode_text(text_tokens)
99
+ text_features /= text_features.norm(dim=-1, keepdim=True)
100
 
101
+ for i, image_path in enumerate(image_paths):
102
+ if i % 10 == 0:
103
+ print(f"Processing image {i+1}/{len(image_paths)}: {os.path.basename(image_path)}")
 
 
104
 
105
+ try:
106
+ image = Image.open(image_path).convert("RGB")
107
+ image_tensor = preprocess(image).unsqueeze(0).to(device)
 
 
 
 
108
 
109
+ # Encode image
110
+ image_features = model.encode_image(image_tensor)
111
+ image_features /= image_features.norm(dim=-1, keepdim=True)
112
 
113
+ # Calculate similarity and convert to probabilities
114
+ similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
115
+ probabilities = similarity.squeeze(0).cpu().numpy()
116
+
117
+ scores = {}
118
+ for j, class_name in enumerate(class_names):
119
+ scores[class_name] = float(probabilities[j])
120
+
121
+ results[os.path.basename(image_path)] = scores
122
+
123
+ except Exception as e:
124
+ print(f"Error processing {image_path}: {e}")
125
+ uniform_prob = 1.0 / len(class_names)
126
+ results[os.path.basename(image_path)] = {class_name: uniform_prob for class_name in class_names}
127
 
128
  return results
129
 
130
  except Exception as e:
131
  print(f"Error loading BioCLIP: {e}")
132
+ import traceback
133
+ traceback.print_exc()
134
  return None
135
 
136
  def run_openclip_inference(model_name, image_paths, class_names):
 
332
 
333
  # Handle different models with appropriate methods
334
  if model_name in ["imageomics/bioclip", "imageomics/bioclip-2"]:
335
+ results = run_bioclip_inference(model_name, image_paths, CLASS_NAMES)
336
  elif model_name == "google/siglip2-so400m-patch16-naflex":
337
  results = run_siglip_inference(image_paths, CLASS_NAMES)
338
  elif model_name in ["facebook/PE-Core-L14-336", "laion/CLIP-ViT-L-14-laion2B-s32B-b82K"]: