justinkay commited on
Commit
3aaf8da
·
1 Parent(s): 47b7c6b

Dynamic subsampling

Browse files
Files changed (1) hide show
  1. app.py +38 -2
app.py CHANGED
@@ -63,9 +63,45 @@ print(f"Loaded {len(images_data)} images for the quiz")
63
  with open('images.txt', 'r') as f:
64
  image_filenames = [line.strip() for line in f.readlines() if line.strip()]
65
 
66
- # Initialize CODA
67
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68
- dataset = Dataset("iwildcam_demo.pt", device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  loss_fn = LOSS_FNS['acc']
70
  oracle = Oracle(dataset, loss_fn=loss_fn)
71
 
 
63
  with open('images.txt', 'r') as f:
64
  image_filenames = [line.strip() for line in f.readlines() if line.strip()]
65
 
66
+ # Initialize CODA with subsampled dataset
67
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68
+
69
+ # Load full dataset
70
+ full_preds = torch.load("iwildcam_demo.pt").to(device)
71
+ full_labels = torch.load("iwildcam_demo_labels.pt").to(device)
72
+
73
+ # Subsample to balance classes
74
+ from collections import defaultdict
75
+ class_to_indices = defaultdict(list)
76
+ for idx, label in enumerate(full_labels):
77
+ class_idx = label.item()
78
+ class_to_indices[class_idx].append(idx)
79
+
80
+ # Find minimum class size
81
+ min_class_size = min(len(indices) for indices in class_to_indices.values())
82
+ print(f"Subsampling to {min_class_size} images per class (total: {min_class_size * len(class_to_indices)} images)")
83
+
84
+ # Randomly subsample each class
85
+ subsampled_indices = []
86
+ for class_idx in sorted(class_to_indices.keys()):
87
+ indices = class_to_indices[class_idx]
88
+ sampled = np.random.choice(indices, size=min_class_size, replace=False)
89
+ subsampled_indices.extend(sampled.tolist())
90
+
91
+ # Sort indices to maintain order
92
+ subsampled_indices.sort()
93
+
94
+ # Create subsampled dataset
95
+ subsampled_preds = full_preds[:, subsampled_indices, :]
96
+ subsampled_labels = full_labels[subsampled_indices]
97
+ image_filenames = [image_filenames[idx] for idx in subsampled_indices]
98
+
99
+ # Create Dataset object with subsampled data
100
+ dataset = Dataset.__new__(Dataset)
101
+ dataset.preds = subsampled_preds
102
+ dataset.labels = subsampled_labels
103
+ dataset.device = device
104
+
105
  loss_fn = LOSS_FNS['acc']
106
  oracle = Oracle(dataset, loss_fn=loss_fn)
107