File size: 3,994 Bytes
ae522ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy as np
import torch

def iou_base(preds, masks, threshold=0.5, eps=1e-6):
    """
    IoU by definition regardless of shape.

    Args:
        - preds: Predictions from the model
        - masks: Ground truth masks
        - threshold: Threshold for binarization
        - eps: Small constant to prevent division by zero
    Output:
        - (Float): Intersection over Union (IoU) score
    """
    # Flatten everything
    preds = preds.reshape(-1)
    masks = masks.reshape(-1)

    preds = (preds > threshold).float()
    intersection = torch.sum(preds * masks)
    union = torch.sum(preds) + torch.sum(masks) - intersection
    return (intersection + eps) / (union + eps)


def iou_list(preds, masks, threshold=0.5, eps=1e-6):
    """
    IoU image-wise when preds and masks are batched.

    Args:
        - preds: Predictions from the model (B x H x W)
        - masks: Ground truth masks (B x H x W)
        - threshold: Threshold for binarization
        - eps: Small constant to prevent division by zero
    Output:
        - (List): List of IoU scores for each image
    """
    iou_list = []
    for i in range(preds.shape[0]):
        iou = iou_base(preds[i,:,:], masks[i,:,:], threshold, eps)
        iou_list.append(iou)
    return iou_list


def iou_object_labels(ground_truth, prediction):
    """
    Compute the pixel-level segmentation IoU for multiple objects in one image sample.
    Args:
        - ground_truth: 2D numpy array (int), ground truth mask, each object identified by pixels with the same unique value
        - prediction: 2D numpy array (int), predicted mask, each object identified by pixels with the same unique value
    Returns:
        - iou_matrix: 2D numpy array (float), IoU matrix for each object pair
    """
    
    # Count objects
    true_objects = len(np.unique(ground_truth))
    pred_objects = len(np.unique(prediction))
    
    # Compute intersection
    h = np.histogram2d(ground_truth.flatten(), prediction.flatten(), bins=(true_objects,pred_objects))
    intersection = h[0]
    
    # Area of objects
    area_true = np.histogram(ground_truth, bins=true_objects)[0]
    area_pred = np.histogram(prediction, bins=pred_objects)[0]
    
    # Calculate union
    area_true = np.expand_dims(area_true, -1)
    area_pred = np.expand_dims(area_pred, 0)
    union = area_true + area_pred - intersection
    
    # Exclude background from the analysis
    intersection = intersection[1:,1:]
    union = union[1:,1:]
    
    # Compute Intersection over Union
    union[union == 0] = 1e-9
    iou_matrix = intersection/union
    
    return iou_matrix


def measures_at(threshold, iou_matrix):
    """
    Object level evaluation metrics at a given IoU threshold.
    Args:
        - threshold: float, IoU threshold
        - iou_matrix: 2D numpy array (float), IoU matrix for each object pair
    Returns:
        - f1: float, F1 score
        - precision: float, precision
        - recall: float, recall
        - jaccard: float, Jaccard index
        - TP: int, number of true positives objects identified in a image sample
        - FP: int, number of false positives objects identified in a image sample
        - FN: int, number of false negatives objects identified in a image sample
    """
    
    matches = iou_matrix > threshold
    
    true_positives = np.sum(matches, axis=1) == 1   # Correct objects
    false_positives = np.sum(matches, axis=0) == 0  # Extra objects
    false_negatives = np.sum(matches, axis=1) == 0  # Missed objects
    
    assert np.all(np.less_equal(true_positives, 1))
    assert np.all(np.less_equal(false_positives, 1))
    assert np.all(np.less_equal(false_negatives, 1))
    
    TP, FP, FN = np.sum(true_positives), np.sum(false_positives), np.sum(false_negatives)
    
    f1 = 2*TP / (2*TP + FP + FN + 1e-9)
    precision = TP / (TP + FP + 1e-9)
    recall = TP / (TP + FN + 1e-9)
    jaccard = TP / (TP + FP + FN + 1e-9)
    
    return f1, precision, recall, jaccard, TP, FP, FN