File size: 4,803 Bytes
ae522ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset
import albumentations as A
from mwm.components.image_processing import (
    normalize_image, 
    get_gt_mask_png, 
    read_image_png, 
    TestTimeTransform
)
from mwm import logger


def _get_transform(image_size, mode, overlap=0.1):
    if mode == "train":
        return A.Compose([
            A.RandomCrop(width=image_size[0], height=image_size[1]),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5)
        ],
        additional_targets={'sdm': 'mask'}
        )
    elif mode == "val":
        return A.Compose([
            A.CenterCrop(width=image_size[0], height=image_size[1])
        ],
        additional_targets={'sdm': 'mask'}
        )
    elif mode == "test":
        return TestTimeTransform(width=image_size[0], height=image_size[1])
    else:
        logger.error(f"Invalid mode: {mode}")
        raise ValueError(f"Invalid mode: {mode}")


# Utils for custom datasets
def make_dataset(dataset_name, image_dir, mask_dir, sdm_dir, image_list, mode, image_size=[256, 256]):

    transform = _get_transform(image_size, mode)

    if dataset_name == "seg_2ch":
        dataset = Seg2ChannelDataset(image_dir, mask_dir, sdm_dir, image_list, transform)
        logger.info(f"Dataset: {dataset_name} successfully processed. ")
        return dataset
    else:
        logger.error(f"Invalid dataset: {dataset_name}")
        raise ValueError(f"Invalid dataset: {dataset_name}")


# Classes for custom datasets
class Seg2ChannelDataset(Dataset):
    def __init__(self, image_dir, mask_dir, sdm_dir, image_list, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.sdm_dir = sdm_dir
        self.image_list = image_list # This is when image_list is pre-selected for train/val/test split
        self.transform = transform

        # For info retrieval where needed (e.g. at evaluation)
        self.this_image_path = ""
        self.this_mask_path = ""

    def __len__(self):
        return len(self.image_list)


    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index of the image to retrieve
        Returns:
            image (torch.Tensor): Image tensor of shape (C, H, W).
                Each image only returns one "sample" during training and validation,
                but multiple patches during testing:
                (n_patches_h, n_patches_w, C, patch_h, patch_w)
            mask (torch.Tensor): Mask tensor of shape (C, H, W)
            sdm (torch.Tensor): SDM tensor of shape (1, H, W) (optional)
        """
        img_path = os.path.join(self.image_dir, self.image_list[idx])
        mask_path = os.path.join(self.mask_dir, self.image_list[idx]) # Assuming masks have the same name
        self.this_image_path = img_path
        self.this_mask_path = mask_path

        # Read image and mask
        image = read_image_png(img_path)
        if os.path.exists(mask_path):
            mask_raw = read_image_png(mask_path)
        else: # inference time don't have mask
            mask_raw = np.zeros_like(image).astype(np.uint8) # dummy mask
        if self.sdm_dir:
            sdm_path = os.path.join(self.sdm_dir, self.image_list[idx].replace(".png", ".npy")) # Assuming sdms have the same name
            sdm = np.load(sdm_path) # load sdm as numpy array
        else:
            sdm = np.zeros_like(mask_raw).astype(np.float32) # dummy sdm

        # Normalize & Convert to tensors
        image = image / 255.0  # when import from preprocessed image dir: /norm_images
        mask = get_gt_mask_png(mask_raw[:,:,0])[:,:,1:] # leave out the 1st channel (empty), [0 1]
        # mask = get_gt_mask_png(mask_raw[:,:,0])[:,:,-1] # test with nuclei channel only
        # mask = np.expand_dims(mask, axis=-1)  # Add channel dimension
        # mask = mask / 255.0  # Normalize (Assuming mask values are 0 or 255)

        if self.transform:
            augmented = self.transform(image=image, mask=mask, sdm=sdm)
            image = augmented["image"]
            mask = augmented["mask"]
            sdm = augmented["sdm"]

        if len(image.shape) == 5: # Test time
            image = torch.tensor(image, dtype=torch.float32).permute(0, 1, -1, -3, -2)
        else:
            image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)
        if mask is not None: # Test time doesn't have mask
            mask = torch.tensor(mask, dtype=torch.float32).permute(2, 0, 1)
        if sdm is not None: # sdm is optional, depending on the loss in use
            sdm = torch.tensor(sdm, dtype=torch.float32).unsqueeze(0)

        return image, mask, sdm
    
    def get_mask_path(self):
        return self.this_mask_path