justinkay commited on
Commit
eaf554e
·
1 Parent(s): f6adf18

Confusion matrix plot

Browse files
Files changed (1) hide show
  1. app.py +109 -38
app.py CHANGED
@@ -59,6 +59,7 @@ MODEL_INFO = [
59
  ]
60
 
61
  DEMO_LEARNING_RATE = 0.05 # don't use default; use something more fun
 
62
 
63
  # load image metadata
64
  images_data = []
@@ -305,51 +306,116 @@ def create_probability_chart():
305
  return temp_fig
306
 
307
  def create_accuracy_chart():
308
- """Create a bar chart showing true accuracy of each model"""
309
- global oracle, dataset
310
 
311
- if oracle is None or dataset is None:
312
- # Fallback for initial state
313
- model_labels = [info['name'] for info in MODEL_INFO]
314
- accuracies = np.random.random(len(MODEL_INFO)) # Random accuracies for now
315
- else:
316
- true_losses = oracle.true_losses(dataset.preds)
317
- # Convert losses to accuracies (assuming loss is 1 - accuracy)
318
- accuracies = (1 - true_losses).detach().cpu().numpy().flatten()
319
- model_labels = [" " + info['name'] for info in MODEL_INFO[:len(accuracies)]]
320
-
321
- # Find the index of the highest accuracy
322
- best_idx = np.argmax(accuracies)
323
-
324
- fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150)
325
-
326
- # Create colors array - highlight the best model
327
- colors = ['red' if i == best_idx else 'forestgreen' for i in range(len(model_labels))]
328
- bars = ax.bar(range(len(model_labels)), accuracies, color=colors, alpha=0.7)
329
-
330
- # Add text above the highest bar
331
- ax.text(best_idx, accuracies[best_idx] + 0.005, 'True best model',
332
- ha='center', va='bottom', fontsize=12, fontweight='bold')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
- ax.set_ylabel('True (oracle) \naccuracy of model', fontsize=12)
335
- ax.set_title('True Model Accuracies', fontsize=12)
336
- ax.set_ylim(np.min(accuracies) - 0.025, np.max(accuracies) + 0.05)
 
337
 
338
- # Set x-axis labels and ticks
339
- ax.set_xticks(range(len(model_labels)))
340
- ax.set_xticklabels(model_labels, fontsize=12, ha='center')
 
 
341
 
342
- # Add logos to x-axis
343
- for i, model_info in enumerate(MODEL_INFO[:len(accuracies)]):
344
- add_logo_to_x_axis(ax, i, model_info['logo'], model_info['name'])
345
- plt.yticks(fontsize=12)
346
  plt.tight_layout()
347
 
348
- # Save the figure and close it to prevent memory leaks
349
  temp_fig = fig
350
  plt.close(fig)
351
  return temp_fig
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  # Create the Gradio interface
354
  with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge",
355
  theme=gr.themes.Base(),
@@ -659,9 +725,12 @@ with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge",
659
  A shaggy, dark brown antelope recognized by its white rump ring and backward-curving horns in males. Smaller and darker than the common eland, waterbuck prefer wet habitats and lack the eland's throat dewlap.
660
 
661
  ----
662
-
663
  """)
664
 
 
 
 
665
  with gr.Row():
666
  back_button = gr.Button("← Back to Intro", variant="secondary", size="lg", visible=False)
667
  guide_button = gr.Button("View Species Classification Guide", variant="secondary", size="lg")
@@ -810,7 +879,8 @@ with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge",
810
  # Create oracle and CODA selector for this user
811
  oracle = Oracle(dataset, loss_fn=loss_fn)
812
  coda_selector = CODA(dataset,
813
- learning_rate=DEMO_LEARNING_RATE)
 
814
 
815
  image, status, predictions = get_next_coda_image()
816
  prob_plot = create_probability_chart()
@@ -849,7 +919,8 @@ with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge",
849
  # Create oracle and CODA selector for this user
850
  oracle = Oracle(dataset, loss_fn=loss_fn)
851
  coda_selector = CODA(dataset,
852
- learning_rate=DEMO_LEARNING_RATE)
 
853
 
854
  # Reset all displays
855
  prob_plot = create_probability_chart()
 
59
  ]
60
 
61
  DEMO_LEARNING_RATE = 0.05 # don't use default; use something more fun
62
+ DEMO_ALPHA = 0.25
63
 
64
  # load image metadata
65
  images_data = []
 
306
  return temp_fig
307
 
308
  def create_accuracy_chart():
309
+ """Create confusion matrix estimates for each model side by side"""
310
+ global coda_selector, iteration_count
311
 
312
+ if coda_selector is None:
313
+ # Fallback for initial state - return empty figure
314
+ fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150)
315
+ ax.text(0.5, 0.5, 'Start demo to see confusion matrices',
316
+ ha='center', va='center', fontsize=12)
317
+ ax.axis('off')
318
+ plt.tight_layout()
319
+ temp_fig = fig
320
+ plt.close(fig)
321
+ return temp_fig
322
+
323
+ # Get confusion matrix estimates from CODA's Dirichlet distributions
324
+ dirichlets = coda_selector.dirichlets # Shape: [num_models, num_classes, num_classes]
325
+ num_models = dirichlets.shape[0]
326
+ num_classes = dirichlets.shape[1]
327
+
328
+ # Convert Dirichlet parameters to expected confusion matrices
329
+ # The expected value of a Dirichlet is alpha / sum(alpha)
330
+ confusion_matrices = []
331
+ for model_idx in range(num_models):
332
+ alpha = dirichlets[model_idx].detach().cpu().numpy()
333
+ # Normalize each row to get probabilities
334
+ conf_matrix = alpha / alpha.sum(axis=1, keepdims=True)
335
+ confusion_matrices.append(conf_matrix)
336
+
337
+ # Create subplots for each model
338
+ fig, axes = plt.subplots(1, num_models, figsize=(8, 2.8), dpi=150)
339
+ if num_models == 1:
340
+ axes = [axes]
341
+
342
+ for model_idx, (ax, conf_matrix) in enumerate(zip(axes, confusion_matrices)):
343
+ # Apply square root scaling to make small values more visible
344
+ # This expands small values while still showing large values
345
+ sqrt_conf_matrix = np.sqrt(np.sqrt(np.sqrt(np.sqrt(conf_matrix))))
346
+
347
+ # Plot confusion matrix as heatmap with sqrt-scaled values
348
+ im = ax.imshow(sqrt_conf_matrix, cmap='Blues', aspect='auto')#, vmin=0, vmax=1)
349
+
350
+ # Add model name as title
351
+ model_info = MODEL_INFO[model_idx]
352
+ ax.set_title(f"{model_info['name']}", fontsize=10, pad=5)
353
 
354
+ # Set axis labels
355
+ if model_idx == 0:
356
+ ax.set_ylabel('True class', fontsize=9)
357
+ ax.set_xlabel('Predicted', fontsize=9)
358
 
359
+ # Set ticks
360
+ ax.set_xticks(range(num_classes))
361
+ ax.set_yticks(range(num_classes))
362
+ ax.set_xticklabels(range(num_classes), fontsize=8)
363
+ ax.set_yticklabels(range(num_classes), fontsize=8)
364
 
365
+ plt.suptitle(f"CODA's Confusion Matrix Estimates (Iteration {iteration_count})", fontsize=12, y=0.98)
 
 
 
366
  plt.tight_layout()
367
 
 
368
  temp_fig = fig
369
  plt.close(fig)
370
  return temp_fig
371
 
372
+ # OLD CODE - True Model Accuracies Bar Chart (kept for easy reversion)
373
+ # def create_accuracy_chart():
374
+ # """Create a bar chart showing true accuracy of each model"""
375
+ # global oracle, dataset
376
+ #
377
+ # if oracle is None or dataset is None:
378
+ # # Fallback for initial state
379
+ # model_labels = [info['name'] for info in MODEL_INFO]
380
+ # accuracies = np.random.random(len(MODEL_INFO)) # Random accuracies for now
381
+ # else:
382
+ # true_losses = oracle.true_losses(dataset.preds)
383
+ # # Convert losses to accuracies (assuming loss is 1 - accuracy)
384
+ # accuracies = (1 - true_losses).detach().cpu().numpy().flatten()
385
+ # model_labels = [" " + info['name'] for info in MODEL_INFO[:len(accuracies)]]
386
+ #
387
+ # # Find the index of the highest accuracy
388
+ # best_idx = np.argmax(accuracies)
389
+ #
390
+ # fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150)
391
+ #
392
+ # # Create colors array - highlight the best model
393
+ # colors = ['red' if i == best_idx else 'forestgreen' for i in range(len(model_labels))]
394
+ # bars = ax.bar(range(len(model_labels)), accuracies, color=colors, alpha=0.7)
395
+ #
396
+ # # Add text above the highest bar
397
+ # ax.text(best_idx, accuracies[best_idx] + 0.005, 'True best model',
398
+ # ha='center', va='bottom', fontsize=12, fontweight='bold')
399
+ #
400
+ # ax.set_ylabel('True (oracle) \naccuracy of model', fontsize=12)
401
+ # ax.set_title('True Model Accuracies', fontsize=12)
402
+ # ax.set_ylim(np.min(accuracies) - 0.025, np.max(accuracies) + 0.05)
403
+ #
404
+ # # Set x-axis labels and ticks
405
+ # ax.set_xticks(range(len(model_labels)))
406
+ # ax.set_xticklabels(model_labels, fontsize=12, ha='center')
407
+ #
408
+ # # Add logos to x-axis
409
+ # for i, model_info in enumerate(MODEL_INFO[:len(accuracies)]):
410
+ # add_logo_to_x_axis(ax, i, model_info['logo'], model_info['name'])
411
+ # plt.yticks(fontsize=12)
412
+ # plt.tight_layout()
413
+ #
414
+ # # Save the figure and close it to prevent memory leaks
415
+ # temp_fig = fig
416
+ # plt.close(fig)
417
+ # return temp_fig
418
+
419
  # Create the Gradio interface
420
  with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge",
421
  theme=gr.themes.Base(),
 
725
  A shaggy, dark brown antelope recognized by its white rump ring and backward-curving horns in males. Smaller and darker than the common eland, waterbuck prefer wet habitats and lack the eland's throat dewlap.
726
 
727
  ----
728
+
729
  """)
730
 
731
+ # Add spacing before buttons
732
+ gr.HTML("<div style='margin-top: 0.1em;'></div>")
733
+
734
  with gr.Row():
735
  back_button = gr.Button("← Back to Intro", variant="secondary", size="lg", visible=False)
736
  guide_button = gr.Button("View Species Classification Guide", variant="secondary", size="lg")
 
879
  # Create oracle and CODA selector for this user
880
  oracle = Oracle(dataset, loss_fn=loss_fn)
881
  coda_selector = CODA(dataset,
882
+ learning_rate=DEMO_LEARNING_RATE,
883
+ alpha=DEMO_ALPHA)
884
 
885
  image, status, predictions = get_next_coda_image()
886
  prob_plot = create_probability_chart()
 
919
  # Create oracle and CODA selector for this user
920
  oracle = Oracle(dataset, loss_fn=loss_fn)
921
  coda_selector = CODA(dataset,
922
+ learning_rate=DEMO_LEARNING_RATE,
923
+ alpha=DEMO_ALPHA)
924
 
925
  # Reset all displays
926
  prob_plot = create_probability_chart()