import os import numpy as np import cv2 import torch from torch.utils.data import Dataset import albumentations as A from mwm.components.image_processing import ( normalize_image, get_gt_mask_png, read_image_png, TestTimeTransform ) from mwm import logger def _get_transform(image_size, mode, overlap=0.1): if mode == "train": return A.Compose([ A.RandomCrop(width=image_size[0], height=image_size[1]), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.RandomRotate90(p=0.5) ], additional_targets={'sdm': 'mask'} ) elif mode == "val": return A.Compose([ A.CenterCrop(width=image_size[0], height=image_size[1]) ], additional_targets={'sdm': 'mask'} ) elif mode == "test": return TestTimeTransform(width=image_size[0], height=image_size[1]) else: logger.error(f"Invalid mode: {mode}") raise ValueError(f"Invalid mode: {mode}") # Utils for custom datasets def make_dataset(dataset_name, image_dir, mask_dir, sdm_dir, image_list, mode, image_size=[256, 256]): transform = _get_transform(image_size, mode) if dataset_name == "seg_2ch": dataset = Seg2ChannelDataset(image_dir, mask_dir, sdm_dir, image_list, transform) logger.info(f"Dataset: {dataset_name} successfully processed. ") return dataset else: logger.error(f"Invalid dataset: {dataset_name}") raise ValueError(f"Invalid dataset: {dataset_name}") # Classes for custom datasets class Seg2ChannelDataset(Dataset): def __init__(self, image_dir, mask_dir, sdm_dir, image_list, transform=None): self.image_dir = image_dir self.mask_dir = mask_dir self.sdm_dir = sdm_dir self.image_list = image_list # This is when image_list is pre-selected for train/val/test split self.transform = transform # For info retrieval where needed (e.g. at evaluation) self.this_image_path = "" self.this_mask_path = "" def __len__(self): return len(self.image_list) def __getitem__(self, idx): """ Args: idx (int): Index of the image to retrieve Returns: image (torch.Tensor): Image tensor of shape (C, H, W). Each image only returns one "sample" during training and validation, but multiple patches during testing: (n_patches_h, n_patches_w, C, patch_h, patch_w) mask (torch.Tensor): Mask tensor of shape (C, H, W) sdm (torch.Tensor): SDM tensor of shape (1, H, W) (optional) """ img_path = os.path.join(self.image_dir, self.image_list[idx]) mask_path = os.path.join(self.mask_dir, self.image_list[idx]) # Assuming masks have the same name self.this_image_path = img_path self.this_mask_path = mask_path # Read image and mask image = read_image_png(img_path) if os.path.exists(mask_path): mask_raw = read_image_png(mask_path) else: # inference time don't have mask mask_raw = np.zeros_like(image).astype(np.uint8) # dummy mask if self.sdm_dir: sdm_path = os.path.join(self.sdm_dir, self.image_list[idx].replace(".png", ".npy")) # Assuming sdms have the same name sdm = np.load(sdm_path) # load sdm as numpy array else: sdm = np.zeros_like(mask_raw).astype(np.float32) # dummy sdm # Normalize & Convert to tensors image = image / 255.0 # when import from preprocessed image dir: /norm_images mask = get_gt_mask_png(mask_raw[:,:,0])[:,:,1:] # leave out the 1st channel (empty), [0 1] # mask = get_gt_mask_png(mask_raw[:,:,0])[:,:,-1] # test with nuclei channel only # mask = np.expand_dims(mask, axis=-1) # Add channel dimension # mask = mask / 255.0 # Normalize (Assuming mask values are 0 or 255) if self.transform: augmented = self.transform(image=image, mask=mask, sdm=sdm) image = augmented["image"] mask = augmented["mask"] sdm = augmented["sdm"] if len(image.shape) == 5: # Test time image = torch.tensor(image, dtype=torch.float32).permute(0, 1, -1, -3, -2) else: image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) if mask is not None: # Test time doesn't have mask mask = torch.tensor(mask, dtype=torch.float32).permute(2, 0, 1) if sdm is not None: # sdm is optional, depending on the loss in use sdm = torch.tensor(sdm, dtype=torch.float32).unsqueeze(0) return image, mask, sdm def get_mask_path(self): return self.this_mask_path