Spaces:
Running
on
Zero
Running
on
Zero
| import PIL | |
| from typing import List | |
| import numpy as np | |
| import torchvision.transforms.functional as F | |
| import torch | |
| def multiply_grayscale_images(image1, image2): | |
| # Convert the images to NumPy arrays | |
| image1_np = np.array(image1) | |
| image2_np = np.array(image2) | |
| # Perform element-wise multiplication (ensure to use np.float32 to avoid overflow) | |
| multiplied_image = image1_np.astype(np.float32) * image2_np.astype(np.float32) | |
| # Normalize the result to the range 0-255 (if needed) | |
| multiplied_image = np.clip(multiplied_image, 0, 255) | |
| # Convert back to uint8 (8-bit grayscale image) | |
| multiplied_image = multiplied_image.astype(np.uint8) | |
| # Convert back to an image and save the result | |
| result_image = PIL.Image.fromarray(multiplied_image) | |
| return result_image | |
| def create_color_masks(image: PIL.Image.Image): | |
| # Load the image | |
| image = image.convert("RGB") | |
| image_np = np.array(image) # Convert to numpy array (Height x Width x 3) | |
| # Find unique colors in the image | |
| unique_colors = np.unique(image_np.reshape(-1, 3), axis=0) | |
| output = [] | |
| # Create masks for each color | |
| for color in unique_colors: | |
| if sum(color) == 0: | |
| continue | |
| mask = np.all(image_np == color, axis=-1) | |
| color_str = '_'.join(map(str, color)) # Create a string representation of the color | |
| output.append((color_str, mask)) | |
| # Skip Background Mask Image | |
| background_area = 0.0 | |
| background_mask_index = -1 | |
| for idx, (color_str, mask) in enumerate(output): | |
| area = np.sum(mask > 0) / (mask.shape[0] * mask.shape[1]) | |
| if area > background_area: | |
| background_area = area | |
| background_mask_index = idx | |
| # Final Elements | |
| elements = [] | |
| for idx, (color_str, mask) in enumerate(output): | |
| if idx == background_mask_index: | |
| print(background_mask_index) | |
| continue | |
| mask_image = PIL.Image.fromarray(mask.astype(np.uint8) * 255) | |
| elements.append((color_str, mask_image)) | |
| # Final Background | |
| final_background_mask_image = PIL.Image.new("L", (image.size[0], image.size[1]), 255) | |
| draw = PIL.ImageDraw.Draw(final_background_mask_image) | |
| for idx, (color_str, mask_image) in enumerate(elements): | |
| final_background_mask_image = multiply_grayscale_images(final_background_mask_image, PIL.ImageOps.invert(mask_image)) | |
| return final_background_mask_image, elements | |
| def create_text_masks(polygons, width, height): | |
| # Loop over each polygon in the list | |
| text_masks = [] | |
| for i, polygon_coords in enumerate(polygons): | |
| # Create a new grayscale image (L mode) with a black background (0) | |
| mask = PIL.Image.new('L', (width, height), 0) | |
| # Create a drawing object | |
| draw = PIL.ImageDraw.Draw(mask) | |
| # Convert the list of polygon coordinates into a format ImageDraw can use (list of tuples) | |
| polygon_points = [(polygon_coords[j], polygon_coords[j + 1]) for j in range(0, len(polygon_coords), 2)] | |
| # Draw the polygon with white (255) fill | |
| draw.polygon(polygon_points, fill=255) | |
| text_masks.append(mask) | |
| return text_masks | |
| class GetLayerMask: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "image": ("IMAGE",), | |
| "json_data": ("JSON",), | |
| }, | |
| } | |
| RETURN_TYPES = ("MASK", "MASK", "JSON") | |
| FUNCTION = "main" | |
| CATEGORY = "tensorops" | |
| def main(self, image: torch.Tensor, json_data: str): | |
| # Create PIL.Image | |
| image = image.permute(0, 3, 1, 2) | |
| image_pil = F.to_pil_image(image[0]) | |
| # Create bg and elements | |
| bg, elements = create_color_masks(image_pil) | |
| # Create Text Masks | |
| print("items", json_data) | |
| items = [item for item in json_data] | |
| text_polygon_list = [] | |
| text_label_list = [] | |
| text_masks = [] | |
| for item in items: | |
| text_polygon_list.append(item["polygon"]) | |
| text_label_list.append(item["label"]) | |
| for mask_image in create_text_masks(text_polygon_list, bg.size[0], bg.size[1]): | |
| img = np.array(mask_image).astype(np.float32) / 255.0 | |
| img = torch.from_numpy(img)[None,] | |
| text_masks.append(img) | |
| output = [] | |
| bg = np.array(bg).astype(np.float32) / 255.0 | |
| bg = torch.from_numpy(bg)[None,] | |
| output.append(bg) | |
| for _, mask_image in elements: | |
| img = np.array(mask_image).astype(np.float32) / 255.0 | |
| img = torch.from_numpy(img)[None,] | |
| output.append(img) | |
| return (torch.cat(output, dim=0), torch.cat(text_masks, dim=0), text_label_list) | |