#!/usr/bin/env python import os import re import tempfile from collections.abc import Iterator from threading import Thread import cv2 import gradio as gr import spaces import torch from loguru import logger from PIL import Image from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer import gc model_id = os.getenv("MODEL_ID", "google/medgemma-4b-it") # Memory optimization settings torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True # Initialize processor and model with memory optimizations processor = AutoProcessor.from_pretrained(model_id) # Load model with aggressive memory optimizations model = AutoModelForImageTextToText.from_pretrained( model_id, device_map="auto", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, token=os.environ.get("HF_TOKEN", "YOUR_HF_TOKEN"), # Memory optimization parameters low_cpu_mem_usage=True, trust_remote_code=True, # Use 8-bit quantization if available load_in_8bit=True if torch.cuda.is_available() else False, # Alternative: use 4-bit quantization for even more memory savings # load_in_4bit=True, # bnb_4bit_compute_dtype=torch.float16, # bnb_4bit_use_double_quant=True, # bnb_4bit_quant_type="nf4" ) # Enable gradient checkpointing to save memory during inference if hasattr(model, 'gradient_checkpointing_enable'): model.gradient_checkpointing_enable() MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "3")) # Reduced from 5 to 3 def cleanup_memory(): """Aggressive memory cleanup""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() def count_files_in_new_message(paths: list[str]) -> tuple[int, int]: image_count = 0 video_count = 0 for path in paths: if path.endswith(".mp4"): video_count += 1 else: image_count += 1 return image_count, video_count def count_files_in_history(history: list[dict]) -> tuple[int, int]: image_count = 0 video_count = 0 for item in history: if item["role"] != "user" or isinstance(item["content"], str): continue if item["content"][0].endswith(".mp4"): video_count += 1 else: image_count += 1 return image_count, video_count def validate_media_constraints(message: dict, history: list[dict]) -> bool: new_image_count, new_video_count = count_files_in_new_message(message["files"]) history_image_count, history_video_count = count_files_in_history(history) image_count = history_image_count + new_image_count video_count = history_video_count + new_video_count if video_count > 1: gr.Warning("Only one video is supported.") return False if video_count == 1: if image_count > 0: gr.Warning("Mixing images and videos is not allowed.") return False if "" in message["text"]: gr.Warning("Using tags with video files is not supported.") return False if video_count == 0 and image_count > MAX_NUM_IMAGES: gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.") return False if "" in message["text"] and message["text"].count("") != new_image_count: gr.Warning("The number of tags in the text does not match the number of images.") return False return True def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]: vidcap = cv2.VideoCapture(video_path) fps = vidcap.get(cv2.CAP_PROP_FPS) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) # Reduce max frames for memory efficiency max_frames = min(MAX_NUM_IMAGES, 8) frame_interval = max(total_frames // max_frames, 1) frames: list[tuple[Image.Image, float]] = [] for i in range(0, min(total_frames, max_frames * frame_interval), frame_interval): if len(frames) >= max_frames: break vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) success, image = vidcap.read() if success: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Resize image to reduce memory usage pil_image = Image.fromarray(image) # Resize if too large max_size = 512 if max(pil_image.size) > max_size: pil_image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) timestamp = round(i / fps, 2) frames.append((pil_image, timestamp)) vidcap.release() return frames def process_video(video_path: str) -> list[dict]: content = [] frames = downsample_video(video_path) for frame in frames: pil_image, timestamp = frame with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: pil_image.save(temp_file.name, optimize=True) content.append({"type": "text", "text": f"Frame {timestamp}:"}) content.append({"type": "image", "url": temp_file.name}) logger.debug(f"{content=}") return content def process_interleaved_images(message: dict) -> list[dict]: logger.debug(f"{message['files']=}") parts = re.split(r"()", message["text"]) logger.debug(f"{parts=}") content = [] image_index = 0 for part in parts: logger.debug(f"{part=}") if part == "": # Resize images before processing img = Image.open(message["files"][image_index]) max_size = 512 if max(img.size) > max_size: img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) # Save resized image with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: img.save(temp_file.name, optimize=True, quality=85) content.append({"type": "image", "url": temp_file.name}) else: content.append({"type": "image", "url": message["files"][image_index]}) logger.debug(f"file: {message['files'][image_index]}") image_index += 1 elif part.strip(): content.append({"type": "text", "text": part.strip()}) elif isinstance(part, str) and part != "": content.append({"type": "text", "text": part}) logger.debug(f"{content=}") return content def process_new_user_message(message: dict) -> list[dict]: if not message["files"]: return [{"type": "text", "text": message["text"]}] if message["files"][0].endswith(".mp4"): return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])] if "" in message["text"]: return process_interleaved_images(message) # Process regular images with resizing processed_images = [] for path in message["files"]: img = Image.open(path) max_size = 512 if max(img.size) > max_size: img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: img.save(temp_file.name, optimize=True, quality=85) processed_images.append({"type": "image", "url": temp_file.name}) else: processed_images.append({"type": "image", "url": path}) return [ {"type": "text", "text": message["text"]}, *processed_images, ] def process_history(history: list[dict]) -> list[dict]: messages = [] current_user_content: list[dict] = [] # Limit history to prevent memory overflow recent_history = history[-10:] if len(history) > 10 else history for item in recent_history: if item["role"] == "assistant": if current_user_content: messages.append({"role": "user", "content": current_user_content}) current_user_content = [] messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]}) else: content = item["content"] if isinstance(content, str): current_user_content.append({"type": "text", "text": content}) else: current_user_content.append({"type": "image", "url": content[0]}) return messages @spaces.GPU(duration=120) def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 1024) -> Iterator[str]: # Cleanup memory before processing cleanup_memory() if not validate_media_constraints(message, history): yield "" return try: messages = [] if system_prompt: messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]}) messages.extend(process_history(history)) messages.append({"role": "user", "content": process_new_user_message(message)}) # Apply chat template with memory optimization inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(device=model.device) # Reduce max_new_tokens to save memory max_new_tokens = min(max_new_tokens, 512) streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( inputs, max_new_tokens=max_new_tokens, streamer=streamer, temperature=0.8, # Slightly reduced for more focused responses top_p=0.9, # Reduced for efficiency top_k=50, # Reduced for efficiency min_p=0.05, # Added for better token filtering do_sample=True, pad_token_id=processor.tokenizer.eos_token_id, use_cache=True, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() output = "" for delta in streamer: output += delta yield output except Exception as e: logger.error(f"Error during generation: {e}") yield f"Error: {str(e)}" finally: # Cleanup after generation cleanup_memory() # Streamlined CSS custom_css = """ .gr-chatbot { background-color: #ffffff; border-radius: 8px; border: 1px solid #e5e7eb; } .gr-textbox textarea { min-height: 100px; border-radius: 8px; padding: 12px; } .gr-button { background-color: #4f46e5 !important; color: white !important; border-radius: 6px !important; padding: 8px 16px !important; } .gr-interface { max-width: 800px; margin: 0 auto; padding: 16px; } """ DESCRIPTION = """\ ## Medical Vision-Language Assistant (Memory Optimized) This AI assistant analyzes medical images and videos with memory efficiency optimizations. **Features:** - Medical image analysis (max 512px resolution for efficiency) - Video frame processing (limited frames) - Reduced memory footprint - Optimized for resource-constrained environments """ demo = gr.ChatInterface( fn=run, type="messages", chatbot=gr.Chatbot( type="messages", scale=1, allow_tags=["image"], bubble_full_width=False, height=400, # Fixed height to save memory ), textbox=gr.MultimodalTextbox( file_types=["image", ".mp4"], file_count="multiple", autofocus=True, placeholder="Upload images/video and ask questions...", ), multimodal=True, additional_inputs=[ gr.Textbox( label="System Prompt", value="You are a medical AI assistant. Provide concise, accurate analysis.", lines=2, ), gr.Slider( label="Response Length", minimum=50, maximum=512, step=10, value=256, info="Shorter responses use less memory" ), ], stop_btn=None, title="Medical Vision Assistant", description=DESCRIPTION, cache_examples=False, css=custom_css, ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=True)