coda / app.py
justinkay
New model layout, deleted a few redundant images
46632d2
raw
history blame
43.3 kB
import os
os.environ['GRADIO_TEMP_DIR'] = "tmp/"
import gradio as gr
import json
import random
from PIL import Image
from tqdm import tqdm
from collections import OrderedDict
import numpy as np
import torch
import shutil
import argparse
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from coda import CODA
from coda.datasets import Dataset
from coda.options import LOSS_FNS
from coda.oracle import Oracle
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--debug', action='store_true', help='Enable debug mode with delete button')
args_cli = parser.parse_args()
DEBUG_MODE = args_cli.debug
if DEBUG_MODE:
print("Debug mode enabled - delete button will be available")
# Create deleted_in_app directory if it doesn't exist
os.makedirs('deleted_in_app', exist_ok=True)
with open('iwildcam_demo_annotations.json', 'r') as f:
data = json.load(f)
SPECIES_MAP = OrderedDict([
(24, "Jaguar"), # panthera onca
(10, "Ocelot"), # leopardus pardalis
(6, "Mountain Lion"), # puma concolor
(101, "Common Eland"), # tragelaphus oryx
(102, "Waterbuck"), # kobus ellipsiprymnus
])
NAME_TO_ID = {name: id for id, name in SPECIES_MAP.items()}
# Class names in order (0-4) from classes.txt
CLASS_NAMES = ["Jaguar", "Ocelot", "Mountain Lion", "Common Eland", "Waterbuck"]
NAME_TO_CLASS_IDX = {name: idx for idx, name in enumerate(CLASS_NAMES)}
# Model information from models.txt
MODEL_INFO = [
{"org": "Facebook", "name": "PE-Core", "logo": "logos/meta.png"},
{"org": "Google", "name": "SigLIP2", "logo": "logos/google.png"},
{"org": "OpenAI", "name": "CLIPViT-L", "logo": "logos/openai.png"},
{"org": "Imageomics", "name": "BioCLIP", "logo": "logos/imageomics.png"},
{"org": "LAION", "name": "CLIP-L", "logo": "logos/laion.png"}
]
DEMO_LEARNING_RATE = 0.05 # don't use default; use something more fun
DEMO_ALPHA = 0.25
def create_species_guide_content():
"""Create the species identification guide content"""
with gr.Column():
gr.Markdown("""
# Species Classification Guide
### Learn to identify the five wildlife species in this demo.
## Jaguar
""")
gr.Image("species_id/jaguar.jpg", label="Jaguar example image", show_label=False)
gr.Markdown("""
The largest cat in the Americas, with a stocky, muscular build and a broad head; its golden coat is patterned with rosettes that often have central spots inside.
----
## Ocelot
""")
gr.Image("species_id/ocelot.jpg", label="Ocelot example image", show_label=False)
gr.Markdown("""
A medium-sized spotted cat about twice the size of a domestic cat, with a slender body, large eyes, and striking chain-link or stripe-like rosettes. It differs from jaguars by its smaller size, leaner build, and more elongated markings.
----
## Mountain Lion
""")
gr.Image("species_id/mountainlion.jpg", label="Mountain lion example image", show_label=False)
gr.Markdown("""
Also called cougar or puma, this cat has a plain tawny or grayish coat without spots or rosettes. Its long tail and uniformly colored fur distinguish it from jaguars and ocelots.
----
## Common Eland
""")
gr.Image("species_id/commoneland.jpg", label="Eland example image", show_label=False)
gr.Markdown("""
The largest antelope species, with a robust body, spiraled horns on both sexes, and a characteristic dewlap hanging from the throat. It differs from waterbuck by its lighter tan coat, faint body stripes, and massive size.
----
## Waterbuck
""")
gr.Image("species_id/waterbuck.jpg", label="Waterbuck example image", show_label=False)
gr.Markdown("""
A shaggy, dark brown antelope recognized by its white rump ring and backward-curving horns in males. Smaller and darker than the common eland, waterbuck prefer wet habitats and lack the eland's throat dewlap.
----
""")
# load image metadata
images_data = []
for annotation in tqdm(data['annotations'], desc='Loading annotations'):
image_id = annotation['image_id']
category_id = annotation['category_id']
image_info = next((img for img in data['images'] if img['id'] == image_id), None)
if image_info:
images_data.append({
'filename': image_info['file_name'],
'species_id': category_id,
'species_name': SPECIES_MAP[category_id]
})
print(f"Loaded {len(images_data)} images for the quiz")
# Load image filenames list
with open('images.txt', 'r') as f:
full_image_filenames = [line.strip() for line in f.readlines() if line.strip()]
# Initialize full dataset (will be subsampled per-user)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load full dataset
full_preds = torch.load("iwildcam_demo.pt").to(device)
full_labels = torch.load("iwildcam_demo_labels.pt").to(device)
# Pre-compute class indices for subsampling
from collections import defaultdict
full_class_to_indices = defaultdict(list)
for idx, label in enumerate(full_labels):
class_idx = label.item()
full_class_to_indices[class_idx].append(idx)
# Find minimum class size
min_class_size = min(len(indices) for indices in full_class_to_indices.values())
print(f"Each user will get {min_class_size} images per class (total: {min_class_size * len(full_class_to_indices)} images per user)")
# Loss function for oracle
loss_fn = LOSS_FNS['acc']
# Global state (will be set per-user in start_demo)
current_image_info = None
coda_selector = None
oracle = None
dataset = None
image_filenames = None
iteration_count = 0
def get_model_predictions(chosen_idx):
"""Get model predictions and scores for a specific image"""
global dataset
if dataset is None or chosen_idx >= dataset.preds.shape[1]:
return "No predictions available"
# Get predictions for this image (shape: [num_models, num_classes])
image_preds = dataset.preds[:, chosen_idx, :].detach().cpu().numpy()
predictions_list = []
for model_idx in range(image_preds.shape[0]):
model_scores = image_preds[model_idx]
predicted_class_idx = model_scores.argmax()
predicted_class_name = CLASS_NAMES[predicted_class_idx]
confidence = model_scores[predicted_class_idx]
model_info = MODEL_INFO[model_idx]
predictions_list.append(f"**{model_info['name']}:** {predicted_class_name} *({confidence:.3f})*")
predictions_text = "### Model Predictions\n\n" + " | ".join(predictions_list)
return predictions_text
def add_logo_to_x_axis(ax, x_pos, logo_path, model_name, height_px=35):
"""Add a logo image to x-axis next to model name"""
try:
img = mpimg.imread(logo_path)
# Calculate zoom to achieve desired height in pixels
# Rough conversion: height_px / image_height / dpi * 72
zoom = height_px / img.shape[0] / ax.figure.dpi * 72
imagebox = OffsetImage(img, zoom=zoom)
# Position logo to the left of the x-tick
logo_offset = -0.28 # Adjust this to move logo left/right relative to tick
y_offset = -0.08
ab = AnnotationBbox(imagebox, (x_pos + logo_offset, y_offset),
xycoords=('data', 'axes fraction'), frameon=False)
ax.add_artist(ab)
except Exception as e:
print(f"Could not load logo {logo_path}: {e}")
def get_next_coda_image():
"""Get the next image that CODA wants labeled"""
global current_image_info, coda_selector, iteration_count
# Get next item from CODA
chosen_idx, selection_prob = coda_selector.get_next_item_to_label()
# Get the corresponding image filename
if chosen_idx < len(image_filenames):
filename = image_filenames[chosen_idx]
image_path = os.path.join('iwildcam_demo_images', filename)
# Find the corresponding annotation for this image
current_image_info = None
for annotation in data['annotations']:
image_id = annotation['image_id']
image_info = next((img for img in data['images'] if img['id'] == image_id), None)
if image_info and image_info['file_name'] == filename:
current_image_info = {
'filename': filename,
'species_id': annotation['category_id'],
'species_name': SPECIES_MAP[annotation['category_id']],
'chosen_idx': chosen_idx,
'selection_prob': selection_prob
}
break
try:
image = Image.open(image_path)
predictions = get_model_predictions(chosen_idx)
return image, f"Iteration {iteration_count}: CODA selected this image for labeling", predictions
except Exception as e:
print(f"Error loading image {image_path}: {e}")
return None, f"Error loading image: {e}", "No predictions available"
else:
return None, "Image index out of range", "No predictions available"
def delete_current_image():
"""Delete the current image by moving it to deleted_in_app directory"""
global current_image_info, coda_selector
if current_image_info is None:
return "No image to delete!", None, "No predictions", None, None, ""
filename = current_image_info['filename']
chosen_idx = current_image_info['chosen_idx']
source_path = os.path.join('iwildcam_demo_images', filename)
dest_path = os.path.join('deleted_in_app', filename)
try:
shutil.move(source_path, dest_path)
result = f"✓ Moved {filename} to deleted_in_app/"
print(f"Deleted image: {filename}")
# Remove from CODA's unlabeled indices without adding a label
if chosen_idx in coda_selector.unlabeled_idxs:
coda_selector.unlabeled_idxs.remove(chosen_idx)
except Exception as e:
result = f"Error deleting image: {e}"
print(f"Error deleting {filename}: {e}")
# Load next image
next_image, status, predictions = get_next_coda_image()
status_html = f'{status} <span class="inline-help-btn" title="What is this?">?</span>'
# Get updated plots
prob_plot = create_probability_chart()
accuracy_plot = create_accuracy_chart()
return result, next_image, predictions, prob_plot, accuracy_plot, status_html
def check_answer(user_choice):
"""Process user's label and update CODA"""
global current_image_info, coda_selector, iteration_count
if current_image_info is None:
return "Please load an image first!", "", None, "No predictions", None, None
correct_species = current_image_info['species_name']
chosen_idx = current_image_info['chosen_idx']
selection_prob = current_image_info['selection_prob']
# Convert user choice to class index (0-5)
if user_choice == "I don't know":
# For "I don't know", just remove from sampling without providing label
coda_selector.unlabeled_idxs.remove(chosen_idx)
result = f"The last image was skipped and will not be used for model selection. The correct species was {correct_species}. "
else:
user_class_idx = NAME_TO_CLASS_IDX.get(user_choice, NAME_TO_CLASS_IDX[correct_species])
if user_choice == correct_species:
result = f"🎉 Your last classification was correct! It was indeed a {correct_species}."
else:
result = f"❌ Your last classification was incorrect. It was a {correct_species}, not a {user_choice}. This may mislead the model selection process!"
# Update CODA with the label
coda_selector.add_label(chosen_idx, user_class_idx, selection_prob)
iteration_count += 1
# Get updated plots
prob_plot = create_probability_chart()
accuracy_plot = create_accuracy_chart()
# Load next image
next_image, status, predictions = get_next_coda_image()
# Create HTML with inline help button for status
status_html = f'{status} <span class="inline-help-btn" title="What is this?">?</span>'
return result, status_html, next_image, predictions, prob_plot, accuracy_plot
def create_probability_chart():
"""Create a bar chart showing probability each model is best"""
global coda_selector
if coda_selector is None:
# Fallback for initial state
model_labels = [info['name'] for info in MODEL_INFO]
probabilities = np.ones(len(MODEL_INFO)) / len(MODEL_INFO) # Uniform prior
else:
probs_tensor = coda_selector.get_pbest()
probabilities = probs_tensor.detach().cpu().numpy().flatten()
model_labels = [" "*(3 + len(info['name'])//4) + info['name'] for info in MODEL_INFO[:len(probabilities)]]
# Find the index of the highest probability
best_idx = np.argmax(probabilities)
fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150)
# Create colors array - highlight the best model
colors = ['orange' if i == best_idx else 'steelblue' for i in range(len(model_labels))]
bars = ax.bar(range(len(model_labels)), probabilities, color=colors, alpha=0.7)
# Add text above the highest bar
ax.text(best_idx, probabilities[best_idx] + 0.0025, 'Current best guess',
ha='center', va='bottom', fontsize=12, fontweight='bold')
ax.set_ylabel('Probability model is best', fontsize=12)
ax.set_title(f'CODA Model Selection Probabilities (Iteration {iteration_count})', fontsize=12)
ax.set_ylim(np.min(probabilities) - 0.01, np.max(probabilities) + 0.02)
# Set x-axis labels and ticks
ax.set_xticks(range(len(model_labels)))
ax.set_xticklabels(model_labels, fontsize=12, ha='center')
# Add logos to x-axis
for i, model_info in enumerate(MODEL_INFO[:len(probabilities)]):
add_logo_to_x_axis(ax, i, model_info['logo'], model_info['name'])
plt.yticks(fontsize=12)
plt.tight_layout()
# Save the figure and close it to prevent memory leaks
temp_fig = fig
plt.close(fig)
return temp_fig
def create_accuracy_chart():
"""Create confusion matrix estimates for each model side by side"""
global coda_selector, iteration_count
if coda_selector is None:
# Fallback for initial state - return empty figure
fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150)
ax.text(0.5, 0.5, 'Start demo to see confusion matrices',
ha='center', va='center', fontsize=12)
ax.axis('off')
plt.tight_layout()
temp_fig = fig
plt.close(fig)
return temp_fig
# Get confusion matrix estimates from CODA's Dirichlet distributions
dirichlets = coda_selector.dirichlets # Shape: [num_models, num_classes, num_classes]
num_models = dirichlets.shape[0]
num_classes = dirichlets.shape[1]
# Convert Dirichlet parameters to expected confusion matrices
# The expected value of a Dirichlet is alpha / sum(alpha)
confusion_matrices = []
for model_idx in range(num_models):
alpha = dirichlets[model_idx].detach().cpu().numpy()
# Normalize each row to get probabilities
conf_matrix = alpha / alpha.sum(axis=1, keepdims=True)
confusion_matrices.append(conf_matrix)
# Create subplots for each model
# Adjust width based on number of models (2.4 inches per model works well)
fig_width = num_models * 2.4
fig, axes = plt.subplots(1, num_models, figsize=(fig_width, 2.8), dpi=150)
if num_models == 1:
axes = [axes]
# Species abbreviations for axis labels
species_labels = ['Jag', 'Oce', 'M.L.', 'C.E.', 'Wat']
for model_idx, (ax, conf_matrix) in enumerate(zip(axes, confusion_matrices)):
# Apply square root scaling to make small values more visible
# This expands small values while still showing large values
sqrt_conf_matrix = np.sqrt(np.sqrt(np.sqrt(np.sqrt(conf_matrix))))
# Plot confusion matrix as heatmap with sqrt-scaled values
im = ax.imshow(sqrt_conf_matrix, cmap='Blues', aspect='auto')#, vmin=0, vmax=1)
# Add model name as title
model_info = MODEL_INFO[model_idx]
ax.set_title(f"{model_info['name']}", fontsize=10, pad=5)
# Set axis labels
if model_idx == 0:
ax.set_ylabel('True class', fontsize=9)
ax.set_xlabel('Predicted', fontsize=9)
# Set ticks with species abbreviations
ax.set_xticks(range(num_classes))
ax.set_yticks(range(num_classes))
ax.set_xticklabels(species_labels[:num_classes], fontsize=8)
ax.set_yticklabels(species_labels[:num_classes], fontsize=8)
plt.suptitle(f"CODA's Confusion Matrix Estimates (Iteration {iteration_count})", fontsize=12, y=0.98)
plt.tight_layout()
temp_fig = fig
plt.close(fig)
return temp_fig
# OLD CODE - True Model Accuracies Bar Chart (kept for easy reversion)
# def create_accuracy_chart():
# """Create a bar chart showing true accuracy of each model"""
# global oracle, dataset
#
# if oracle is None or dataset is None:
# # Fallback for initial state
# model_labels = [info['name'] for info in MODEL_INFO]
# accuracies = np.random.random(len(MODEL_INFO)) # Random accuracies for now
# else:
# true_losses = oracle.true_losses(dataset.preds)
# # Convert losses to accuracies (assuming loss is 1 - accuracy)
# accuracies = (1 - true_losses).detach().cpu().numpy().flatten()
# model_labels = [" " + info['name'] for info in MODEL_INFO[:len(accuracies)]]
#
# # Find the index of the highest accuracy
# best_idx = np.argmax(accuracies)
#
# fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150)
#
# # Create colors array - highlight the best model
# colors = ['red' if i == best_idx else 'forestgreen' for i in range(len(model_labels))]
# bars = ax.bar(range(len(model_labels)), accuracies, color=colors, alpha=0.7)
#
# # Add text above the highest bar
# ax.text(best_idx, accuracies[best_idx] + 0.005, 'True best model',
# ha='center', va='bottom', fontsize=12, fontweight='bold')
#
# ax.set_ylabel('True (oracle) \naccuracy of model', fontsize=12)
# ax.set_title('True Model Accuracies', fontsize=12)
# ax.set_ylim(np.min(accuracies) - 0.025, np.max(accuracies) + 0.05)
#
# # Set x-axis labels and ticks
# ax.set_xticks(range(len(model_labels)))
# ax.set_xticklabels(model_labels, fontsize=12, ha='center')
#
# # Add logos to x-axis
# for i, model_info in enumerate(MODEL_INFO[:len(accuracies)]):
# add_logo_to_x_axis(ax, i, model_info['logo'], model_info['name'])
# plt.yticks(fontsize=12)
# plt.tight_layout()
#
# # Save the figure and close it to prevent memory leaks
# temp_fig = fig
# plt.close(fig)
# return temp_fig
# Create the Gradio interface
with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge",
theme=gr.themes.Base(),
css="""
.subtle-outline {
border: 1px solid var(--border-color-primary) !important;
background: var(--background-fill-secondary) !important;
border-radius: var(--radius-lg);
padding: 1rem;
}
.subtle-outline .flex {
background-color: var(--background-fill-secondary) !important;
}
/* Popup overlay styles */
.popup-overlay {
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
background-color: rgba(0, 0, 0, 0.5);
z-index: 1000;
display: flex;
justify-content: center;
align-items: center;
}
.popup-overlay > div {
background: transparent !important;
border: none !important;
padding: 0 !important;
margin: 0 !important;
}
.popup-content {
background: var(--background-fill-primary) !important;
padding: 2rem !important;
border-radius: 1rem !important;
max-width: 850px;
width: 90%;
max-height: 80vh;
overflow-y: auto;
box-shadow: 0 10px 25px rgba(0, 0, 0, 0.3);
border: none !important;
margin: 0 !important;
color: var(--body-text-color) !important;
}
.popup-content > div {
background: var(--background-fill-primary) !important;
border: none !important;
padding: 0 !important;
margin: 0 !important;
overflow-y: visible !important;
max-height: none !important;
}
.popup-content h1,
.popup-content h2,
.popup-content h3,
.popup-content p,
.popup-content li {
color: var(--body-text-color) !important;
}
/* Ensure gradio column components don't interfere with scrolling */
.popup-content .gradio-column {
overflow-y: visible !important;
max-height: none !important;
}
/* Ensure images in popup are responsive */
.popup-content img {
max-width: 100% !important;
height: auto !important;
}
/* Center title */
.text-center {
text-align: center !important;
}
/* Right align text */
.text-right {
text-align: right !important;
}
/* Subtitle styling */
.subtitle {
text-align: center !important;
font-weight: 300 !important;
color: #666 !important;
margin-top: -0.5rem !important;
}
/* Question mark icon styling */
.panel-container {
position: relative;
}
.help-icon {
position: absolute;
top: 5px;
right: 5px;
width: 25px;
height: 25px;
background-color: #f8f9fa;
color: #6c757d;
border: 1px solid #dee2e6;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
cursor: pointer;
font-size: 13px;
font-weight: 600;
z-index: 10;
transition: all 0.2s ease;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
}
.help-icon:hover {
background-color: #e9ecef;
color: #495057;
border-color: #adb5bd;
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.15);
}
/* Help popup styles */
.help-popup-overlay {
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
background-color: rgba(0, 0, 0, 0.5);
z-index: 1001;
display: flex;
justify-content: center;
align-items: center;
}
.help-popup-overlay > div {
background: transparent !important;
border: none !important;
padding: 0 !important;
margin: 0 !important;
}
.help-popup-content {
background: var(--background-fill-primary) !important;
padding: 1.5rem !important;
border-radius: 0.5rem !important;
max-width: 600px;
width: 90%;
box-shadow: 0 10px 25px rgba(0, 0, 0, 0.3);
border: none !important;
margin: 0 !important;
color: var(--body-text-color) !important;
}
.help-popup-content > div {
background: var(--background-fill-primary) !important;
border: none !important;
padding: 0 !important;
margin: 0 !important;
}
.help-popup-content h1,
.help-popup-content h2,
.help-popup-content h3,
.help-popup-content p,
.help-popup-content li {
color: var(--body-text-color) !important;
}
/* Inline help button */
.inline-help-btn {
display: inline-block;
width: 20px;
height: 20px;
background-color: #f8f9fa;
color: #6c757d;
border: 1px solid #dee2e6;
border-radius: 50%;
text-align: center;
line-height: 18px;
cursor: pointer;
font-size: 11px;
font-weight: 600;
margin-left: 8px;
vertical-align: middle;
transition: all 0.2s ease;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
}
.inline-help-btn:hover {
background-color: #e9ecef;
color: #495057;
border-color: #adb5bd;
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.15);
}
#hidden-selection-help-btn {
display: none;
}
/* Reduce spacing around status text */
.status-text {
margin: 0 !important;
padding: 0 !important;
}
.status-text > div {
margin: 0 !important;
padding: 0 !important;
}
/* Compact model predictions panel */
.compact-predictions {
line-height: 1.1 !important;
margin: 0 !important;
padding: 0.1rem !important;
}
.compact-predictions p {
margin: 0.05rem 0 !important;
}
.compact-predictions h3 {
margin: 0 0 0.1rem 0 !important;
}
/* Target the subtle-outline group that contains predictions */
.subtle-outline {
padding: 0.3rem !important;
margin: 0.2rem 0 !important;
}
/* Target the column inside the outline */
.subtle-outline .flex {
padding: 0 !important;
margin: 0 !important;
}
/* Ensure text in predictions panel is visible in dark mode */
.subtle-outline * {
color: var(--body-text-color) !important;
}
""") as demo:
# Main page title
gr.Markdown("# CODA: Consensus-Driven Active Model Selection", elem_classes="text-center")
# Popup component
with gr.Group(visible=True, elem_classes="popup-overlay") as popup_overlay:
with gr.Group(elem_classes="popup-content"):
# Main intro content
intro_content = gr.Markdown("""
# CODA: Consensus-Driven Active Model Selection
## Wildlife Photo Classification Challenge
You are a wildlife ecologist who has just collected a season's worth of imagery from cameras
deployed in Africa and Central and South America. You want to know what species occur in this imagery,
and you hope to use a pre-trained classifier to give you answers quickly.
But which one should you use?
Instead of labeling a large validation set, our new method, **CODA**, enables you to perform **active model selection**.
That is, CODA uses predictions from candidate models to guide the labeling process, querying you (a species identification expert)
for labels on a select few images that will most efficiently differentiate between your candidate machine learning models.
This demo lets you try CODA yourself! First, become a species identification expert by reading our classification guide
so that you will be equipped to provide ground truth labels. Then, watch as CODA narrows down the best model over time
as you provide labels for the query images. You will see that with your input CODA is able to identify the best model candidate
with as few as ten (correctly) labeled images.
""")
# Species guide content (initially hidden)
with gr.Column(visible=False) as species_guide_content:
create_species_guide_content()
# Add spacing before buttons
gr.HTML("<div style='margin-top: 0.1em;'></div>")
with gr.Row():
back_button = gr.Button("← Back to Intro", variant="secondary", size="lg", visible=False)
guide_button = gr.Button("View Species Classification Guide", variant="secondary", size="lg")
popup_start_button = gr.Button("Start Demo", variant="primary", size="lg")
# Help popups for panels
with gr.Group(visible=False, elem_classes="help-popup-overlay") as prob_help_popup:
with gr.Group(elem_classes="help-popup-content"):
gr.Markdown("""
## CODA Model Selection Probabilities
This chart shows CODA's current confidence in each candidate model being the best performer.
**How to read this chart:**
- Each bar represents one of the candidate machine learning models
- The height of each bar shows the probability (0-100%) that this model is the best, according to CODA
- The orange bar indicates CODA's current best guess
- As you provide more labels, CODA updates these probabilities
**What you'll see:**
- CODA initializes these probabilities based on each model's agreement with the consensus, providing informative priors
- As you label images, some models will gain confidence while others lose it
- The goal is for one model to clearly emerge as the winner
""")
prob_help_close = gr.Button("Close", variant="secondary")
with gr.Group(visible=False, elem_classes="help-popup-overlay") as acc_help_popup:
with gr.Group(elem_classes="help-popup-content"):
gr.Markdown("""
## True Model Accuracies
This chart shows the actual performance of each model on the complete dataset (only possible with oracle knowledge).
**How to read this chart:**
- Each bar represents the true accuracy of one model
- The red bar shows the actual best-performing model
- This information is hidden from CODA during the selection process
- You can compare this with CODA's estimates to see how well it's doing
**Why this matters:**
- This represents the "ground truth" that CODA is trying to discover
- In real scenarios, you wouldn't know these true accuracies beforehand
- The demo shows these to illustrate how CODA's estimates align with reality
""")
acc_help_close = gr.Button("Close", variant="secondary")
with gr.Group(visible=False, elem_classes="help-popup-overlay") as selection_help_popup:
with gr.Group(elem_classes="help-popup-content"):
gr.Markdown("""
## How CODA Selects Images for Labeling
[Placeholder]
""")
selection_help_close = gr.Button("Close", variant="secondary")
# Species guide popup during demo
with gr.Group(visible=False, elem_classes="popup-overlay") as species_guide_popup:
with gr.Group(elem_classes="popup-content"):
create_species_guide_content()
# Add spacing before button
gr.HTML("<div style='margin-top: 0.1em;'></div>")
species_guide_close = gr.Button("Go back to demo", variant="primary", size="lg")
# Two panels with bar charts
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes="panel-container"):
prob_help_button = gr.Button("?", elem_classes="help-icon", size="sm")
prob_plot = gr.Plot(
value=None,
show_label=False
)
with gr.Column(scale=1):
with gr.Group(elem_classes="panel-container"):
acc_help_button = gr.Button("?", elem_classes="help-icon", size="sm")
accuracy_plot = gr.Plot(
value=create_accuracy_chart(),
show_label=False
)
# Status display with help button
status_with_help = gr.HTML("", visible=True, elem_classes="status-text")
selection_help_button = gr.Button("", visible=False, elem_id="hidden-selection-help-btn")
with gr.Row():
image_display = gr.Image(
label="Identify this animal:",
value=None,
height=400,
width=550
)
# Model predictions panel (full width, single line)
with gr.Group(elem_classes="subtle-outline"):
with gr.Column(elem_classes="flex items-center justify-center h-full"):
model_predictions_display = gr.Markdown(
"### Model Predictions\n\n*Start the demo to see model votes!*",
show_label=False,
elem_classes="text-center compact-predictions"
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Which species is this?")
with gr.Column(scale=5):
result_display = gr.Markdown("", visible=True, elem_classes="text-right")
with gr.Row():
# Create buttons for each species
species_buttons = []
for species_name in SPECIES_MAP.values():
btn = gr.Button(species_name, variant="secondary", size="lg")
species_buttons.append(btn)
# Add "I don't know" button
idk_button = gr.Button("I don't know", variant="primary", size="lg")
# Add debug delete button (only visible in debug mode)
if DEBUG_MODE:
delete_button = gr.Button("🗑️ Delete Current Image", variant="stop", size="lg")
# Add buttons row
with gr.Row():
view_guide_button = gr.Button("📖 View Species Guide", variant="secondary", size="lg")
start_over_button = gr.Button("Start Over", variant="secondary", size="lg")
# Set up button interactions
def start_demo():
global iteration_count, coda_selector, dataset, oracle, image_filenames
# Reset the demo state
iteration_count = 0
# Subsample dataset for this user
subsampled_indices = []
for class_idx in sorted(full_class_to_indices.keys()):
indices = full_class_to_indices[class_idx]
sampled = np.random.choice(indices, size=min_class_size, replace=False)
subsampled_indices.extend(sampled.tolist())
# Sort indices to maintain order
subsampled_indices.sort()
# Create subsampled dataset for this user
subsampled_preds = full_preds[:, subsampled_indices, :]
subsampled_labels = full_labels[subsampled_indices]
image_filenames = [full_image_filenames[idx] for idx in subsampled_indices]
# Create Dataset object with subsampled data
dataset = Dataset.__new__(Dataset)
dataset.preds = subsampled_preds
dataset.labels = subsampled_labels
dataset.device = device
# Create oracle and CODA selector for this user
oracle = Oracle(dataset, loss_fn=loss_fn)
coda_selector = CODA(dataset,
learning_rate=DEMO_LEARNING_RATE,
alpha=DEMO_ALPHA)
image, status, predictions = get_next_coda_image()
prob_plot = create_probability_chart()
acc_plot = create_accuracy_chart()
# Create HTML with inline help button
status_html = f'{status} <span class="inline-help-btn" title="What is this?">?</span>'
return image, status_html, predictions, prob_plot, acc_plot, gr.update(visible=False), "", gr.update(visible=True)
def start_over():
global iteration_count, coda_selector, dataset, oracle, image_filenames
# Reset the demo state
iteration_count = 0
# Subsample dataset for this user (new random subsample)
subsampled_indices = []
for class_idx in sorted(full_class_to_indices.keys()):
indices = full_class_to_indices[class_idx]
sampled = np.random.choice(indices, size=min_class_size, replace=False)
subsampled_indices.extend(sampled.tolist())
# Sort indices to maintain order
subsampled_indices.sort()
# Create subsampled dataset for this user
subsampled_preds = full_preds[:, subsampled_indices, :]
subsampled_labels = full_labels[subsampled_indices]
image_filenames = [full_image_filenames[idx] for idx in subsampled_indices]
# Create Dataset object with subsampled data
dataset = Dataset.__new__(Dataset)
dataset.preds = subsampled_preds
dataset.labels = subsampled_labels
dataset.device = device
# Create oracle and CODA selector for this user
oracle = Oracle(dataset, loss_fn=loss_fn)
coda_selector = CODA(dataset,
learning_rate=DEMO_LEARNING_RATE,
alpha=DEMO_ALPHA)
# Reset all displays
prob_plot = create_probability_chart()
acc_plot = create_accuracy_chart()
return None, "Demo reset. Click 'Start CODA Demo' to begin.", "### Model Predictions\n\n*Start the demo to see model votes!*", prob_plot, acc_plot, "", gr.update(visible=True), gr.update(visible=False)
def show_species_guide():
# Show species guide, hide intro content, show back button, hide guide button
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
def show_intro():
# Show intro content, hide species guide, hide back button, show guide button
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
def show_prob_help():
return gr.update(visible=True)
def hide_prob_help():
return gr.update(visible=False)
def show_acc_help():
return gr.update(visible=True)
def hide_acc_help():
return gr.update(visible=False)
def show_selection_help():
return gr.update(visible=True)
def hide_selection_help():
return gr.update(visible=False)
def show_species_guide_popup():
return gr.update(visible=True)
def hide_species_guide_popup():
return gr.update(visible=False)
popup_start_button.click(
fn=start_demo,
outputs=[image_display, status_with_help, model_predictions_display, prob_plot, accuracy_plot, popup_overlay, result_display, selection_help_button]
)
start_over_button.click(
fn=start_over,
outputs=[image_display, status_with_help, model_predictions_display, prob_plot, accuracy_plot, result_display, popup_overlay, selection_help_button]
)
guide_button.click(
fn=show_species_guide,
outputs=[intro_content, species_guide_content, back_button, guide_button]
)
back_button.click(
fn=show_intro,
outputs=[intro_content, species_guide_content, back_button, guide_button]
)
# Help popup handlers
prob_help_button.click(
fn=show_prob_help,
outputs=[prob_help_popup]
)
prob_help_close.click(
fn=hide_prob_help,
outputs=[prob_help_popup]
)
acc_help_button.click(
fn=show_acc_help,
outputs=[acc_help_popup]
)
acc_help_close.click(
fn=hide_acc_help,
outputs=[acc_help_popup]
)
selection_help_button.click(
fn=show_selection_help,
outputs=[selection_help_popup]
)
selection_help_close.click(
fn=hide_selection_help,
outputs=[selection_help_popup]
)
# Species guide popup handlers
view_guide_button.click(
fn=show_species_guide_popup,
outputs=[species_guide_popup]
)
species_guide_close.click(
fn=hide_species_guide_popup,
outputs=[species_guide_popup]
)
for btn in species_buttons:
btn.click(
fn=check_answer,
inputs=[gr.State(btn.value)],
outputs=[result_display, status_with_help, image_display, model_predictions_display, prob_plot, accuracy_plot]
)
idk_button.click(
fn=check_answer,
inputs=[gr.State("I don't know")],
outputs=[result_display, status_with_help, image_display, model_predictions_display, prob_plot, accuracy_plot]
)
# Wire up delete button in debug mode
if DEBUG_MODE:
delete_button.click(
fn=delete_current_image,
outputs=[result_display, image_display, model_predictions_display, prob_plot, accuracy_plot, status_with_help]
)
# Add JavaScript to handle inline help button clicks
demo.load(
lambda: None,
outputs=[],
js="""
() => {
setTimeout(() => {
document.addEventListener('click', function(e) {
if (e.target && e.target.classList.contains('inline-help-btn')) {
e.preventDefault();
e.stopPropagation();
const hiddenBtn = document.getElementById('hidden-selection-help-btn');
if (hiddenBtn) {
hiddenBtn.click();
}
}
});
}, 100);
}
"""
)
if __name__ == "__main__":
demo.launch(
# share=True,
# server_port=7861,
allowed_paths=["/"]
)