justinkay
commited on
Commit
·
0cb5bc4
1
Parent(s):
ce6dd47
Go back to accuracy plot
Browse files
app.py
CHANGED
|
@@ -61,7 +61,10 @@ MODEL_INFO = [
|
|
| 61 |
]
|
| 62 |
|
| 63 |
DEMO_LEARNING_RATE = 0.05 # don't use default; use something more fun
|
| 64 |
-
DEMO_ALPHA = 0.25
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
def create_species_guide_content():
|
| 67 |
"""Create the species identification guide content"""
|
|
@@ -370,6 +373,15 @@ def create_probability_chart():
|
|
| 370 |
return temp_fig
|
| 371 |
|
| 372 |
def create_accuracy_chart():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
"""Create confusion matrix estimates for each model side by side"""
|
| 374 |
global coda_selector, iteration_count
|
| 375 |
|
|
@@ -438,52 +450,52 @@ def create_accuracy_chart():
|
|
| 438 |
plt.close(fig)
|
| 439 |
return temp_fig
|
| 440 |
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
#
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
#
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
#
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
#
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
#
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
#
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
#
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
|
| 488 |
# Create the Gradio interface
|
| 489 |
with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge",
|
|
|
|
| 61 |
]
|
| 62 |
|
| 63 |
DEMO_LEARNING_RATE = 0.05 # don't use default; use something more fun
|
| 64 |
+
DEMO_ALPHA = 0.9 # 0.25 # this is more fun if showing the confusion matrices
|
| 65 |
+
|
| 66 |
+
# Toggle between confusion matrix and accuracy chart
|
| 67 |
+
USE_CONFUSION_MATRIX = False # Set to True for confusion matrices, False for accuracy bars
|
| 68 |
|
| 69 |
def create_species_guide_content():
|
| 70 |
"""Create the species identification guide content"""
|
|
|
|
| 373 |
return temp_fig
|
| 374 |
|
| 375 |
def create_accuracy_chart():
|
| 376 |
+
"""Create either confusion matrices or accuracy bar chart based on USE_CONFUSION_MATRIX toggle"""
|
| 377 |
+
global coda_selector, oracle, dataset, iteration_count
|
| 378 |
+
|
| 379 |
+
if USE_CONFUSION_MATRIX:
|
| 380 |
+
return create_confusion_matrix_chart()
|
| 381 |
+
else:
|
| 382 |
+
return create_accuracy_bar_chart()
|
| 383 |
+
|
| 384 |
+
def create_confusion_matrix_chart():
|
| 385 |
"""Create confusion matrix estimates for each model side by side"""
|
| 386 |
global coda_selector, iteration_count
|
| 387 |
|
|
|
|
| 450 |
plt.close(fig)
|
| 451 |
return temp_fig
|
| 452 |
|
| 453 |
+
def create_accuracy_bar_chart():
|
| 454 |
+
"""Create a bar chart showing true accuracy of each model (with muted colors)"""
|
| 455 |
+
global oracle, dataset
|
| 456 |
+
|
| 457 |
+
if oracle is None or dataset is None:
|
| 458 |
+
# Fallback for initial state
|
| 459 |
+
fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150)
|
| 460 |
+
ax.text(0.5, 0.5, 'Start demo to see model accuracies',
|
| 461 |
+
ha='center', va='center', fontsize=12)
|
| 462 |
+
ax.axis('off')
|
| 463 |
+
plt.tight_layout()
|
| 464 |
+
temp_fig = fig
|
| 465 |
+
plt.close(fig)
|
| 466 |
+
return temp_fig
|
| 467 |
+
|
| 468 |
+
true_losses = oracle.true_losses(dataset.preds)
|
| 469 |
+
# Convert losses to accuracies (assuming loss is 1 - accuracy)
|
| 470 |
+
accuracies = (1 - true_losses).detach().cpu().numpy().flatten()
|
| 471 |
+
model_labels = [" "*(9 if info['name']=='LAION CLIP' else 4 if info['name']=='SigLIP2' else 6) + info['name'] for info in MODEL_INFO[:len(accuracies)]]
|
| 472 |
+
|
| 473 |
+
# Find the index of the highest accuracy
|
| 474 |
+
best_idx = np.argmax(accuracies)
|
| 475 |
+
|
| 476 |
+
fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150)
|
| 477 |
+
|
| 478 |
+
# Use muted, uniform color for all bars (no highlighting)
|
| 479 |
+
bars = ax.bar(range(len(model_labels)), accuracies, color='#8B9DC3', alpha=0.6)
|
| 480 |
+
|
| 481 |
+
ax.set_ylabel('True (oracle) \naccuracy of model', fontsize=12)
|
| 482 |
+
ax.set_title('True Model Accuracies', fontsize=12)
|
| 483 |
+
ax.set_ylim(np.min(accuracies) - 0.025, np.max(accuracies) + 0.05)
|
| 484 |
+
|
| 485 |
+
# Set x-axis labels and ticks
|
| 486 |
+
ax.set_xticks(range(len(model_labels)))
|
| 487 |
+
ax.set_xticklabels(model_labels, fontsize=12, ha='center')
|
| 488 |
+
|
| 489 |
+
# Add logos to x-axis
|
| 490 |
+
for i, model_info in enumerate(MODEL_INFO[:len(accuracies)]):
|
| 491 |
+
add_logo_to_x_axis(ax, i, model_info['logo'], model_info['name'])
|
| 492 |
+
plt.yticks(fontsize=12)
|
| 493 |
+
plt.tight_layout()
|
| 494 |
+
|
| 495 |
+
# Save the figure and close it to prevent memory leaks
|
| 496 |
+
temp_fig = fig
|
| 497 |
+
plt.close(fig)
|
| 498 |
+
return temp_fig
|
| 499 |
|
| 500 |
# Create the Gradio interface
|
| 501 |
with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge",
|