Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torchvision.transforms.functional as F | |
| import io | |
| import os | |
| from typing import List | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| from PIL import Image, ImageDraw, ImageColor, ImageFont | |
| import random | |
| import numpy as np | |
| import re | |
| #workaround for unnecessary flash_attn requirement | |
| from unittest.mock import patch | |
| from transformers.dynamic_module_utils import get_imports | |
| def fixed_get_imports(filename: str | os.PathLike) -> list[str]: | |
| if not str(filename).endswith("modeling_florence2.py"): | |
| return get_imports(filename) | |
| imports = get_imports(filename) | |
| imports.remove("flash_attn") | |
| return imports | |
| import comfy.model_management as mm | |
| from comfy.utils import ProgressBar | |
| import folder_paths | |
| script_directory = os.path.dirname(os.path.abspath(__file__)) | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| class DownloadAndLoadFlorence2Model: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "model": ( | |
| [ | |
| 'microsoft/Florence-2-base', | |
| 'microsoft/Florence-2-base-ft', | |
| 'microsoft/Florence-2-large', | |
| 'microsoft/Florence-2-large-ft', | |
| 'HuggingFaceM4/Florence-2-DocVQA' | |
| ], | |
| { | |
| "default": 'microsoft/Florence-2-base' | |
| }), | |
| "precision": ([ 'fp16','bf16','fp32'], | |
| { | |
| "default": 'fp16' | |
| }), | |
| "attention": ( | |
| [ 'flash_attention_2', 'sdpa', 'eager'], | |
| { | |
| "default": 'sdpa' | |
| }), | |
| }, | |
| } | |
| RETURN_TYPES = ("FL2MODEL",) | |
| RETURN_NAMES = ("florence2_model",) | |
| FUNCTION = "loadmodel" | |
| CATEGORY = "Florence2" | |
| def loadmodel(self, model, precision, attention): | |
| device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] | |
| model_name = model.rsplit('/', 1)[-1] | |
| model_path = os.path.join(folder_paths.models_dir, "LLM", model_name) | |
| if not os.path.exists(model_path): | |
| print(f"Downloading Lumina model to: {model_path}") | |
| from huggingface_hub import snapshot_download | |
| snapshot_download(repo_id=model, | |
| local_dir=model_path, | |
| local_dir_use_symlinks=False) | |
| print(f"using {attention} for attention") | |
| with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement | |
| model = AutoModelForCausalLM.from_pretrained(model_path, attn_implementation=attention, device_map=device, torch_dtype=dtype,trust_remote_code=True) | |
| processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) | |
| florence2_model = { | |
| 'model': model, | |
| 'processor': processor, | |
| 'dtype': dtype | |
| } | |
| return (florence2_model,) | |
| def calculate_bounding_box(width, height, flat_points) -> List[float]: | |
| """ | |
| Calculate the bounding box for a polygon. | |
| Args: | |
| flat_points (list of int): Flat list of x, y coordinates defining the polygon points. | |
| Returns: | |
| tuple: (min_x, min_y, max_x, max_y) defining the bounding box. | |
| """ | |
| if not flat_points or len(flat_points) % 2 != 0: | |
| raise ValueError("The list of points must be non-empty and have an even number of elements") | |
| x_coords = flat_points[0::2] | |
| y_coords = flat_points[1::2] | |
| min_x = min(x_coords) | |
| max_x = max(x_coords) | |
| min_y = min(y_coords) | |
| max_y = max(y_coords) | |
| return [min_x / width, min_y / height, max_x / width, max_y / height] | |
| class Florence2Run: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "image": ("IMAGE", ), | |
| "florence2_model": ("FL2MODEL", ), | |
| "text_input": ("STRING", {"default": "", "multiline": True}), | |
| "task": ( | |
| [ | |
| 'region_caption', | |
| 'dense_region_caption', | |
| 'region_proposal', | |
| 'caption', | |
| 'detailed_caption', | |
| 'more_detailed_caption', | |
| 'caption_to_phrase_grounding', | |
| 'referring_expression_segmentation', | |
| 'ocr', | |
| 'ocr_with_region', | |
| 'docvqa' | |
| ], | |
| ), | |
| "fill_mask": ("BOOLEAN", {"default": True}), | |
| }, | |
| "optional": { | |
| "keep_model_loaded": ("BOOLEAN", {"default": False}), | |
| "max_new_tokens": ("INT", {"default": 1024, "min": 1, "max": 4096}), | |
| "num_beams": ("INT", {"default": 3, "min": 1, "max": 64}), | |
| "do_sample": ("BOOLEAN", {"default": True}), | |
| "output_mask_select": ("STRING", {"default": ""}), | |
| } | |
| } | |
| RETURN_TYPES = ("IMAGE", "MASK", "STRING", "JSON") | |
| RETURN_NAMES =("image", "mask", "caption", "data") | |
| FUNCTION = "encode" | |
| CATEGORY = "Florence2" | |
| def encode(self, image, text_input, florence2_model, task, fill_mask, keep_model_loaded=False, | |
| num_beams=3, max_new_tokens=1024, do_sample=True, output_mask_select=""): | |
| device = mm.get_torch_device() | |
| _, height, width, _ = image.shape | |
| offload_device = mm.unet_offload_device() | |
| annotated_image_tensor = None | |
| mask_tensor = None | |
| processor = florence2_model['processor'] | |
| model = florence2_model['model'] | |
| dtype = florence2_model['dtype'] | |
| model.to(device) | |
| colormap = ['blue','orange','green','purple','brown','pink','olive','cyan','red', | |
| 'lime','indigo','violet','aqua','magenta','gold','tan','skyblue'] | |
| prompts = { | |
| 'region_caption': '<OD>', | |
| 'dense_region_caption': '<DENSE_REGION_CAPTION>', | |
| 'region_proposal': '<REGION_PROPOSAL>', | |
| 'caption': '<CAPTION>', | |
| 'detailed_caption': '<DETAILED_CAPTION>', | |
| 'more_detailed_caption': '<MORE_DETAILED_CAPTION>', | |
| 'caption_to_phrase_grounding': '<CAPTION_TO_PHRASE_GROUNDING>', | |
| 'referring_expression_segmentation': '<REFERRING_EXPRESSION_SEGMENTATION>', | |
| 'ocr': '<OCR>', | |
| 'ocr_with_region': '<OCR_WITH_REGION>', | |
| 'docvqa': '<DocVQA>' | |
| } | |
| task_prompt = prompts.get(task, '<OD>') | |
| if (task not in ['referring_expression_segmentation', 'caption_to_phrase_grounding', 'docvqa']) and text_input: | |
| raise ValueError("Text input (prompt) is only supported for 'referring_expression_segmentation', 'caption_to_phrase_grounding', and 'docvqa'") | |
| if text_input != "": | |
| prompt = task_prompt + " " + text_input | |
| else: | |
| prompt = task_prompt | |
| image = image.permute(0, 3, 1, 2) | |
| out = [] | |
| out_masks = [] | |
| out_results = [] | |
| out_data = [] | |
| pbar = ProgressBar(len(image)) | |
| for img in image: | |
| image_pil = F.to_pil_image(img) | |
| inputs = processor(text=prompt, images=image_pil, return_tensors="pt", do_rescale=False).to(dtype).to(device) | |
| generated_ids = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| num_beams=num_beams, | |
| ) | |
| results = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| print(results) | |
| # cleanup the special tokens from the final list | |
| if task == 'ocr_with_region': | |
| clean_results = str(results) | |
| cleaned_string = re.sub(r'</?s>|<[^>]*>', '\n', clean_results) | |
| clean_results = re.sub(r'\n+', '\n', cleaned_string) | |
| else: | |
| clean_results = str(results) | |
| clean_results = clean_results.replace('</s>', '') | |
| clean_results = clean_results.replace('<s>', '') | |
| #return single string if only one image for compatibility with nodes that can't handle string lists | |
| if len(image) == 1: | |
| out_results = clean_results | |
| else: | |
| out_results.append(clean_results) | |
| W, H = image_pil.size | |
| parsed_answer = processor.post_process_generation(results, task=task_prompt, image_size=(W, H)) | |
| if task == 'region_caption' or task == 'dense_region_caption' or task == 'caption_to_phrase_grounding' or task == 'region_proposal': | |
| fig, ax = plt.subplots(figsize=(W / 100, H / 100), dpi=100) | |
| fig.subplots_adjust(left=0, right=1, top=1, bottom=0) | |
| ax.imshow(image_pil) | |
| bboxes = parsed_answer[task_prompt]['bboxes'] | |
| labels = parsed_answer[task_prompt]['labels'] | |
| mask_indexes = [] | |
| # Determine mask indexes outside the loop | |
| if output_mask_select != "": | |
| mask_indexes = [n for n in output_mask_select.split(",")] | |
| print(mask_indexes) | |
| else: | |
| mask_indexes = [str(i) for i in range(len(bboxes))] | |
| # Initialize mask_layer only if needed | |
| if fill_mask: | |
| mask_layer = Image.new('RGB', image_pil.size, (0, 0, 0)) | |
| mask_draw = ImageDraw.Draw(mask_layer) | |
| for index, (bbox, label) in enumerate(zip(bboxes, labels)): | |
| # Modify the label to include the index | |
| indexed_label = f"{index}.{label}" | |
| if fill_mask: | |
| if str(index) in mask_indexes: | |
| print("match index:", str(index), "in mask_indexes:", mask_indexes) | |
| mask_draw.rectangle([bbox[0], bbox[1], bbox[2], bbox[3]], fill=(255, 255, 255)) | |
| if label in mask_indexes: | |
| print("match label") | |
| mask_draw.rectangle([bbox[0], bbox[1], bbox[2], bbox[3]], fill=(255, 255, 255)) | |
| # Create a Rectangle patch | |
| rect = patches.Rectangle( | |
| (bbox[0], bbox[1]), # (x,y) - lower left corner | |
| bbox[2] - bbox[0], # Width | |
| bbox[3] - bbox[1], # Height | |
| linewidth=1, | |
| edgecolor='r', | |
| facecolor='none', | |
| label=indexed_label | |
| ) | |
| # Calculate text width with a rough estimation | |
| text_width = len(label) * 6 # Adjust multiplier based on your font size | |
| text_height = 12 # Adjust based on your font size | |
| # Initial text position | |
| text_x = bbox[0] | |
| text_y = bbox[1] - text_height # Position text above the top-left of the bbox | |
| # Adjust text_x if text is going off the left or right edge | |
| if text_x < 0: | |
| text_x = 0 | |
| elif text_x + text_width > W: | |
| text_x = W - text_width | |
| # Adjust text_y if text is going off the top edge | |
| if text_y < 0: | |
| text_y = bbox[3] # Move text below the bottom-left of the bbox if it doesn't overlap with bbox | |
| # Add the rectangle to the plot | |
| ax.add_patch(rect) | |
| facecolor = random.choice(colormap) if len(image) == 1 else 'red' | |
| # Add the label | |
| plt.text( | |
| text_x, | |
| text_y, | |
| indexed_label, | |
| color='white', | |
| fontsize=12, | |
| bbox=dict(facecolor=facecolor, alpha=0.5) | |
| ) | |
| if fill_mask: | |
| mask_tensor = F.to_tensor(mask_layer) | |
| mask_tensor = mask_tensor.unsqueeze(0).permute(0, 2, 3, 1).cpu().float() | |
| mask_tensor = mask_tensor.mean(dim=0, keepdim=True) | |
| mask_tensor = mask_tensor.repeat(1, 1, 1, 3) | |
| mask_tensor = mask_tensor[:, :, :, 0] | |
| out_masks.append(mask_tensor) | |
| # Remove axis and padding around the image | |
| ax.axis('off') | |
| ax.margins(0,0) | |
| ax.get_xaxis().set_major_locator(plt.NullLocator()) | |
| ax.get_yaxis().set_major_locator(plt.NullLocator()) | |
| fig.canvas.draw() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', pad_inches=0) | |
| buf.seek(0) | |
| annotated_image_pil = Image.open(buf) | |
| annotated_image_tensor = F.to_tensor(annotated_image_pil) | |
| out_tensor = annotated_image_tensor[:3, :, :].unsqueeze(0).permute(0, 2, 3, 1).cpu().float() | |
| out.append(out_tensor) | |
| pbar.update(1) | |
| plt.close(fig) | |
| elif task == 'referring_expression_segmentation': | |
| # Create a new black image | |
| mask_image = Image.new('RGB', (W, H), 'black') | |
| mask_draw = ImageDraw.Draw(mask_image) | |
| predictions = parsed_answer[task_prompt] | |
| # Iterate over polygons and labels | |
| for polygons, label in zip(predictions['polygons'], predictions['labels']): | |
| color = random.choice(colormap) | |
| for _polygon in polygons: | |
| _polygon = np.array(_polygon).reshape(-1, 2) | |
| # Clamp polygon points to image boundaries | |
| _polygon = np.clip(_polygon, [0, 0], [W - 1, H - 1]) | |
| if len(_polygon) < 3: | |
| print('Invalid polygon:', _polygon) | |
| continue | |
| _polygon = _polygon.reshape(-1).tolist() | |
| # Draw the polygon | |
| if fill_mask: | |
| overlay = Image.new('RGBA', image_pil.size, (255, 255, 255, 0)) | |
| image_pil = image_pil.convert('RGBA') | |
| draw = ImageDraw.Draw(overlay) | |
| color_with_opacity = ImageColor.getrgb(color) + (180,) | |
| draw.polygon(_polygon, outline=color, fill=color_with_opacity, width=3) | |
| image_pil = Image.alpha_composite(image_pil, overlay) | |
| else: | |
| draw = ImageDraw.Draw(image_pil) | |
| draw.polygon(_polygon, outline=color, width=3) | |
| #draw mask | |
| mask_draw.polygon(_polygon, outline="white", fill="white") | |
| image_tensor = F.to_tensor(image_pil) | |
| image_tensor = image_tensor[:3, :, :].unsqueeze(0).permute(0, 2, 3, 1).cpu().float() | |
| out.append(image_tensor) | |
| mask_tensor = F.to_tensor(mask_image) | |
| mask_tensor = mask_tensor.unsqueeze(0).permute(0, 2, 3, 1).cpu().float() | |
| mask_tensor = mask_tensor.mean(dim=0, keepdim=True) | |
| mask_tensor = mask_tensor.repeat(1, 1, 1, 3) | |
| mask_tensor = mask_tensor[:, :, :, 0] | |
| out_masks.append(mask_tensor) | |
| pbar.update(1) | |
| elif task == 'ocr_with_region': | |
| try: | |
| font = ImageFont.load_default().font_variant(size=24) | |
| except: | |
| font = ImageFont.load_default() | |
| predictions = parsed_answer[task_prompt] | |
| scale = 1 | |
| draw = ImageDraw.Draw(image_pil) | |
| bboxes, labels = predictions['quad_boxes'], predictions['labels'] | |
| for box, label in zip(bboxes, labels): | |
| bbox = calculate_bounding_box(width, height, box) | |
| out_data.append({"label": label, "polygon": box, "box": bbox}) | |
| color = random.choice(colormap) | |
| new_box = (np.array(box) * scale).tolist() | |
| draw.polygon(new_box, width=3, outline=color) | |
| draw.text((new_box[0]+8, new_box[1]+2), | |
| "{}".format(label), | |
| align="right", | |
| font=font, | |
| fill=color) | |
| image_tensor = F.to_tensor(image_pil) | |
| image_tensor = image_tensor[:3, :, :].unsqueeze(0).permute(0, 2, 3, 1).cpu().float() | |
| out.append(image_tensor) | |
| elif task == 'docvqa': | |
| if text_input == "": | |
| raise ValueError("Text input (prompt) is required for 'docvqa'") | |
| prompt = "<DocVQA> " + text_input | |
| inputs = processor(text=prompt, images=image_pil, return_tensors="pt", do_rescale=False).to(dtype).to(device) | |
| generated_ids = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| num_beams=num_beams, | |
| ) | |
| results = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| clean_results = results.replace('</s>', '').replace('<s>', '') | |
| if len(image) == 1: | |
| out_results = clean_results | |
| else: | |
| out_results.append(clean_results) | |
| out.append(F.to_tensor(image_pil).unsqueeze(0).permute(0, 2, 3, 1).cpu().float()) | |
| pbar.update(1) | |
| if len(out) > 0: | |
| out_tensor = torch.cat(out, dim=0) | |
| else: | |
| out_tensor = torch.zeros((1, 64,64, 3), dtype=torch.float32, device="cpu") | |
| if len(out_masks) > 0: | |
| out_mask_tensor = torch.cat(out_masks, dim=0) | |
| else: | |
| out_mask_tensor = torch.zeros((1,64,64), dtype=torch.float32, device="cpu") | |
| if not keep_model_loaded: | |
| print("Offloading model...") | |
| model.to(offload_device) | |
| mm.soft_empty_cache() | |
| return (out_tensor, out_mask_tensor, out_results, out_data) | |
| NODE_CLASS_MAPPINGS = { | |
| "DownloadAndLoadFlorence2Model": DownloadAndLoadFlorence2Model, | |
| "Florence2Run": Florence2Run, | |
| } | |
| NODE_DISPLAY_NAME_MAPPINGS = { | |
| "DownloadAndLoadFlorence2Model": "DownloadAndLoadFlorence2Model", | |
| "Florence2Run": "Florence2Run", | |
| } |