File size: 60,223 Bytes
5dd4d41
 
 
 
 
 
 
 
 
47b7c6b
 
991278f
 
47b7c6b
 
 
5dd4d41
 
 
 
 
 
 
 
 
991278f
 
 
 
 
 
 
 
 
 
 
5dd4d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46632d2
5dd4d41
 
7569088
 
5dd4d41
 
f6adf18
0cb5bc4
 
 
 
f6adf18
783e612
 
 
 
 
 
 
 
 
 
 
 
 
 
3496b5b
783e612
 
 
 
 
 
 
 
 
 
3496b5b
783e612
 
 
 
 
 
 
 
 
3496b5b
783e612
 
 
 
 
 
 
 
 
 
3496b5b
783e612
 
 
 
 
 
 
 
 
3496b5b
783e612
 
 
 
 
5dd4d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c56581
5dd4d41
3c56581
5dd4d41
3aaf8da
 
 
 
 
3c56581
3aaf8da
3c56581
3aaf8da
 
3c56581
3aaf8da
 
3c56581
 
3aaf8da
3c56581
5dd4d41
 
3c56581
5dd4d41
3c56581
 
 
 
5dd4d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6adf18
5dd4d41
 
 
 
 
 
 
 
 
 
 
7569088
5dd4d41
 
 
46632d2
5dd4d41
 
 
 
 
 
 
 
 
 
 
 
 
f6dfc71
5dd4d41
 
 
 
 
f6dfc71
5dd4d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
991278f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dd4d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
072e3b7
5dd4d41
072e3b7
5dd4d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7569088
5dd4d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cb5bc4
 
 
 
 
 
 
 
 
eaf554e
 
5dd4d41
eaf554e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46632d2
 
 
eaf554e
 
 
4c8935a
 
 
eaf554e
 
 
 
 
 
 
 
 
 
 
5dd4d41
eaf554e
 
 
 
5dd4d41
4c8935a
eaf554e
 
4c8935a
 
5dd4d41
eaf554e
5dd4d41
 
 
 
 
 
0cb5bc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dd112b
 
 
f6dfc71
 
 
 
0cb5bc4
 
 
f6dfc71
 
 
 
 
 
 
 
1dd112b
 
f6dfc71
1dd112b
0cb5bc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eaf554e
5dd4d41
 
 
 
 
 
79b3f7b
5dd4d41
 
 
 
79b3f7b
5dd4d41
 
ec209fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dd4d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79b3f7b
5dd4d41
 
 
 
 
 
 
 
 
79b3f7b
5dd4d41
 
 
79b3f7b
5dd4d41
 
 
 
 
 
 
79b3f7b
 
 
 
 
 
 
 
5dd4d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
072e3b7
 
 
 
 
5dd4d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79b3f7b
5dd4d41
 
 
 
 
 
 
79b3f7b
5dd4d41
 
 
79b3f7b
5dd4d41
 
 
 
 
79b3f7b
 
 
 
 
 
 
 
5dd4d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79b3f7b
 
 
 
 
f6dfc71
 
 
5dd4d41
 
5c7e627
5dd4d41
1f96d2c
 
 
 
 
5dd4d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
046c99d
5dd4d41
 
 
f6adf18
5dd4d41
 
 
 
783e612
5dd4d41
eaf554e
 
 
5dd4d41
 
4935755
 
5dd4d41
 
 
 
 
 
 
3496b5b
5dd4d41
 
3496b5b
5dd4d41
 
 
 
 
3496b5b
 
5dd4d41
 
 
3496b5b
 
 
 
 
 
 
 
 
 
5dd4d41
3496b5b
5dd4d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3496b5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dd4d41
 
3496b5b
 
5dd4d41
783e612
 
 
 
 
 
 
 
 
 
1f96d2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec209fa
1f96d2c
 
 
 
 
 
ec209fa
1f96d2c
 
 
 
 
 
 
5dd4d41
 
 
 
 
 
 
 
 
 
04eb087
 
 
 
 
 
1f96d2c
04eb087
1f96d2c
04eb087
1f96d2c
04eb087
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aac60bf
04eb087
1f96d2c
046c99d
 
 
04eb087
1f96d2c
04eb087
 
1f96d2c
04eb087
 
 
 
1f96d2c
04eb087
 
 
 
 
 
090bc1e
046c99d
 
 
 
 
 
 
 
 
 
 
090bc1e
046c99d
 
090bc1e
 
 
04eb087
046c99d
 
991278f
 
 
 
 
5dd4d41
 
3c56581
 
5dd4d41
 
3c56581
d50bd8f
046c99d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d50bd8f
046c99d
 
d50bd8f
046c99d
d50bd8f
 
 
 
 
 
 
046c99d
 
5dd4d41
 
 
 
 
 
 
 
 
3c56581
 
5dd4d41
 
3c56581
d50bd8f
046c99d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d50bd8f
046c99d
 
d50bd8f
 
 
 
 
 
046c99d
d50bd8f
 
046c99d
 
5dd4d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
783e612
 
 
 
 
 
daf0675
672baca
 
daf0675
5dd4d41
 
046c99d
 
 
 
 
 
 
 
 
aac60bf
046c99d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dd4d41
 
 
 
046c99d
 
 
 
 
 
 
 
 
aac60bf
046c99d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dd4d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04eb087
 
 
 
5dd4d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daf0675
 
 
04eb087
daf0675
 
783e612
 
 
 
 
 
 
 
 
 
 
5dd4d41
 
 
 
 
 
 
 
 
 
 
 
 
991278f
 
 
 
 
 
 
f6dfc71
5dd4d41
 
 
 
 
f6dfc71
5dd4d41
 
 
 
 
 
 
 
 
 
 
 
f6dfc71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dd4d41
f6dfc71
5dd4d41
 
 
 
 
faf6f83
f6dfc71
faf6f83
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
import os
os.environ['GRADIO_TEMP_DIR'] = "tmp/"

import gradio as gr
import json
import random
from PIL import Image
from tqdm import tqdm
from collections import OrderedDict
import numpy as np
import torch
import shutil
import argparse

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

from coda import CODA
from coda.datasets import Dataset
from coda.options import LOSS_FNS
from coda.oracle import Oracle

# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--debug', action='store_true', help='Enable debug mode with delete button')
args_cli = parser.parse_args()
DEBUG_MODE = args_cli.debug

if DEBUG_MODE:
    print("Debug mode enabled - delete button will be available")
    # Create deleted_in_app directory if it doesn't exist
    os.makedirs('deleted_in_app', exist_ok=True)


with open('iwildcam_demo_annotations.json', 'r') as f:
    data = json.load(f)

SPECIES_MAP = OrderedDict([
    (24, "Jaguar"),           # panthera onca
    (10, "Ocelot"),           # leopardus pardalis
    (6, "Mountain Lion"),     # puma concolor
    (101, "Common Eland"),    # tragelaphus oryx
    (102, "Waterbuck"),       # kobus ellipsiprymnus
])
NAME_TO_ID = {name: id for id, name in SPECIES_MAP.items()}

# Class names in order (0-4) from classes.txt
CLASS_NAMES = ["Jaguar", "Ocelot", "Mountain Lion", "Common Eland", "Waterbuck"]
NAME_TO_CLASS_IDX = {name: idx for idx, name in enumerate(CLASS_NAMES)}

# Model information from models.txt
MODEL_INFO = [
    {"org": "Facebook", "name": "PE-Core", "logo": "logos/meta.png"},
    {"org": "Google", "name": "SigLIP2", "logo": "logos/google.png"},
    {"org": "OpenAI", "name": "CLIPViT-L", "logo": "logos/openai.png"},
    {"org": "Imageomics", "name": "BioCLIP2", "logo": "logos/imageomics.png"},
    {"org": "LAION", "name": "LAION CLIP", "logo": "logos/laion.png"}
]

DEMO_LEARNING_RATE = 0.05 # don't use default; use something more fun
DEMO_ALPHA = 0.9 # 0.25   # this is more fun if showing the confusion matrices

# Toggle between confusion matrix and accuracy chart
USE_CONFUSION_MATRIX = False  # Set to True for confusion matrices, False for accuracy bars

def create_species_guide_content():
    """Create the species identification guide content"""
    with gr.Column():
        gr.Markdown("""
        # Species Classification Guide

        ### Learn to identify the five wildlife species in this demo.

        ## Jaguar
        """)

        gr.Image("species_id/jaguar.jpg", label="Jaguar example image", show_label=False)

        gr.Markdown("""
        #### The largest cat in the Americas, with a stocky, muscular build and a broad head. Coat is patterned with rosettes that often have central spots inside.

        ----

        ## Ocelot

        """)

        gr.Image("species_id/ocelot.jpg", label="Ocelot example image", show_label=False)

        gr.Markdown("""
        #### Smaller and leaner than a jaguar, with more elongated markings and rounder ears.

        ----

        ## Mountain Lion
        """)

        gr.Image("species_id/mountainlion.jpg", label="Mountain lion example image", show_label=False)

        gr.Markdown("""
        #### Also called cougar or puma, this cat has a plain tawny or grayish coat without spots. Its long tail and uniformly colored fur distinguish it from jaguars and ocelots.

        ----

        ## Common Eland

        """)

        gr.Image("species_id/commoneland.jpg", label="Eland example image", show_label=False)

        gr.Markdown("""
        ### The largest antelope species. Identifiable by its spiraled horns on both sexes. Lighter tan coat than a waterbuck.

        ----

        ## Waterbuck
        """)

        gr.Image("species_id/waterbuck.jpg", label="Waterbuck example image", show_label=False)

        gr.Markdown("""
        #### A shaggy, dark brown antelope. Identifiable by backward-curving horns in males, no horns on females. Larger, rounder ears and darker coat than the common eland.

        ----

        """)

# load image metadata
images_data = []
for annotation in tqdm(data['annotations'], desc='Loading annotations'):
    image_id = annotation['image_id']
    category_id = annotation['category_id']
    image_info = next((img for img in data['images'] if img['id'] == image_id), None)
    if image_info:
        images_data.append({
            'filename': image_info['file_name'],
            'species_id': category_id,
            'species_name': SPECIES_MAP[category_id]
        })
print(f"Loaded {len(images_data)} images for the quiz")

# Load image filenames list
with open('images.txt', 'r') as f:
    full_image_filenames = [line.strip() for line in f.readlines() if line.strip()]

# Initialize full dataset (will be subsampled per-user)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load full dataset
full_preds = torch.load("iwildcam_demo.pt").to(device)
full_labels = torch.load("iwildcam_demo_labels.pt").to(device)

# Pre-compute class indices for subsampling
from collections import defaultdict
full_class_to_indices = defaultdict(list)
for idx, label in enumerate(full_labels):
    class_idx = label.item()
    full_class_to_indices[class_idx].append(idx)

# Find minimum class size
min_class_size = min(len(indices) for indices in full_class_to_indices.values())
print(f"Each user will get {min_class_size} images per class (total: {min_class_size * len(full_class_to_indices)} images per user)")

# Loss function for oracle
loss_fn = LOSS_FNS['acc']

# Global state (will be set per-user in start_demo)
current_image_info = None
coda_selector = None
oracle = None
dataset = None
image_filenames = None
iteration_count = 0

def get_model_predictions(chosen_idx):
    """Get model predictions and scores for a specific image"""
    global dataset

    if dataset is None or chosen_idx >= dataset.preds.shape[1]:
        return "No predictions available"

    # Get predictions for this image (shape: [num_models, num_classes])
    image_preds = dataset.preds[:, chosen_idx, :].detach().cpu().numpy()

    predictions_list = []

    for model_idx in range(image_preds.shape[0]):
        model_scores = image_preds[model_idx]
        predicted_class_idx = model_scores.argmax()
        predicted_class_name = CLASS_NAMES[predicted_class_idx]
        confidence = model_scores[predicted_class_idx]

        model_info = MODEL_INFO[model_idx]
        predictions_list.append(f"**{model_info['name']}:** {predicted_class_name} *({confidence:.3f})*")

    predictions_text = "### Model Predictions\n\n" + " | ".join(predictions_list)

    return predictions_text

def add_logo_to_x_axis(ax, x_pos, logo_path, model_name, height_px=35):
    """Add a logo image to x-axis next to model name"""
    try:
        img = mpimg.imread(logo_path)
        # Calculate zoom to achieve desired height in pixels
        # Rough conversion: height_px / image_height / dpi * 72
        zoom = height_px / min(img.shape[0],img.shape[1]) / ax.figure.dpi * 72
        imagebox = OffsetImage(img, zoom=zoom)

        # Position logo to the left of the x-tick
        logo_offset = -0.28  # Adjust this to move logo left/right relative to tick
        y_offset = -0.08
        ab = AnnotationBbox(imagebox, (x_pos + logo_offset, y_offset),
                           xycoords=('data', 'axes fraction'), frameon=False)
        ax.add_artist(ab)
    except Exception as e:
        print(f"Could not load logo {logo_path}: {e}")

def get_next_coda_image():
    """Get the next image that CODA wants labeled"""
    global current_image_info, coda_selector, iteration_count

    # Get next item from CODA
    chosen_idx, selection_prob = coda_selector.get_next_item_to_label()
    print("CODA chosen_idx, selection prob:", chosen_idx, selection_prob)

    # Get the corresponding image filename
    if chosen_idx < len(image_filenames):
        filename = image_filenames[chosen_idx]
        image_path = os.path.join('iwildcam_demo_images', filename)
        print("Next image is", filename)

        # Find the corresponding annotation for this image
        current_image_info = None
        for annotation in data['annotations']:
            image_id = annotation['image_id']
            image_info = next((img for img in data['images'] if img['id'] == image_id), None)
            if image_info and image_info['file_name'] == filename:
                current_image_info = {
                    'filename': filename,
                    'species_id': annotation['category_id'],
                    'species_name': SPECIES_MAP[annotation['category_id']],
                    'chosen_idx': chosen_idx,
                    'selection_prob': selection_prob
                }
                break

        try:
            image = Image.open(image_path)
            predictions = get_model_predictions(chosen_idx)
            return image, f"Iteration {iteration_count}: CODA selected this image for labeling", predictions
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            return None, f"Error loading image: {e}", "No predictions available"
    else:
        return None, "Image index out of range", "No predictions available"

def delete_current_image():
    """Delete the current image by moving it to deleted_in_app directory"""
    global current_image_info, coda_selector

    if current_image_info is None:
        return "No image to delete!", None, "No predictions", None, None, ""

    filename = current_image_info['filename']
    chosen_idx = current_image_info['chosen_idx']
    source_path = os.path.join('iwildcam_demo_images', filename)
    dest_path = os.path.join('deleted_in_app', filename)

    try:
        shutil.move(source_path, dest_path)
        result = f"βœ“ Moved {filename} to deleted_in_app/"
        print(f"Deleted image: {filename}")

        # Remove from CODA's unlabeled indices without adding a label
        if chosen_idx in coda_selector.unlabeled_idxs:
            coda_selector.unlabeled_idxs.remove(chosen_idx)
    except Exception as e:
        result = f"Error deleting image: {e}"
        print(f"Error deleting {filename}: {e}")

    # Load next image
    next_image, status, predictions = get_next_coda_image()
    status_html = f'{status} <span class="inline-help-btn" title="What is this?">?</span>'

    # Get updated plots
    prob_plot = create_probability_chart()
    accuracy_plot = create_accuracy_chart()

    return result, next_image, predictions, prob_plot, accuracy_plot, status_html

def check_answer(user_choice):
    """Process user's label and update CODA"""
    global current_image_info, coda_selector, iteration_count

    if current_image_info is None:
        return "Please load an image first!", "", None, "No predictions", None, None

    correct_species = current_image_info['species_name']
    chosen_idx = current_image_info['chosen_idx']
    selection_prob = current_image_info['selection_prob']

    # Convert user choice to class index (0-5)
    if user_choice == "I don't know":
        # For "I don't know", just remove from sampling without providing label
        coda_selector.unlabeled_idxs.remove(chosen_idx)
        result = f"The last image was skipped and will not be used for model selection. The correct species was {correct_species}. "
    else:
        user_class_idx = NAME_TO_CLASS_IDX.get(user_choice, NAME_TO_CLASS_IDX[correct_species])
        if user_choice == correct_species:
            result = f"πŸŽ‰ Your last classification was correct! It was indeed a {correct_species}."
        else:
            result = f"❌ Your last classification was incorrect. It was a {correct_species}, not a {user_choice}. This may mislead the model selection process!"

        # Update CODA with the label
        coda_selector.add_label(chosen_idx, user_class_idx, selection_prob)

    iteration_count += 1

    # Get updated plots
    prob_plot = create_probability_chart()
    accuracy_plot = create_accuracy_chart()

    # Load next image
    next_image, status, predictions = get_next_coda_image()
    # Create HTML with inline help button for status
    status_html = f'{status} <span class="inline-help-btn" title="What is this?">?</span>'
    return result, status_html, next_image, predictions, prob_plot, accuracy_plot

def create_probability_chart():
    """Create a bar chart showing probability each model is best"""
    global coda_selector

    if coda_selector is None:
        # Fallback for initial state
        model_labels = [info['name'] for info in MODEL_INFO]
        probabilities = np.ones(len(MODEL_INFO)) / len(MODEL_INFO)  # Uniform prior
    else:
        probs_tensor = coda_selector.get_pbest()
        probabilities = probs_tensor.detach().cpu().numpy().flatten()
        model_labels = [" "*(9 if info['name']=='LAION CLIP' else 4 if info['name']=='SigLIP2' else 6) + info['name'] for info in MODEL_INFO[:len(probabilities)]]

    # Find the index of the highest probability
    best_idx = np.argmax(probabilities)

    fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150)

    # Create colors array - highlight the best model
    colors = ['orange' if i == best_idx else 'steelblue' for i in range(len(model_labels))]
    bars = ax.bar(range(len(model_labels)), probabilities, color=colors, alpha=0.7)

    # Add text above the highest bar
    ax.text(best_idx, probabilities[best_idx] + 0.0025, 'Current best guess',
            ha='center', va='bottom', fontsize=12, fontweight='bold')

    ax.set_ylabel('Probability model is best', fontsize=12)
    ax.set_title(f'CODA Model Selection Probabilities (Iteration {iteration_count})', fontsize=12)
    ax.set_ylim(np.min(probabilities) - 0.01, np.max(probabilities) + 0.02)

    # Set x-axis labels and ticks
    ax.set_xticks(range(len(model_labels)))
    ax.set_xticklabels(model_labels, fontsize=12, ha='center')

    # Add logos to x-axis
    for i, model_info in enumerate(MODEL_INFO[:len(probabilities)]):
        add_logo_to_x_axis(ax, i, model_info['logo'], model_info['name'])
    plt.yticks(fontsize=12)
    plt.tight_layout()

    # Save the figure and close it to prevent memory leaks
    temp_fig = fig
    plt.close(fig)
    return temp_fig

def create_accuracy_chart():
    """Create either confusion matrices or accuracy bar chart based on USE_CONFUSION_MATRIX toggle"""
    global coda_selector, oracle, dataset, iteration_count

    if USE_CONFUSION_MATRIX:
        return create_confusion_matrix_chart()
    else:
        return create_accuracy_bar_chart()

def create_confusion_matrix_chart():
    """Create confusion matrix estimates for each model side by side"""
    global coda_selector, iteration_count

    if coda_selector is None:
        # Fallback for initial state - return empty figure
        fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150)
        ax.text(0.5, 0.5, 'Start demo to see confusion matrices',
                ha='center', va='center', fontsize=12)
        ax.axis('off')
        plt.tight_layout()
        temp_fig = fig
        plt.close(fig)
        return temp_fig

    # Get confusion matrix estimates from CODA's Dirichlet distributions
    dirichlets = coda_selector.dirichlets  # Shape: [num_models, num_classes, num_classes]
    num_models = dirichlets.shape[0]
    num_classes = dirichlets.shape[1]

    # Convert Dirichlet parameters to expected confusion matrices
    # The expected value of a Dirichlet is alpha / sum(alpha)
    confusion_matrices = []
    for model_idx in range(num_models):
        alpha = dirichlets[model_idx].detach().cpu().numpy()
        # Normalize each row to get probabilities
        conf_matrix = alpha / alpha.sum(axis=1, keepdims=True)
        confusion_matrices.append(conf_matrix)

    # Create subplots for each model
    # Adjust width based on number of models (2.4 inches per model works well)
    fig_width = num_models * 2.4
    fig, axes = plt.subplots(1, num_models, figsize=(fig_width, 2.8), dpi=150)
    if num_models == 1:
        axes = [axes]

    # Species abbreviations for axis labels
    species_labels = ['Jag', 'Oce', 'M.L.', 'C.E.', 'Wat']

    for model_idx, (ax, conf_matrix) in enumerate(zip(axes, confusion_matrices)):
        # Apply square root scaling to make small values more visible
        # This expands small values while still showing large values
        sqrt_conf_matrix = np.sqrt(np.sqrt(np.sqrt(np.sqrt(conf_matrix))))

        # Plot confusion matrix as heatmap with sqrt-scaled values
        im = ax.imshow(sqrt_conf_matrix, cmap='Blues', aspect='auto')#, vmin=0, vmax=1)

        # Add model name as title
        model_info = MODEL_INFO[model_idx]
        ax.set_title(f"{model_info['name']}", fontsize=10, pad=5)

        # Set axis labels
        if model_idx == 0:
            ax.set_ylabel('True class', fontsize=9)
        ax.set_xlabel('Predicted', fontsize=9)

        # Set ticks with species abbreviations
        ax.set_xticks(range(num_classes))
        ax.set_yticks(range(num_classes))
        ax.set_xticklabels(species_labels[:num_classes], fontsize=8)
        ax.set_yticklabels(species_labels[:num_classes], fontsize=8)

    plt.suptitle(f"CODA's Confusion Matrix Estimates (Iteration {iteration_count})", fontsize=12, y=0.98)
    plt.tight_layout()

    temp_fig = fig
    plt.close(fig)
    return temp_fig

def create_accuracy_bar_chart():
    """Create a bar chart showing true accuracy of each model (with muted colors)"""
    global oracle, dataset

    if oracle is None or dataset is None:
        # Fallback for initial state
        fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150)
        ax.text(0.5, 0.5, 'Start demo to see model accuracies',
                ha='center', va='center', fontsize=12)
        ax.axis('off')
        plt.tight_layout()
        temp_fig = fig
        plt.close(fig)
        return temp_fig

    true_losses = oracle.true_losses(dataset.preds)
    # Convert losses to accuracies (assuming loss is 1 - accuracy)
    accuracies = (1 - true_losses).detach().cpu().numpy().flatten()
    model_labels = [" "*(9 if info['name']=='LAION CLIP' else 4 if info['name']=='SigLIP2' else 6) + info['name'] for info in MODEL_INFO[:len(accuracies)]]

    # Find the index of the highest accuracy
    best_idx = np.argmax(accuracies)

    fig, ax = plt.subplots(figsize=(8, 2.8), dpi=150)

    # Create colors array - highlight the best model with dark reddish orange, others soft pink
    colors = ['#F8481C' if i == best_idx else '#F8BBD0' for i in range(len(model_labels))]
    bars = ax.bar(range(len(model_labels)), accuracies, color=colors, alpha=0.85)

    # Add text above the highest bar
    ax.text(best_idx, accuracies[best_idx] + 0.0025, 'True best model',
            ha='center', va='bottom', fontsize=12, fontweight='bold')

    ax.set_ylabel('True (oracle) \naccuracy of model', fontsize=12)
    ax.set_title('True Model Accuracies', fontsize=12)
    y_min = np.min(accuracies) - 0.025
    y_max = np.max(accuracies) + 0.05
    ax.set_ylim(y_min, y_max)

    # Add accuracy values in the middle of the visible portion of each bar
    for i, (bar, acc) in enumerate(zip(bars, accuracies)):
        # Position text in the middle of the visible part of the bar
        text_y = (y_min + acc) / 2
        # Use black text for all bars
        text_color = '#000000'
        ax.text(i, text_y, f'{acc:.3f}',
                ha='center', va='center', fontsize=10, fontweight='bold', color=text_color)

    # Set x-axis labels and ticks
    ax.set_xticks(range(len(model_labels)))
    ax.set_xticklabels(model_labels, fontsize=12, ha='center')

    # Add logos to x-axis
    for i, model_info in enumerate(MODEL_INFO[:len(accuracies)]):
        add_logo_to_x_axis(ax, i, model_info['logo'], model_info['name'])
    plt.yticks(fontsize=12)
    plt.tight_layout()

    # Save the figure and close it to prevent memory leaks
    temp_fig = fig
    plt.close(fig)
    return temp_fig

# Create the Gradio interface
with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge", 
               theme=gr.themes.Base(),
               css="""
               .subtle-outline {
                   border: 1px solid var(--border-color-primary) !important;
                   background: var(--background-fill-secondary) !important;
                   border-radius: var(--radius-lg);
                   padding: 1rem;
               }
               .subtle-outline .flex {
                   background-color: var(--background-fill-secondary) !important;
               }

               /* Light blue background for model predictions panel */
               .model-predictions-panel {
                   border: 1px solid #6B8CBF !important;
                   background: #D6E4F5 !important;
                   border-radius: var(--radius-lg);
                   padding: 0.3rem !important;
                   margin: 0.2rem 0 !important;
               }
               .model-predictions-panel .flex {
                   background-color: #D6E4F5 !important;
                   padding: 0 !important;
                   margin: 0 !important;
               }
               .model-predictions-panel * {
                   color: #1a1a1a !important;
               }

               /* Popup overlay styles */
               .popup-overlay {
                   position: fixed;
                   top: 0;
                   left: 0;
                   width: 100%;
                   height: 100%;
                   background-color: rgba(0, 0, 0, 0.5);
                   z-index: 1000;
                   display: flex;
                   justify-content: center;
                   align-items: center;
               }

               .popup-overlay > div {
                   background: transparent !important;
                   border: none !important;
                   padding: 0 !important;
                   margin: 0 !important;
               }

               .popup-content {
                   background: var(--background-fill-primary) !important;
                   padding: 2rem !important;
                   border-radius: 1rem !important;
                   max-width: 850px;
                   width: 90%;
                   max-height: 80vh;
                   overflow-y: auto;
                   box-shadow: 0 10px 25px rgba(0, 0, 0, 0.3);
                   border: none !important;
                   margin: 0 !important;
                   color: var(--body-text-color) !important;
               }

               .popup-content > div {
                   background: var(--background-fill-primary) !important;
                   border: none !important;
                   padding: 0 !important;
                   margin: 0 !important;
                   overflow-y: visible !important;
                   max-height: none !important;
               }

               .popup-content h1,
               .popup-content h2,
               .popup-content h3,
               .popup-content p,
               .popup-content li {
                   color: var(--body-text-color) !important;
               }

               /* Ensure gradio column components don't interfere with scrolling */
               .popup-content .gradio-column {
                   overflow-y: visible !important;
                   max-height: none !important;
               }

               /* Ensure images in popup are responsive */
               .popup-content img {
                   max-width: 100% !important;
                   height: auto !important;
               }

               /* Center title */
               .text-center {
                   text-align: center !important;
               }

               /* Right align text */
               .text-right {
                   text-align: right !important;
               }

               /* Subtitle styling */
               .subtitle {
                   text-align: center !important;
                   font-weight: 300 !important;
                   color: #666 !important;
                   margin-top: -0.5rem !important;
               }

               /* Question mark icon styling */
               .panel-container {
                   position: relative;
               }

               .help-icon {
                   position: absolute;
                   top: 5px;
                   right: 5px;
                   width: 25px;
                   height: 25px;
                   background-color: #f8f9fa;
                   color: #6c757d;
                   border: 1px solid #dee2e6;
                   border-radius: 50%;
                   display: flex;
                   align-items: center;
                   justify-content: center;
                   cursor: pointer;
                   font-size: 13px;
                   font-weight: 600;
                   z-index: 10;
                   transition: all 0.2s ease;
                   box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
               }

               .help-icon:hover {
                   background-color: #e9ecef;
                   color: #495057;
                   border-color: #adb5bd;
                   box-shadow: 0 2px 6px rgba(0, 0, 0, 0.15);
               }

               /* Help popup styles */
               .help-popup-overlay {
                   position: fixed;
                   top: 0;
                   left: 0;
                   width: 100%;
                   height: 100%;
                   background-color: rgba(0, 0, 0, 0.5);
                   z-index: 1001;
                   display: flex;
                   justify-content: center;
                   align-items: center;
               }

               .help-popup-overlay > div {
                   background: transparent !important;
                   border: none !important;
                   padding: 0 !important;
                   margin: 0 !important;
               }

               .help-popup-content {
                   background: var(--background-fill-primary) !important;
                   padding: 1.5rem !important;
                   border-radius: 0.5rem !important;
                   max-width: 600px;
                   width: 90%;
                   box-shadow: 0 10px 25px rgba(0, 0, 0, 0.3);
                   border: none !important;
                   margin: 0 !important;
                   color: var(--body-text-color) !important;
               }

               .help-popup-content > div {
                   background: var(--background-fill-primary) !important;
                   border: none !important;
                   padding: 0 !important;
                   margin: 0 !important;
               }

               .help-popup-content h1,
               .help-popup-content h2,
               .help-popup-content h3,
               .help-popup-content p,
               .help-popup-content li {
                   color: var(--body-text-color) !important;
               }

               /* Inline help button */
               .inline-help-btn {
                   display: inline-block;
                   width: 20px;
                   height: 20px;
                   background-color: #f8f9fa;
                   color: #6c757d;
                   border: 1px solid #dee2e6;
                   border-radius: 50%;
                   text-align: center;
                   line-height: 18px;
                   cursor: pointer;
                   font-size: 11px;
                   font-weight: 600;
                   margin-left: 8px;
                   vertical-align: middle;
                   transition: all 0.2s ease;
                   box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
               }

               .inline-help-btn:hover {
                   background-color: #e9ecef;
                   color: #495057;
                   border-color: #adb5bd;
                   box-shadow: 0 2px 6px rgba(0, 0, 0, 0.15);
               }

               #hidden-selection-help-btn {
                   display: none;
               }

               /* Reduce spacing around status text */
               .status-text {
                   margin: 0 !important;
                   padding: 0 !important;
               }

               .status-text > div {
                   margin: 0 !important;
                   padding: 0 !important;
               }

               /* Compact model predictions panel */
               .compact-predictions {
                   line-height: 1.1 !important;
                   margin: 0 !important;
                   padding: 0.1rem !important;
               }

               .compact-predictions p {
                   margin: 0.05rem 0 !important;
               }

               .compact-predictions h3 {
                   margin: 0 0 0.1rem 0 !important;
               }

               /* Target the subtle-outline group that contains predictions */
               .subtle-outline {
                   padding: 0.3rem !important;
                   margin: 0.2rem 0 !important;
               }

               /* Target the column inside the outline */
               .subtle-outline .flex {
                   padding: 0 !important;
                   margin: 0 !important;
               }

               /* Ensure text in predictions panel is visible in dark mode */
               .subtle-outline * {
                   color: var(--body-text-color) !important;
               }
               """,
               delete_cache=(3600,3600) # once per hour - clear old javascript
               ) as demo:
    # Main page title
    gr.Markdown("# CODA: Consensus-Driven Active Model Selection", elem_classes="text-center")
    gr.Markdown("*Figure out which model is best by actively annotating data. See <a href='https://www.arxiv.org/abs/2507.23771'>the paper</a> for more details!*", elem_classes="text-center")

    # Add buttons row
    with gr.Row():
        view_guide_button = gr.Button("πŸ“– View Species Guide", variant="secondary", size="lg")
        start_over_button = gr.Button("Start Over", variant="secondary", size="lg")

    # Popup component
    with gr.Group(visible=True, elem_classes="popup-overlay") as popup_overlay:
        with gr.Group(elem_classes="popup-content"):
            # Main intro content
            intro_content = gr.Markdown("""
            # CODA: Consensus-Driven Active Model Selection

            ## Wildlife Photo Classification Challenge

            You are a wildlife ecologist who has just collected a season's worth of imagery from cameras
            deployed in Africa and Central and South America. You want to know what species occur in this imagery,
            and you hope to use a pre-trained classifier to give you answers quickly.
            But which one should you use?

            Instead of labeling a large validation set, our new method, **CODA**, enables you to perform **active model selection**.
            That is, CODA uses predictions from candidate models to guide the labeling process, querying you (a species identification expert)
            for labels on a select few images that will most efficiently differentiate between your candidate machine learning models.

            This demo lets you try CODA yourself! First, **become a species identification expert by reading our classification guide**
            so that you will be equipped to provide ground truth labels. Then, watch as CODA narrows down the best model over time
            as you provide labels for the query images. You will see that with your input CODA is able to identify the best model candidate
            with as few as ten (correctly) labeled images.
                 
            """)

            # Species guide content (initially hidden)
            with gr.Column(visible=False) as species_guide_content:
                create_species_guide_content()

            # Add spacing before buttons
            gr.HTML("<div style='margin-top: 0.1em;'></div>")

            with gr.Row():
                back_button = gr.Button("← Back to Intro", variant="secondary", size="lg", visible=False)
                guide_button = gr.Button("View Species Classification Guide", variant="primary", size="lg")
                popup_start_button = gr.Button("Start Demo", variant="secondary", size="lg")

    # Help popups for panels
    with gr.Group(visible=False, elem_classes="help-popup-overlay") as prob_help_popup:
        with gr.Group(elem_classes="help-popup-content"):
            gr.Markdown("""
            ## CODA Model Selection Probabilities

            This chart shows CODA's current confidence in each candidate classifier being the best performer.

            **How to read this chart:**
            - Each bar represents one of the candidate machine learning classifiers
            - The height of each bar shows the probability (0-100%) that this model is the best, according to CODA
            - The orange bar indicates CODA's current best guess
            - As you provide more labels, CODA updates these probabilities

            **What you'll see:**
            - CODA initializes these probabilities based on each classifier's agreement with the consensus votes of *all* classifiers, 
                        providing informative priors
            - As you label images, some models will gain confidence while others lose it
            - The goal is for one model to clearly emerge as the winner
                        
            More details can be found in [the paper](https://www.arxiv.org/abs/2507.23771)!
            
            **What models are these?**
                        
            For this demo, we selected 5 zero-shot classifiers that would be reasonable choices for someone who 
                        wanted to classify wildlife imagery. The models are: facebook/PE-Core-L14-336, 
                        google/siglip2-so400m-patch16-naflex, openai/clip-vit-large-patch14, imageomics/bioclip-2, and
                        laion/CLIP-ViT-L-14-laion2B-s32B-b82K. Our goal is not to make any general claims about the performance
                        of these models but rather to provide a realistic set of candidates for demonstrating CODA.
                        
            """)
            gr.HTML("<div style='margin-top: 0.1em;'></div>")
            prob_help_close = gr.Button("Close", variant="secondary")

    with gr.Group(visible=False, elem_classes="help-popup-overlay") as acc_help_popup:
        with gr.Group(elem_classes="help-popup-content"):
            gr.Markdown("""
            ## True Model Accuracies

            This chart shows the actual performance of each model on the complete dataset (only possible with oracle knowledge).

            **How to read this chart:**
            - Each bar represents the true accuracy of one model
            - The red bar shows the actual best-performing model
            - This information is hidden from CODA during the selection process
            - You can compare this with CODA's estimates to see how well it's doing

            **Why this matters:**
            - This represents the "ground truth" that CODA is trying to discover
            - In real scenarios, you wouldn't know these true accuracies beforehand
            - The demo shows these to illustrate how CODA's estimates align with reality

            """)
            acc_help_close = gr.Button("Close", variant="secondary")

    with gr.Group(visible=False, elem_classes="help-popup-overlay") as selection_help_popup:
        with gr.Group(elem_classes="help-popup-content"):
            gr.Markdown("""
            ## How CODA selects images for labeling

            CODA selects images that best differentiate top-performing classifiers from each other. It
                        does this by constructing a probabilistic model of which classifier is best (see
                        the plot at the bottom-left). Each iteration, CODA selects an image to be labeled
                        based on how much a label for that image is expected to affect the probabilistic model. 

            Intuitively, CODA will select images where the top classifiers disagree, since knowing the ground truth for these images will provide 
                        the most information about which classifier is best overall.
                        
            More details can be found in [the paper](https://www.arxiv.org/abs/2507.23771)!
                        
            **What data is this?**
                        
            We selected a subset of 5 species from the iWildcam dataset, and subsampled a dataset of ~500 images for this demo.
                        Each refresh will generate a slightly different subset, leading to slightly different model selection 
                        performance.

            """)
            gr.HTML("<div style='margin-top: 0.1em;'></div>")

            selection_help_close = gr.Button("Close", variant="secondary")

    # Species guide popup during demo
    with gr.Group(visible=False, elem_classes="popup-overlay") as species_guide_popup:
        with gr.Group(elem_classes="popup-content"):
            create_species_guide_content()

            # Add spacing before button
            gr.HTML("<div style='margin-top: 0.1em;'></div>")

            species_guide_close = gr.Button("Go back to demo", variant="primary", size="lg")

    # Status display with help button and result on same row
    selection_help_button = gr.Button("", visible=False, elem_id="hidden-selection-help-btn")

    with gr.Row():
        with gr.Column(scale=3):
            status_with_help = gr.HTML("", visible=True, elem_classes="status-text")
        with gr.Column(scale=2):
            result_display = gr.Markdown("", visible=True, elem_classes="text-right")

    with gr.Row():
        image_display = gr.Image(
            label="Identify this animal:",
            value=None,
            height=400,
            width=550,
            elem_id="main-image-display"
        )

    gr.Markdown("### What species is this?")

    with gr.Row():
        # Create buttons for each species
        species_buttons = []
        for species_name in SPECIES_MAP.values():
            btn = gr.Button(species_name, variant="primary", size="lg")
            species_buttons.append(btn)

        # Add "I don't know" button
        idk_button = gr.Button("I don't know", variant="primary", size="lg")

    # Model predictions panel (full width, single line)
    with gr.Group(elem_classes="model-predictions-panel"):
        with gr.Column(elem_classes="flex items-center justify-center h-full"):
            model_predictions_display = gr.Markdown(
                "### Model Predictions\n\n*Start the demo to see model votes!*",
                show_label=False,
                elem_classes="text-center compact-predictions"
            )

    # Two panels with bar charts
    with gr.Row():
        with gr.Column(scale=1):
            with gr.Group(elem_classes="panel-container"):
                prob_help_button = gr.Button("?", elem_classes="help-icon", size="sm")
                prob_plot = gr.Plot(
                    value=None,
                    show_label=False
                )
        with gr.Column(scale=1):
            # with gr.Group(elem_classes="panel-container"):
                # acc_help_button = gr.Button("?", elem_classes="help-icon", size="sm")

                # with gr.Row(elem_classes="flex-grow") as accuracy_title_row:
                #     gr.Markdown("""
                #     ## True Model Accuracies
                #
                #     Click below to view the true model accuracies.
                #
                #     Note you wouldn't be able to do this in the real model selection setting!
                #
                #     """,
                #     elem_classes="text-center compact-predictions flex-grow")

                # # Centered reveal button (initially visible)
                # with gr.Group(visible=True) as accuracy_hidden_group:
                #     with gr.Column(elem_classes="flex items-center justify-center h-full"):
                #         reveal_accuracy_button = gr.Button(
                #             "Reveal model accuracies",
                #             variant="primary"
                #         )

                # # Accuracy plot (initially hidden)
                # accuracy_plot = gr.Plot(
                #     value=create_accuracy_chart(),
                #     show_label=False,
                #     visible=False
                # )
            accuracy_plot = gr.Plot(
                value=create_accuracy_chart(),
                show_label=False,
                visible=False
            )
            with gr.Group(visible=True, elem_classes="subtle-outline accuracy-hidden-panel") as hidden_group:
                with gr.Column(elem_classes="flex items-center justify-center h-full"):

                    # example of how to add spacing:
                    # gr.HTML("<div style='margin-top: 2.9em;'></div>")

                    hidden_text0 = gr.Markdown("""
                        ## True model performance is hidden
                        """,
                        elem_classes="text-center",)

                    gr.HTML("<div style='margin-top: 0.25em;'></div>")

                    hidden_text1 = gr.Markdown("""
                        In this problem setting the true model performance is assumed to be unknown (that is why we want to perform model selection!)

                        However, for this demo, we have computed the actual accuracies of each model in order to evaluate CODA's performance.

                        """,
                        elem_classes="text-center",
                    )
                    gr.HTML("<div style='margin-top: 0.25em;'></div>")

                    # with gr.Row():
                    #     with gr.Column(scale=2):
                    #         pass
                    #     with gr.Column(scale=1, min_width=100):
                    #         reveal_accuracy_button = gr.Button(
                    #             "πŸ” Reveal",
                    #             variant="secondary",
                    #             size="lg"
                    #         )
                    #     with gr.Column(scale=2):
                    #         pass
                    with gr.Row():
                        reveal_accuracy_button = gr.Button(
                                "πŸ” Reveal True Model Accuracies",
                                variant="secondary",
                                size="lg"
                            )

                    # example of how to add spacing:
                    # gr.HTML("<div style='margin-top: 2.9em;'></div>")

    # Add debug delete button (only visible in debug mode)
    if DEBUG_MODE:
        delete_button = gr.Button("πŸ—‘οΈ Delete Current Image", variant="stop", size="lg")

    # Set up button interactions
    def start_demo():
        global iteration_count, coda_selector, dataset, oracle, image_filenames

        # Reset the demo state
        iteration_count = 0

        # Keep resampling until we get a subset where the initial best model (by CODA) is NOT the true best model
        while True:
            # Subsample dataset for this user
            subsampled_indices = []
            for class_idx in sorted(full_class_to_indices.keys()):
                indices = full_class_to_indices[class_idx]
                sampled = np.random.choice(indices, size=min_class_size, replace=False)
                subsampled_indices.extend(sampled.tolist())

            # Sort indices to maintain order
            subsampled_indices.sort()

            # Create subsampled dataset for this user
            subsampled_preds = full_preds[:, subsampled_indices, :]
            subsampled_labels = full_labels[subsampled_indices]
            image_filenames = [full_image_filenames[idx] for idx in subsampled_indices]

            # Create Dataset object with subsampled data
            dataset = Dataset.__new__(Dataset)
            dataset.preds = subsampled_preds
            dataset.labels = subsampled_labels
            dataset.device = device

            # Create oracle and CODA selector for this user
            oracle = Oracle(dataset, loss_fn=loss_fn)
            coda_selector = CODA(dataset,
                                 learning_rate=DEMO_LEARNING_RATE,
                                 alpha=DEMO_ALPHA)

            # Check which model is initially best according to CODA
            probs_tensor = coda_selector.get_pbest()
            probabilities = probs_tensor.detach().cpu().numpy().flatten()
            coda_best_idx = np.argmax(probabilities)

            # Get true best model according to oracle
            true_losses = oracle.true_losses(dataset.preds)
            true_accuracies = (1 - true_losses).detach().cpu().numpy().flatten()
            true_best_idx = np.argmax(true_accuracies)

            # Accept this subset if CODA's initial best is NOT the true best
            if coda_best_idx != true_best_idx:
                break
            # Otherwise, loop and resample

        image, status, predictions = get_next_coda_image()
        prob_plot = create_probability_chart()
        acc_plot = create_accuracy_chart()
        # Create HTML with inline help button
        status_html = f'{status} <span class="inline-help-btn" title="What is this?">?</span>'
        return image, status_html, predictions, prob_plot, acc_plot, gr.update(visible=False), "", gr.update(visible=True)

    def start_over():
        global iteration_count, coda_selector, dataset, oracle, image_filenames

        # Reset the demo state
        iteration_count = 0

        # Keep resampling until we get a subset where the initial best model (by CODA) is NOT the true best model
        while True:
            # Subsample dataset for this user (new random subsample)
            subsampled_indices = []
            for class_idx in sorted(full_class_to_indices.keys()):
                indices = full_class_to_indices[class_idx]
                sampled = np.random.choice(indices, size=min_class_size, replace=False)
                subsampled_indices.extend(sampled.tolist())

            # Sort indices to maintain order
            subsampled_indices.sort()

            # Create subsampled dataset for this user
            subsampled_preds = full_preds[:, subsampled_indices, :]
            subsampled_labels = full_labels[subsampled_indices]
            image_filenames = [full_image_filenames[idx] for idx in subsampled_indices]

            # Create Dataset object with subsampled data
            dataset = Dataset.__new__(Dataset)
            dataset.preds = subsampled_preds
            dataset.labels = subsampled_labels
            dataset.device = device

            # Create oracle and CODA selector for this user
            oracle = Oracle(dataset, loss_fn=loss_fn)
            coda_selector = CODA(dataset,
                                 learning_rate=DEMO_LEARNING_RATE,
                                 alpha=DEMO_ALPHA)

            # Check which model is initially best according to CODA
            probs_tensor = coda_selector.get_pbest()
            probabilities = probs_tensor.detach().cpu().numpy().flatten()
            coda_best_idx = np.argmax(probabilities)

            # Get true best model according to oracle
            true_losses = oracle.true_losses(dataset.preds)
            true_accuracies = (1 - true_losses).detach().cpu().numpy().flatten()
            true_best_idx = np.argmax(true_accuracies)

            # Accept this subset if CODA's initial best is NOT the true best
            if coda_best_idx != true_best_idx:
                break
            # Otherwise, loop and resample

        # Reset all displays
        prob_plot = create_probability_chart()
        acc_plot = create_accuracy_chart()
        return None, "Demo reset. Click 'Start CODA Demo' to begin.", "### Model Predictions\n\n*Start the demo to see model votes!*", prob_plot, acc_plot, "", gr.update(visible=True), gr.update(visible=False)

    def show_species_guide():
        # Show species guide, hide intro content, show back button, hide guide button
        return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)

    def show_intro():
        # Show intro content, hide species guide, hide back button, show guide button
        return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)

    def show_prob_help():
        return gr.update(visible=True)

    def hide_prob_help():
        return gr.update(visible=False)

    def show_acc_help():
        return gr.update(visible=True)

    def hide_acc_help():
        return gr.update(visible=False)

    def show_selection_help():
        return gr.update(visible=True)

    def hide_selection_help():
        return gr.update(visible=False)

    def show_species_guide_popup():
        return gr.update(visible=True)

    def hide_species_guide_popup():
        return gr.update(visible=False)

    def reveal_accuracies():
        """Reveal accuracy plot and hide the hidden group"""
        return gr.update(visible=True), gr.update(visible=False)

    popup_start_button.click(
        fn=start_demo,
        outputs=[image_display, status_with_help, model_predictions_display, prob_plot, accuracy_plot, popup_overlay, result_display, selection_help_button],
        js="""
        () => {
            console.log('=== Panel Height Matching (Dynamic) ===');

            function matchPanelHeights() {
                const panels = document.querySelectorAll('.panel-container');
                console.log('Found .panel-container elements:', panels.length);
                const leftPanel = panels[0]; // prob_plot panel
                const rightPanel = document.querySelector('.accuracy-hidden-panel'); // hidden_group panel

                console.log('Left panel (prob):', leftPanel);
                console.log('Right panel (hidden):', rightPanel);

                if (leftPanel && rightPanel) {
                    const leftHeight = leftPanel.offsetHeight;
                    const rightHeight = rightPanel.offsetHeight;
                    const diff = leftHeight - rightHeight;

                    console.log('Left panel height:', leftHeight);
                    console.log('Right panel height:', rightHeight);
                    console.log('Height difference:', diff);

                    if (diff > 0) {
                        console.log('Setting right panel min-height to:', leftHeight + 'px');

                        rightPanel.style.minHeight = leftHeight + 'px';
                        rightPanel.style.display = 'flex';
                        rightPanel.style.flexDirection = 'column';
                        rightPanel.style.justifyContent = 'center';

                        console.log('Applied min-height and flex centering');
                        return true; // Success
                    } else {
                        console.log('No height adjustment needed (diff <= 0)');
                        return true; // Success
                    }
                } else {
                    console.log('Panels not ready yet');
                    return false; // Not ready
                }
            }

            // Check every 50ms for 3 seconds to catch multiple height changes
            let attempts = 0;
            const maxAttempts = 60; // 60 * 50ms = 3 seconds to catch both height changes
            const checkInterval = setInterval(() => {
                attempts++;
                console.log('Attempt', attempts, 'to match heights');

                matchPanelHeights(); // Always try, don't stop early

                if (attempts >= maxAttempts) {
                    console.log('Finished checking after 3 seconds');
                    clearInterval(checkInterval);
                }
            }, 50); // Check every 50ms
        }
        """
    )

    start_over_button.click(
        fn=start_over,
        outputs=[image_display, status_with_help, model_predictions_display, prob_plot, accuracy_plot, result_display, popup_overlay, selection_help_button],
        js="""
        () => {
            console.log('=== Panel Height Matching (Dynamic - Start Over) ===');

            function matchPanelHeights() {
                const panels = document.querySelectorAll('.panel-container');
                console.log('Found .panel-container elements:', panels.length);
                const leftPanel = panels[0]; // prob_plot panel
                const rightPanel = document.querySelector('.accuracy-hidden-panel'); // hidden_group panel

                console.log('Left panel (prob):', leftPanel);
                console.log('Right panel (hidden):', rightPanel);

                if (leftPanel && rightPanel) {
                    const leftHeight = leftPanel.offsetHeight;
                    const rightHeight = rightPanel.offsetHeight;
                    const diff = leftHeight - rightHeight;

                    console.log('Left panel height:', leftHeight);
                    console.log('Right panel height:', rightHeight);
                    console.log('Height difference:', diff);

                    if (diff > 0) {
                        console.log('Setting right panel min-height to:', leftHeight + 'px');

                        rightPanel.style.minHeight = leftHeight + 'px';
                        rightPanel.style.display = 'flex';
                        rightPanel.style.flexDirection = 'column';
                        rightPanel.style.justifyContent = 'center';

                        console.log('Applied min-height and flex centering');
                        return true; // Success
                    } else {
                        console.log('No height adjustment needed (diff <= 0)');
                        return true; // Success
                    }
                } else {
                    console.log('Panels not ready yet');
                    return false; // Not ready
                }
            }

            // Check every 50ms for 3 seconds to catch multiple height changes
            let attempts = 0;
            const maxAttempts = 60; // 60 * 50ms = 3 seconds to catch both height changes
            const checkInterval = setInterval(() => {
                attempts++;
                console.log('Attempt', attempts, 'to match heights');

                matchPanelHeights(); // Always try, don't stop early

                if (attempts >= maxAttempts) {
                    console.log('Finished checking after 3 seconds');
                    clearInterval(checkInterval);
                }
            }, 50); // Check every 50ms
        }
        """
    )

    guide_button.click(
        fn=show_species_guide,
        outputs=[intro_content, species_guide_content, back_button, guide_button]
    )

    back_button.click(
        fn=show_intro,
        outputs=[intro_content, species_guide_content, back_button, guide_button]
    )

    # Help popup handlers
    prob_help_button.click(
        fn=show_prob_help,
        outputs=[prob_help_popup]
    )

    prob_help_close.click(
        fn=hide_prob_help,
        outputs=[prob_help_popup]
    )

    # acc_help_button.click(
    #     fn=show_acc_help,
    #     outputs=[acc_help_popup]
    # )

    acc_help_close.click(
        fn=hide_acc_help,
        outputs=[acc_help_popup]
    )

    selection_help_button.click(
        fn=show_selection_help,
        outputs=[selection_help_popup]
    )

    selection_help_close.click(
        fn=hide_selection_help,
        outputs=[selection_help_popup]
    )

    # Reveal accuracy button handler
    reveal_accuracy_button.click(
        fn=reveal_accuracies,
        outputs=[accuracy_plot, hidden_group]
    )

    # Species guide popup handlers
    view_guide_button.click(
        fn=show_species_guide_popup,
        outputs=[species_guide_popup]
    )

    species_guide_close.click(
        fn=hide_species_guide_popup,
        outputs=[species_guide_popup]
    )

    for btn in species_buttons:
        btn.click(
            fn=check_answer,
            inputs=[gr.State(btn.value)],
            outputs=[result_display, status_with_help, image_display, model_predictions_display, prob_plot, accuracy_plot]
        )

    idk_button.click(
        fn=check_answer,
        inputs=[gr.State("I don't know")],
        outputs=[result_display, status_with_help, image_display, model_predictions_display, prob_plot, accuracy_plot]
    )

    # Wire up delete button in debug mode
    if DEBUG_MODE:
        delete_button.click(
            fn=delete_current_image,
            outputs=[result_display, image_display, model_predictions_display, prob_plot, accuracy_plot, status_with_help]
        )

    # Add JavaScript to handle inline help button clicks and dynamic image sizing
    demo.load(
        lambda: None,
        outputs=[],
        js="""
        () => {
            // Handle inline help button clicks
            setTimeout(() => {
                document.addEventListener('click', function(e) {
                    if (e.target && e.target.classList.contains('inline-help-btn')) {
                        e.preventDefault();
                        e.stopPropagation();
                        const hiddenBtn = document.getElementById('hidden-selection-help-btn');
                        if (hiddenBtn) {
                            hiddenBtn.click();
                        }
                    }
                });
            }, 100);

            // Dynamic image sizing (NEW VERSION)
            console.log('=== IMAGE SIZING V2 LOADED ===');

            function adjustImageSize() {
                const imageContainer = document.getElementById('main-image-display');
                if (!imageContainer) {
                    console.log('[V2] Image container not found');
                    return false;
                }

                const viewportHeight = window.innerHeight;
                const docHeight = document.documentElement.scrollHeight;
                const currentImageHeight = imageContainer.offsetHeight;

                // Calculate how much we're overflowing
                const overflow = docHeight - viewportHeight;

                // If we're not overflowing, increase image size
                // If we are overflowing, decrease image size by the overflow amount
                const adjustment = -overflow - 30; // Keep padding below bottom button
                const targetHeight = currentImageHeight + adjustment;

                console.log('[V2] viewport:', viewportHeight, 'docHeight:', docHeight, 'currentImg:', currentImageHeight, 'overflow:', overflow, 'target:', targetHeight);

                // Only apply if reasonable
                if (targetHeight > 300 && targetHeight < viewportHeight - 100) {
                    imageContainer.style.height = targetHeight + 'px';
                    imageContainer.style.maxHeight = targetHeight + 'px';
                    console.log('[V2] Set image height to:', targetHeight + 'px');
                    return true;
                }

                return false;
            }

            // Run after initial load
            setTimeout(adjustImageSize, 500);

            // Run periodically for first 5 seconds to catch layout changes
            let attempts = 0;
            const interval = setInterval(() => {
                attempts++;
                adjustImageSize();
                if (attempts >= 50) { // 50 * 100ms = 5 seconds
                    clearInterval(interval);
                }
            }, 100);

            // Re-adjust on window resize
            window.addEventListener('resize', adjustImageSize);
        }
        """,
    )

if __name__ == "__main__":
    demo.launch(
        # share=True,
        # server_port=7861,
        allowed_paths=["/"],
    )