justinkay
commited on
Commit
·
eaf554e
1
Parent(s):
f6adf18
Confusion matrix plot
Browse files
app.py
CHANGED
|
@@ -59,6 +59,7 @@ MODEL_INFO = [
|
|
| 59 |
]
|
| 60 |
|
| 61 |
DEMO_LEARNING_RATE = 0.05 # don't use default; use something more fun
|
|
|
|
| 62 |
|
| 63 |
# load image metadata
|
| 64 |
images_data = []
|
|
@@ -305,51 +306,116 @@ def create_probability_chart():
|
|
| 305 |
return temp_fig
|
| 306 |
|
| 307 |
def create_accuracy_chart():
|
| 308 |
-
"""Create
|
| 309 |
-
global
|
| 310 |
|
| 311 |
-
if
|
| 312 |
-
# Fallback for initial state
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
|
|
|
| 337 |
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
|
|
|
|
|
|
| 341 |
|
| 342 |
-
|
| 343 |
-
for i, model_info in enumerate(MODEL_INFO[:len(accuracies)]):
|
| 344 |
-
add_logo_to_x_axis(ax, i, model_info['logo'], model_info['name'])
|
| 345 |
-
plt.yticks(fontsize=12)
|
| 346 |
plt.tight_layout()
|
| 347 |
|
| 348 |
-
# Save the figure and close it to prevent memory leaks
|
| 349 |
temp_fig = fig
|
| 350 |
plt.close(fig)
|
| 351 |
return temp_fig
|
| 352 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
# Create the Gradio interface
|
| 354 |
with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge",
|
| 355 |
theme=gr.themes.Base(),
|
|
@@ -659,9 +725,12 @@ with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge",
|
|
| 659 |
A shaggy, dark brown antelope recognized by its white rump ring and backward-curving horns in males. Smaller and darker than the common eland, waterbuck prefer wet habitats and lack the eland's throat dewlap.
|
| 660 |
|
| 661 |
----
|
| 662 |
-
|
| 663 |
""")
|
| 664 |
|
|
|
|
|
|
|
|
|
|
| 665 |
with gr.Row():
|
| 666 |
back_button = gr.Button("← Back to Intro", variant="secondary", size="lg", visible=False)
|
| 667 |
guide_button = gr.Button("View Species Classification Guide", variant="secondary", size="lg")
|
|
@@ -810,7 +879,8 @@ with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge",
|
|
| 810 |
# Create oracle and CODA selector for this user
|
| 811 |
oracle = Oracle(dataset, loss_fn=loss_fn)
|
| 812 |
coda_selector = CODA(dataset,
|
| 813 |
-
learning_rate=DEMO_LEARNING_RATE
|
|
|
|
| 814 |
|
| 815 |
image, status, predictions = get_next_coda_image()
|
| 816 |
prob_plot = create_probability_chart()
|
|
@@ -849,7 +919,8 @@ with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge",
|
|
| 849 |
# Create oracle and CODA selector for this user
|
| 850 |
oracle = Oracle(dataset, loss_fn=loss_fn)
|
| 851 |
coda_selector = CODA(dataset,
|
| 852 |
-
learning_rate=DEMO_LEARNING_RATE
|
|
|
|
| 853 |
|
| 854 |
# Reset all displays
|
| 855 |
prob_plot = create_probability_chart()
|
|
|
|
| 59 |
]
|
| 60 |
|
| 61 |
DEMO_LEARNING_RATE = 0.05 # don't use default; use something more fun
|
| 62 |
+
DEMO_ALPHA = 0.25
|
| 63 |
|
| 64 |
# load image metadata
|
| 65 |
images_data = []
|
|
|
|
| 306 |
return temp_fig
|
| 307 |
|
| 308 |
def create_accuracy_chart():
|
| 309 |
+
"""Create confusion matrix estimates for each model side by side"""
|
| 310 |
+
global coda_selector, iteration_count
|
| 311 |
|
| 312 |
+
if coda_selector is None:
|
| 313 |
+
# Fallback for initial state - return empty figure
|
| 314 |
+
fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150)
|
| 315 |
+
ax.text(0.5, 0.5, 'Start demo to see confusion matrices',
|
| 316 |
+
ha='center', va='center', fontsize=12)
|
| 317 |
+
ax.axis('off')
|
| 318 |
+
plt.tight_layout()
|
| 319 |
+
temp_fig = fig
|
| 320 |
+
plt.close(fig)
|
| 321 |
+
return temp_fig
|
| 322 |
+
|
| 323 |
+
# Get confusion matrix estimates from CODA's Dirichlet distributions
|
| 324 |
+
dirichlets = coda_selector.dirichlets # Shape: [num_models, num_classes, num_classes]
|
| 325 |
+
num_models = dirichlets.shape[0]
|
| 326 |
+
num_classes = dirichlets.shape[1]
|
| 327 |
+
|
| 328 |
+
# Convert Dirichlet parameters to expected confusion matrices
|
| 329 |
+
# The expected value of a Dirichlet is alpha / sum(alpha)
|
| 330 |
+
confusion_matrices = []
|
| 331 |
+
for model_idx in range(num_models):
|
| 332 |
+
alpha = dirichlets[model_idx].detach().cpu().numpy()
|
| 333 |
+
# Normalize each row to get probabilities
|
| 334 |
+
conf_matrix = alpha / alpha.sum(axis=1, keepdims=True)
|
| 335 |
+
confusion_matrices.append(conf_matrix)
|
| 336 |
+
|
| 337 |
+
# Create subplots for each model
|
| 338 |
+
fig, axes = plt.subplots(1, num_models, figsize=(8, 2.8), dpi=150)
|
| 339 |
+
if num_models == 1:
|
| 340 |
+
axes = [axes]
|
| 341 |
+
|
| 342 |
+
for model_idx, (ax, conf_matrix) in enumerate(zip(axes, confusion_matrices)):
|
| 343 |
+
# Apply square root scaling to make small values more visible
|
| 344 |
+
# This expands small values while still showing large values
|
| 345 |
+
sqrt_conf_matrix = np.sqrt(np.sqrt(np.sqrt(np.sqrt(conf_matrix))))
|
| 346 |
+
|
| 347 |
+
# Plot confusion matrix as heatmap with sqrt-scaled values
|
| 348 |
+
im = ax.imshow(sqrt_conf_matrix, cmap='Blues', aspect='auto')#, vmin=0, vmax=1)
|
| 349 |
+
|
| 350 |
+
# Add model name as title
|
| 351 |
+
model_info = MODEL_INFO[model_idx]
|
| 352 |
+
ax.set_title(f"{model_info['name']}", fontsize=10, pad=5)
|
| 353 |
|
| 354 |
+
# Set axis labels
|
| 355 |
+
if model_idx == 0:
|
| 356 |
+
ax.set_ylabel('True class', fontsize=9)
|
| 357 |
+
ax.set_xlabel('Predicted', fontsize=9)
|
| 358 |
|
| 359 |
+
# Set ticks
|
| 360 |
+
ax.set_xticks(range(num_classes))
|
| 361 |
+
ax.set_yticks(range(num_classes))
|
| 362 |
+
ax.set_xticklabels(range(num_classes), fontsize=8)
|
| 363 |
+
ax.set_yticklabels(range(num_classes), fontsize=8)
|
| 364 |
|
| 365 |
+
plt.suptitle(f"CODA's Confusion Matrix Estimates (Iteration {iteration_count})", fontsize=12, y=0.98)
|
|
|
|
|
|
|
|
|
|
| 366 |
plt.tight_layout()
|
| 367 |
|
|
|
|
| 368 |
temp_fig = fig
|
| 369 |
plt.close(fig)
|
| 370 |
return temp_fig
|
| 371 |
|
| 372 |
+
# OLD CODE - True Model Accuracies Bar Chart (kept for easy reversion)
|
| 373 |
+
# def create_accuracy_chart():
|
| 374 |
+
# """Create a bar chart showing true accuracy of each model"""
|
| 375 |
+
# global oracle, dataset
|
| 376 |
+
#
|
| 377 |
+
# if oracle is None or dataset is None:
|
| 378 |
+
# # Fallback for initial state
|
| 379 |
+
# model_labels = [info['name'] for info in MODEL_INFO]
|
| 380 |
+
# accuracies = np.random.random(len(MODEL_INFO)) # Random accuracies for now
|
| 381 |
+
# else:
|
| 382 |
+
# true_losses = oracle.true_losses(dataset.preds)
|
| 383 |
+
# # Convert losses to accuracies (assuming loss is 1 - accuracy)
|
| 384 |
+
# accuracies = (1 - true_losses).detach().cpu().numpy().flatten()
|
| 385 |
+
# model_labels = [" " + info['name'] for info in MODEL_INFO[:len(accuracies)]]
|
| 386 |
+
#
|
| 387 |
+
# # Find the index of the highest accuracy
|
| 388 |
+
# best_idx = np.argmax(accuracies)
|
| 389 |
+
#
|
| 390 |
+
# fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150)
|
| 391 |
+
#
|
| 392 |
+
# # Create colors array - highlight the best model
|
| 393 |
+
# colors = ['red' if i == best_idx else 'forestgreen' for i in range(len(model_labels))]
|
| 394 |
+
# bars = ax.bar(range(len(model_labels)), accuracies, color=colors, alpha=0.7)
|
| 395 |
+
#
|
| 396 |
+
# # Add text above the highest bar
|
| 397 |
+
# ax.text(best_idx, accuracies[best_idx] + 0.005, 'True best model',
|
| 398 |
+
# ha='center', va='bottom', fontsize=12, fontweight='bold')
|
| 399 |
+
#
|
| 400 |
+
# ax.set_ylabel('True (oracle) \naccuracy of model', fontsize=12)
|
| 401 |
+
# ax.set_title('True Model Accuracies', fontsize=12)
|
| 402 |
+
# ax.set_ylim(np.min(accuracies) - 0.025, np.max(accuracies) + 0.05)
|
| 403 |
+
#
|
| 404 |
+
# # Set x-axis labels and ticks
|
| 405 |
+
# ax.set_xticks(range(len(model_labels)))
|
| 406 |
+
# ax.set_xticklabels(model_labels, fontsize=12, ha='center')
|
| 407 |
+
#
|
| 408 |
+
# # Add logos to x-axis
|
| 409 |
+
# for i, model_info in enumerate(MODEL_INFO[:len(accuracies)]):
|
| 410 |
+
# add_logo_to_x_axis(ax, i, model_info['logo'], model_info['name'])
|
| 411 |
+
# plt.yticks(fontsize=12)
|
| 412 |
+
# plt.tight_layout()
|
| 413 |
+
#
|
| 414 |
+
# # Save the figure and close it to prevent memory leaks
|
| 415 |
+
# temp_fig = fig
|
| 416 |
+
# plt.close(fig)
|
| 417 |
+
# return temp_fig
|
| 418 |
+
|
| 419 |
# Create the Gradio interface
|
| 420 |
with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge",
|
| 421 |
theme=gr.themes.Base(),
|
|
|
|
| 725 |
A shaggy, dark brown antelope recognized by its white rump ring and backward-curving horns in males. Smaller and darker than the common eland, waterbuck prefer wet habitats and lack the eland's throat dewlap.
|
| 726 |
|
| 727 |
----
|
| 728 |
+
|
| 729 |
""")
|
| 730 |
|
| 731 |
+
# Add spacing before buttons
|
| 732 |
+
gr.HTML("<div style='margin-top: 0.1em;'></div>")
|
| 733 |
+
|
| 734 |
with gr.Row():
|
| 735 |
back_button = gr.Button("← Back to Intro", variant="secondary", size="lg", visible=False)
|
| 736 |
guide_button = gr.Button("View Species Classification Guide", variant="secondary", size="lg")
|
|
|
|
| 879 |
# Create oracle and CODA selector for this user
|
| 880 |
oracle = Oracle(dataset, loss_fn=loss_fn)
|
| 881 |
coda_selector = CODA(dataset,
|
| 882 |
+
learning_rate=DEMO_LEARNING_RATE,
|
| 883 |
+
alpha=DEMO_ALPHA)
|
| 884 |
|
| 885 |
image, status, predictions = get_next_coda_image()
|
| 886 |
prob_plot = create_probability_chart()
|
|
|
|
| 919 |
# Create oracle and CODA selector for this user
|
| 920 |
oracle = Oracle(dataset, loss_fn=loss_fn)
|
| 921 |
coda_selector = CODA(dataset,
|
| 922 |
+
learning_rate=DEMO_LEARNING_RATE,
|
| 923 |
+
alpha=DEMO_ALPHA)
|
| 924 |
|
| 925 |
# Reset all displays
|
| 926 |
prob_plot = create_probability_chart()
|