Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| import torch | |
| from mwm.utils.common import load_json | |
| from mwm.components.model_architecture import make_model | |
| from mwm.components.dataset import make_dataset | |
| from mwm.components.image_processing import ( | |
| post_processing_watershed_2ch, | |
| post_processing_denoise_2ch | |
| ) | |
| def inference(config_path:str, image_path:str): | |
| """ | |
| Inference function for the segmentation model. | |
| Args: | |
| - config_path (str): Path to the configuration file. Example config see "app/config.json". | |
| - image_path (str): Path to the input image for segmentation. | |
| Returns: | |
| - mask_pred_uint8 (np.ndarray): Mask-level prediction in uint8 format. | |
| - labels_pred (np.ndarray): Label-level prediction with shape: (h, w). | |
| """ | |
| # Load config | |
| config = load_json(Path(config_path)) | |
| # Load model | |
| model = make_model(config.network, encoder_weights=None) | |
| model.load_state_dict(torch.load(config.model_path, map_location=torch.device("cpu"))) | |
| model.eval() | |
| # Define pre/post-processing on input image | |
| test_dataset = make_dataset( | |
| config.dataset, | |
| image_dir="", # root path | |
| mask_dir="", | |
| sdm_dir=None, | |
| image_list=[image_path], | |
| mode="test", | |
| image_size=config.image_size | |
| ) | |
| # Get input image (e.g. image shape: torch.Size([3, 3, 3, 256, 256])) | |
| image, _, _ = test_dataset[0] | |
| # Handle device & batching | |
| _, _, c, h, w = image.shape | |
| image = image.reshape(-1, c, h, w).to(torch.device("cpu")) # torch.Size([9, 3, 256, 256]) | |
| # Get prediction | |
| with torch.no_grad(): | |
| output = model(image).squeeze() | |
| # Move to CPU and to numpy | |
| output = output.cpu().numpy() # <class 'numpy.ndarray'>, shape: (9, 2, 256, 256) | |
| # Mask-level result | |
| # <class 'numpy.ndarray'>, probabilities, 2 channels, cut to original image size | |
| output_stitched = test_dataset.transform.reconstruct_full_frame(output) | |
| output_stitched = post_processing_denoise_2ch(output_stitched) | |
| # Label-level result (final) | |
| labels_pred = post_processing_watershed_2ch(output_stitched) | |
| return labels_pred, output_stitched | |