justinkay commited on
Commit
0cb5bc4
·
1 Parent(s): ce6dd47

Go back to accuracy plot

Browse files
Files changed (1) hide show
  1. app.py +59 -47
app.py CHANGED
@@ -61,7 +61,10 @@ MODEL_INFO = [
61
  ]
62
 
63
  DEMO_LEARNING_RATE = 0.05 # don't use default; use something more fun
64
- DEMO_ALPHA = 0.25
 
 
 
65
 
66
  def create_species_guide_content():
67
  """Create the species identification guide content"""
@@ -370,6 +373,15 @@ def create_probability_chart():
370
  return temp_fig
371
 
372
  def create_accuracy_chart():
 
 
 
 
 
 
 
 
 
373
  """Create confusion matrix estimates for each model side by side"""
374
  global coda_selector, iteration_count
375
 
@@ -438,52 +450,52 @@ def create_accuracy_chart():
438
  plt.close(fig)
439
  return temp_fig
440
 
441
- # OLD CODE - True Model Accuracies Bar Chart (kept for easy reversion)
442
- # def create_accuracy_chart():
443
- # """Create a bar chart showing true accuracy of each model"""
444
- # global oracle, dataset
445
- #
446
- # if oracle is None or dataset is None:
447
- # # Fallback for initial state
448
- # model_labels = [info['name'] for info in MODEL_INFO]
449
- # accuracies = np.random.random(len(MODEL_INFO)) # Random accuracies for now
450
- # else:
451
- # true_losses = oracle.true_losses(dataset.preds)
452
- # # Convert losses to accuracies (assuming loss is 1 - accuracy)
453
- # accuracies = (1 - true_losses).detach().cpu().numpy().flatten()
454
- # model_labels = [" " + info['name'] for info in MODEL_INFO[:len(accuracies)]]
455
- #
456
- # # Find the index of the highest accuracy
457
- # best_idx = np.argmax(accuracies)
458
- #
459
- # fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150)
460
- #
461
- # # Create colors array - highlight the best model
462
- # colors = ['red' if i == best_idx else 'forestgreen' for i in range(len(model_labels))]
463
- # bars = ax.bar(range(len(model_labels)), accuracies, color=colors, alpha=0.7)
464
- #
465
- # # Add text above the highest bar
466
- # ax.text(best_idx, accuracies[best_idx] + 0.005, 'True best model',
467
- # ha='center', va='bottom', fontsize=12, fontweight='bold')
468
- #
469
- # ax.set_ylabel('True (oracle) \naccuracy of model', fontsize=12)
470
- # ax.set_title('True Model Accuracies', fontsize=12)
471
- # ax.set_ylim(np.min(accuracies) - 0.025, np.max(accuracies) + 0.05)
472
- #
473
- # # Set x-axis labels and ticks
474
- # ax.set_xticks(range(len(model_labels)))
475
- # ax.set_xticklabels(model_labels, fontsize=12, ha='center')
476
- #
477
- # # Add logos to x-axis
478
- # for i, model_info in enumerate(MODEL_INFO[:len(accuracies)]):
479
- # add_logo_to_x_axis(ax, i, model_info['logo'], model_info['name'])
480
- # plt.yticks(fontsize=12)
481
- # plt.tight_layout()
482
- #
483
- # # Save the figure and close it to prevent memory leaks
484
- # temp_fig = fig
485
- # plt.close(fig)
486
- # return temp_fig
487
 
488
  # Create the Gradio interface
489
  with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge",
 
61
  ]
62
 
63
  DEMO_LEARNING_RATE = 0.05 # don't use default; use something more fun
64
+ DEMO_ALPHA = 0.9 # 0.25 # this is more fun if showing the confusion matrices
65
+
66
+ # Toggle between confusion matrix and accuracy chart
67
+ USE_CONFUSION_MATRIX = False # Set to True for confusion matrices, False for accuracy bars
68
 
69
  def create_species_guide_content():
70
  """Create the species identification guide content"""
 
373
  return temp_fig
374
 
375
  def create_accuracy_chart():
376
+ """Create either confusion matrices or accuracy bar chart based on USE_CONFUSION_MATRIX toggle"""
377
+ global coda_selector, oracle, dataset, iteration_count
378
+
379
+ if USE_CONFUSION_MATRIX:
380
+ return create_confusion_matrix_chart()
381
+ else:
382
+ return create_accuracy_bar_chart()
383
+
384
+ def create_confusion_matrix_chart():
385
  """Create confusion matrix estimates for each model side by side"""
386
  global coda_selector, iteration_count
387
 
 
450
  plt.close(fig)
451
  return temp_fig
452
 
453
+ def create_accuracy_bar_chart():
454
+ """Create a bar chart showing true accuracy of each model (with muted colors)"""
455
+ global oracle, dataset
456
+
457
+ if oracle is None or dataset is None:
458
+ # Fallback for initial state
459
+ fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150)
460
+ ax.text(0.5, 0.5, 'Start demo to see model accuracies',
461
+ ha='center', va='center', fontsize=12)
462
+ ax.axis('off')
463
+ plt.tight_layout()
464
+ temp_fig = fig
465
+ plt.close(fig)
466
+ return temp_fig
467
+
468
+ true_losses = oracle.true_losses(dataset.preds)
469
+ # Convert losses to accuracies (assuming loss is 1 - accuracy)
470
+ accuracies = (1 - true_losses).detach().cpu().numpy().flatten()
471
+ model_labels = [" "*(9 if info['name']=='LAION CLIP' else 4 if info['name']=='SigLIP2' else 6) + info['name'] for info in MODEL_INFO[:len(accuracies)]]
472
+
473
+ # Find the index of the highest accuracy
474
+ best_idx = np.argmax(accuracies)
475
+
476
+ fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150)
477
+
478
+ # Use muted, uniform color for all bars (no highlighting)
479
+ bars = ax.bar(range(len(model_labels)), accuracies, color='#8B9DC3', alpha=0.6)
480
+
481
+ ax.set_ylabel('True (oracle) \naccuracy of model', fontsize=12)
482
+ ax.set_title('True Model Accuracies', fontsize=12)
483
+ ax.set_ylim(np.min(accuracies) - 0.025, np.max(accuracies) + 0.05)
484
+
485
+ # Set x-axis labels and ticks
486
+ ax.set_xticks(range(len(model_labels)))
487
+ ax.set_xticklabels(model_labels, fontsize=12, ha='center')
488
+
489
+ # Add logos to x-axis
490
+ for i, model_info in enumerate(MODEL_INFO[:len(accuracies)]):
491
+ add_logo_to_x_axis(ax, i, model_info['logo'], model_info['name'])
492
+ plt.yticks(fontsize=12)
493
+ plt.tight_layout()
494
+
495
+ # Save the figure and close it to prevent memory leaks
496
+ temp_fig = fig
497
+ plt.close(fig)
498
+ return temp_fig
499
 
500
  # Create the Gradio interface
501
  with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge",