File size: 7,765 Bytes
ae522ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from datetime import datetime
from tqdm import tqdm
import numpy as np
import pandas as pd
import cv2
from skimage import measure
import mlflow

from mwm import logger
from mwm.constants import *
from mwm.utils.common import read_yaml, load_json
from mwm.config.configuration import get_params
from mwm.components.model_architecture import *
from mwm.components.dataset import *
from mwm.components.image_processing import (
    read_image_png, 
    post_processing_watershed_2ch, 
    post_processing_denoise_2ch
)
from mwm.components.metrics import iou_object_labels, measures_at


class EvaluationProcessor2Channel:
    def __init__(self):
        self.results = []
        self.thresholds = np.round(np.arange(0.5, 1.0, 0.05), 2)


    def prep_evaluation(self, prediction, mask_path):
        """
        Args:
            - prediction: after denoise operations
            - mask_path: path to the original
        """

        # Mask to label
        self.sample_name = os.path.basename(mask_path).split(".")[0]
        mask_raw = read_image_png(mask_path)
        self.labels_gt = measure.label(mask_raw[:,:,0], background=0)

        # Prediction to label
        self.labels_pred = post_processing_watershed_2ch(prediction) # key post-processing logic
            

    def update_metrics(self):
        iou_matrix = iou_object_labels(self.labels_gt, self.labels_pred)
        if iou_matrix.size == 0:
            mean_object_iou = 0.0
        else:
            mean_object_iou = np.max(iou_matrix, axis=0).mean()
        
        # Calculate F1 score at all thresholds
        for t in self.thresholds:
            f1, precision, recall, jaccard, tp, fp, fn = measures_at(t, iou_matrix)
            res = {
                "Sample": self.sample_name, 
                "Threshold": t, 
                "F1": f1, 
                "Precision": precision, 
                "Recall": recall, 
                "Jaccard": jaccard, 
                "MeanObjectIoU": mean_object_iou,
                "TP": tp, 
                "FP": fp, 
                "FN": fn
                }
            self.results.append(res)
    

    def log_key_metrics_to_mlflow(self):
        df = pd.DataFrame(self.results)
        df_agg = df.drop(columns=["Sample"]).groupby("Threshold").mean().reset_index().sort_values("Threshold", ascending=True)
        df_agg_list = df_agg.to_dict("records")
        for row_dict in df_agg_list:
            metrics = {k: v for k, v in row_dict.items() if k != "Threshold"}
            mlflow.log_metrics(metrics, step=int(row_dict["Threshold"]*100))
        mlflow.log_metric("MAF1", df_agg["F1"].mean())
        mlflow.log_metric("MAPrecision", df_agg["Precision"].mean())
        mlflow.log_metric("MARecall", df_agg["Recall"].mean())
        mlflow.log_metric("MAJaccard", df_agg["Jaccard"].mean())
        mlflow.log_param("thresholds", self.thresholds)


    def save_results(self, output_path):
        df = pd.DataFrame(self.results)
        df.to_csv(output_path, index=False)


class Evaluator():
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH,
    ):
        self.config = read_yaml(config_filepath)
        self.params = get_params(params_filepath, "evaluation")

        # Make & load model
        self.model_path = self.params.model_file_path
        if not os.path.exists(self.model_path):
            logger.info("Please enter a valid model file path in params.json.\nPath can be found in mlfow trainning logs.\nExample: 'artifacts/models/model_epoch1_20250224_112303.pth'")
            logger.error(f"Model file not found: {self.model_path}")
            raise FileNotFoundError(f"Model file not found: {self.model_path}")
        elif not self.model_path.endswith(".pth"):
            logger.error(f"Check if model file ends with '.pth'. Invalid model file: {self.model_path}")
            raise ValueError(f"Invalid model file: {self.model_path}")
        
        self.model = make_model(self.params.network, encoder_weights=None)
        self.model.load_state_dict(torch.load(self.model_path, map_location=torch.device("cpu")))
        logger.info(f"Model loaded from: {self.model_path}")

        # Make dataset
        self.image_dir = os.path.join(self.config.data_ingestion.unzip_dir, self.config.dataset.image_dir)
        self.mask_dir = os.path.join(self.config.data_ingestion.unzip_dir, self.config.dataset.mask_dir)

        with open(os.path.join(self.config.data_ingestion.unzip_dir, self.config.dataset.test_set_file), "r") as f:
            self.image_list_test = f.read().splitlines()
        self.test_dataset = make_dataset(
            self.params.dataset, 
            self.image_dir, 
            self.mask_dir,
            None, 
            self.image_list_test,
            "test",
            self.params.image_size
        )

        # Make save path (optional)
        if self.params.save_predictions:
            model_name = os.path.basename(self.model_path).split(".")[0]
            self.save_dir = os.path.join(self.config.evaluation.evaluation_dir, f"{model_name}_predictions")
            os.makedirs(self.save_dir, exist_ok=True)


    def handle_device(self):
        # Move model to GPU if available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)


    def evaluate(self):
        self.evaluate_processor = EvaluationProcessor2Channel()
        # Set model to evaluation mode
        self.model.eval()

        # Evaluate individual sample without batching
        batch_progress_bar = tqdm(self.test_dataset, desc=f"Evaluation", leave=True)
        with torch.no_grad():
            for image, _, _ in batch_progress_bar: # e.g. image shape: torch.Size([3, 3, 3, 256, 256])

                mask_path = self.test_dataset.get_mask_path()
                
                # Handle device & batching
                _, _, c, h, w = image.shape
                image = image.reshape(-1, c, h, w).to(self.device) # torch.Size([9, 3, 256, 256]) 

                # Get prediction
                output = self.model(image).squeeze()

                # Move to CPU and to numpy
                output = output.cpu().numpy() # <class 'numpy.ndarray'>, shape: (9, 2, 256, 256)

                # <class 'numpy.ndarray'>, probabilities, 2 channels, cut to original image size
                output_stitched = self.test_dataset.transform.reconstruct_full_frame(output)
                output_stitched = post_processing_denoise_2ch(output_stitched)

                # Evaluate
                self.evaluate_processor.prep_evaluation(output_stitched, mask_path)
                self.evaluate_processor.update_metrics()

                if self.params.save_predictions:
                    save_path = os.path.join(self.save_dir, os.path.basename(mask_path))

                    empty_channel = np.zeros_like(output_stitched[:,:,0])
                    mask_pred_uint8 = np.stack([output_stitched[:,:,1], empty_channel, output_stitched[:,:,0]], axis=-1) * 255 # cv2 uses BGR
                    cv2.imwrite(save_path, mask_pred_uint8)

        mlflow.set_experiment("Evaluation")
        with mlflow.start_run():

            self.evaluate_processor.log_key_metrics_to_mlflow()

            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            save_path = os.path.join(
                self.config.evaluation.evaluation_dir,
                f"evaluation_{timestamp}_on_{os.path.basename(self.model_path).split('.')[0]}.csv"
            )
            self.evaluate_processor.save_results(save_path)

            mlflow.log_param("evaluation_save_path", save_path)
            mlflow.log_param("model_path", self.model_path)