Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image, ImageDraw | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| import tempfile | |
| import os | |
| import gc | |
| import spaces # <--- IMPORT OBLIGATOIRE POUR ZEROGPU | |
| from transformers import ( | |
| Sam3Model, Sam3Processor, | |
| Sam3TrackerModel, Sam3TrackerProcessor, | |
| Sam3VideoModel, Sam3VideoProcessor, | |
| Sam3TrackerVideoModel, Sam3TrackerVideoProcessor | |
| ) | |
| # --- CONFIGURATION --- | |
| MODELS = {} | |
| # Sur ZeroGPU, on peut forcer "cuda" car le décorateur nous garantit un GPU. | |
| device = "cuda" | |
| print(f"🖥️ Configuration ZeroGPU active.") | |
| def cleanup_memory(): | |
| """Force le nettoyage de la VRAM.""" | |
| if MODELS: | |
| print("🧹 Nettoyage préventif mémoire...") | |
| MODELS.clear() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def get_model(model_type): | |
| """Charge le modèle. Sur ZeroGPU, cela se produit à l'intérieur de la fonction décorée.""" | |
| if model_type in MODELS: | |
| return MODELS[model_type] | |
| # Même avec 70GB de VRAM sur ZeroGPU, garder le swap est une bonne pratique de stabilité | |
| cleanup_memory() | |
| print(f"⏳ Chargement de {model_type} sur H200...") | |
| try: | |
| if model_type == "sam3_image_text": | |
| model = Sam3Model.from_pretrained("facebook/sam3").to(device) | |
| processor = Sam3Processor.from_pretrained("facebook/sam3") | |
| MODELS[model_type] = (model, processor) | |
| elif model_type == "sam3_image_tracker": | |
| model = Sam3TrackerModel.from_pretrained("facebook/sam3").to(device) | |
| processor = Sam3TrackerProcessor.from_pretrained("facebook/sam3") | |
| MODELS[model_type] = (model, processor) | |
| elif model_type == "sam3_video_text": | |
| # Sur H200, on peut se permettre le float32, mais bfloat16 reste plus rapide | |
| model = Sam3VideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16) | |
| processor = Sam3VideoProcessor.from_pretrained("facebook/sam3") | |
| MODELS[model_type] = (model, processor) | |
| elif model_type == "sam3_video_tracker": | |
| model = Sam3TrackerVideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16) | |
| processor = Sam3TrackerVideoProcessor.from_pretrained("facebook/sam3") | |
| MODELS[model_type] = (model, processor) | |
| print(f"✅ {model_type} chargé.") | |
| return MODELS[model_type] | |
| except Exception as e: | |
| print(f"❌ Erreur chargement : {e}") | |
| cleanup_memory() | |
| raise e | |
| # --- UTILITAIRES (Pas besoin de GPU ici) --- | |
| def overlay_masks(image, masks, scores=None, alpha=0.5): | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| image = image.convert("RGBA") | |
| if masks is None or len(masks) == 0: | |
| return image | |
| if isinstance(masks, torch.Tensor): | |
| masks = masks.cpu().numpy() | |
| masks = masks.astype(np.uint8) | |
| if masks.ndim == 4: masks = masks[0] | |
| if masks.ndim == 3 and masks.shape[0] == 1: masks = masks[0] | |
| n_masks = masks.shape[0] if masks.ndim == 3 else 1 | |
| if masks.ndim == 2: | |
| masks = [masks] | |
| n_masks = 1 | |
| try: | |
| cmap = matplotlib.colormaps["rainbow"].resampled(max(n_masks, 1)) | |
| except AttributeError: | |
| import matplotlib.cm as cm | |
| cmap = cm.get_cmap("rainbow").resampled(max(n_masks, 1)) | |
| colors = [tuple(int(c * 255) for c in cmap(i)[:3]) for i in range(n_masks)] | |
| overlay_layer = Image.new("RGBA", image.size, (0, 0, 0, 0)) | |
| for i, mask in enumerate(masks): | |
| mask_img = Image.fromarray((mask * 255).astype(np.uint8)) | |
| if mask_img.size != image.size: | |
| mask_img = mask_img.resize(image.size, resample=Image.NEAREST) | |
| color = colors[i] | |
| color_layer = Image.new("RGBA", image.size, color + (0,)) | |
| mask_alpha = mask_img.point(lambda v: int(v * alpha) if v > 0 else 0) | |
| color_layer.putalpha(mask_alpha) | |
| overlay_layer = Image.alpha_composite(overlay_layer, color_layer) | |
| return Image.alpha_composite(image, overlay_layer).convert("RGB") | |
| def get_first_frame(video_path): | |
| """Extrait la première frame d'une vidéo pour permettre le clic.""" | |
| if not video_path: return None | |
| cap = cv2.VideoCapture(video_path) | |
| ret, frame = cap.read() | |
| cap.release() | |
| if ret: | |
| return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| return None | |
| def draw_points_on_image(image, points): | |
| """Dessine des points rouges sur l'image pour feedback visuel.""" | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # Créer une copie pour dessiner | |
| draw_img = image.copy() | |
| draw = ImageDraw.Draw(draw_img) | |
| for pt in points: | |
| x, y = pt | |
| r = 5 | |
| draw.ellipse((x-r, y-r, x+r, y+r), fill="red", outline="white") | |
| return draw_img | |
| # --- HELPERS POUR DUREE DYNAMIQUE ZEROGPU --- | |
| def compute_duration_text(video_path, text_prompt, max_frames, timeout_seconds): | |
| return timeout_seconds | |
| def compute_duration_tracker(video_path, points_state, labels_state, max_frames, timeout_seconds): | |
| return timeout_seconds | |
| # --- LOGIQUE AVEC DÉCORATEURS ZEROGPU --- | |
| def process_image_text(image, text_prompt, threshold, mask_threshold): | |
| if image is None or not text_prompt: | |
| return image, "Please provide an image and a text prompt." | |
| try: | |
| model, processor = get_model("sam3_image_text") | |
| inputs = processor(images=image, text=text_prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| results = processor.post_process_instance_segmentation( | |
| outputs, threshold=threshold, mask_threshold=mask_threshold, | |
| target_sizes=inputs.get("original_sizes").tolist() | |
| )[0] | |
| final_img = overlay_masks(image, results["masks"]) | |
| info = f"Objects found: {len(results['masks'])}\nScores: {results['scores'].cpu().numpy()}" | |
| return final_img, info | |
| except Exception as e: | |
| return image, f"Error: {str(e)}" | |
| # Image Tracker avec Multi-points | |
| def process_image_tracker_gpu(image, x, y, points_state, labels_state, multimask): | |
| if image is None: return image, [], [] | |
| if points_state is None: points_state = []; labels_state = [] | |
| points_state.append([x, y]) | |
| labels_state.append(1) | |
| try: | |
| model, processor = get_model("sam3_image_tracker") | |
| input_points = [[points_state]] | |
| input_labels = [[labels_state]] | |
| inputs = processor(images=image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs, multimask_output=multimask) | |
| masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"], binarize=True)[0] | |
| masks_to_show = masks[0] | |
| if multimask and masks_to_show.shape[0] > 1: | |
| scores = outputs.iou_scores.cpu().numpy()[0, 0] | |
| best_idx = np.argmax(scores) | |
| masks_to_show = masks_to_show[best_idx:best_idx+1] | |
| final_img = overlay_masks(image, masks_to_show) | |
| # Dessiner les points | |
| final_img = draw_points_on_image(final_img, points_state) | |
| return final_img, points_state, labels_state | |
| except Exception as e: | |
| print(f"Tracker Error: {e}") | |
| return image, points_state, labels_state | |
| def process_image_tracker_wrapper(image, evt: gr.SelectData, points_state, labels_state, multimask): | |
| if evt is None: return image, points_state, labels_state | |
| x, y = evt.index | |
| return process_image_tracker_gpu(image, x, y, points_state, labels_state, multimask) | |
| def process_video_text(video_path, text_prompt, max_frames, timeout_seconds): | |
| if not video_path or not text_prompt: return None, "Missing video or prompt." | |
| try: | |
| model, processor = get_model("sam3_video_text") | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| frames = [] | |
| frame_count = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret or (max_frames > 0 and frame_count >= max_frames): break | |
| frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| frame_count += 1 | |
| cap.release() | |
| inference_session = processor.init_video_session(video=frames, inference_device=device, dtype=torch.bfloat16) | |
| inference_session = processor.add_text_prompt(inference_session=inference_session, text=text_prompt) | |
| output_path = tempfile.mktemp(suffix=".mp4") | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| for model_outputs in model.propagate_in_video_iterator(inference_session=inference_session, max_frame_num_to_track=len(frames)): | |
| processed_outputs = processor.postprocess_outputs(inference_session, model_outputs) | |
| frame_idx = model_outputs.frame_idx | |
| orig_frame = Image.fromarray(frames[frame_idx]) | |
| if 'masks' in processed_outputs: | |
| masks = processed_outputs['masks'] | |
| if masks.ndim == 4: masks = masks.squeeze(1) | |
| res_frame = overlay_masks(orig_frame, masks) | |
| else: res_frame = orig_frame | |
| out.write(cv2.cvtColor(np.array(res_frame), cv2.COLOR_RGB2BGR)) | |
| out.release() | |
| return output_path, "Done!" | |
| except Exception as e: return None, f"Error: {str(e)}" | |
| # --- VIDEO TRACKER MULTI-POINT --- | |
| # Fonction CPU pour ajouter un point VISUELLEMENT (sans appeler le GPU) | |
| def add_point_video_preview(video_path, evt: gr.SelectData, points_state, labels_state): | |
| """Ajoute un point à la liste et met à jour l'image de preview avec un point rouge.""" | |
| if not video_path: return None, points_state, labels_state | |
| # Récupérer la frame originale brute (sans points) | |
| # Pour faire simple ici, on la recharge à chaque fois. | |
| # Optimisation possible: stocker l'image originale dans un State. | |
| orig_frame = get_first_frame(video_path) | |
| if orig_frame is None: return None, points_state, labels_state | |
| orig_img = Image.fromarray(orig_frame) | |
| x, y = evt.index | |
| if points_state is None: points_state = []; labels_state = [] | |
| points_state.append([x, y]) | |
| labels_state.append(1) | |
| # Dessiner TOUS les points sur l'image originale | |
| preview_img = draw_points_on_image(orig_img, points_state) | |
| return preview_img, points_state, labels_state | |
| def process_video_tracker_gpu(video_path, points_state, labels_state, max_frames, timeout_seconds): | |
| if not video_path or not points_state: return None, "Please click on the frame first." | |
| try: | |
| model, processor = get_model("sam3_video_tracker") | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| frames = [] | |
| frame_count = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret or (max_frames > 0 and frame_count >= max_frames): break | |
| frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| frame_count += 1 | |
| cap.release() | |
| inference_session = processor.init_video_session(video=frames, inference_device=device, dtype=torch.bfloat16) | |
| # Envoi de TOUS les points accumulés | |
| input_points = [[points_state]] # [Obj=1 [Points...]] | |
| input_labels = [[labels_state]] | |
| processor.add_inputs_to_inference_session( | |
| inference_session=inference_session, | |
| frame_idx=0, | |
| obj_ids=1, | |
| input_points=input_points, | |
| input_labels=input_labels | |
| ) | |
| output_path = tempfile.mktemp(suffix=".mp4") | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| model(inference_session=inference_session, frame_idx=0) # Verrouillage | |
| for sam3_out in model.propagate_in_video_iterator(inference_session): | |
| masks = processor.post_process_masks([sam3_out.pred_masks], original_sizes=[[height, width]], binarize=True)[0] | |
| frame_idx = sam3_out.frame_idx | |
| orig_frame = Image.fromarray(frames[frame_idx]) | |
| res_frame = overlay_masks(orig_frame, masks[:, 0, :, :]) | |
| out.write(cv2.cvtColor(np.array(res_frame), cv2.COLOR_RGB2BGR)) | |
| out.release() | |
| return output_path, "Tracking Finished!" | |
| except Exception as e: | |
| print(f"Video Tracker Error: {e}") | |
| return None, f"Fatal Error: {str(e)}" | |
| # --- INTERFACE GRADIO --- | |
| with gr.Blocks(title="SAM3 Ultimate Suite") as demo: | |
| gr.Markdown("# 🚀 SAM 3 : Unified Promptable Segmentation") | |
| gr.Markdown("This application allows you to utilize **all the powerful features of the SAM 3 model** for segmenting images and videos using text or visual prompts.") | |
| with gr.Tabs(): | |
| # TAB 1 : IMAGE + TEXTE | |
| with gr.Tab("🖼️ Image - Text Prompt"): | |
| gr.Markdown("### Segment objects by description\nSimply upload an image and type the name of the objects you want to detect (e.g., 'cat', 'wheel', 'person'). The model will find all instances.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| i1_input = gr.Image(type="pil", label="Input Image") | |
| i1_text = gr.Textbox(label="Text Prompt", placeholder="e.g.: cat, wheel, person") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| i1_thresh = gr.Slider(0.0, 1.0, value=0.5, label="Confidence Threshold") | |
| i1_mask_thresh = gr.Slider(0.0, 1.0, value=0.5, label="Mask Threshold") | |
| i1_btn = gr.Button("Segment Image", variant="primary") | |
| with gr.Column(): | |
| i1_output = gr.Image(type="pil", label="Result") | |
| i1_info = gr.Textbox(label="Details", lines=2) | |
| i1_btn.click(process_image_text, [i1_input, i1_text, i1_thresh, i1_mask_thresh], [i1_output, i1_info]) | |
| # TAB 2 : IMAGE + TRACKER | |
| with gr.Tab("🖱️ Image - Visual Tracker"): | |
| gr.Markdown("### Segment objects by clicking\nUpload an image and click on the object you wish to segment. You can click multiple times to refine the selection.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| i2_input = gr.Image(type="pil", label="Input Image (Click to add points)", interactive=True) | |
| i2_multimask = gr.Checkbox(label="Return Multiple Masks (Handling Ambiguity)", value=False) | |
| i2_clear = gr.Button("Reset Points") | |
| points_state = gr.State([]) | |
| labels_state = gr.State([]) | |
| with gr.Column(): | |
| i2_output = gr.Image(type="pil", label="Interactive Result") | |
| i2_input.select(process_image_tracker_wrapper, [i2_input, points_state, labels_state, i2_multimask], [i2_output, points_state, labels_state]) | |
| i2_clear.click(lambda: (None, [], []), outputs=[i2_output, points_state, labels_state]) | |
| # TAB 3 : VIDEO + TEXTE | |
| with gr.Tab("🎥 Video - Text Prompt"): | |
| gr.Markdown("### Track objects in video by description\nUpload a video and describe what to track. The model will detect and segment all matching objects throughout the video.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| v3_input = gr.Video(label="Input Video", format="mp4") | |
| v3_text = gr.Textbox(label="Text Prompt", placeholder="e.g.: person, car") | |
| v3_max_frames = gr.Slider(10, 1000, value=50, step=10, label="Max Frames to Process") | |
| v3_duration = gr.Radio([60, 120], value=60, label="Max Processing Time (seconds)", info="Choose 60s for short clips, 120s for complex tasks") | |
| v3_btn = gr.Button("Start Video Segmentation", variant="primary") | |
| with gr.Column(): | |
| v3_output = gr.Video(label="Result Video") | |
| v3_status = gr.Textbox(label="Status") | |
| v3_btn.click(process_video_text, [v3_input, v3_text, v3_max_frames, v3_duration], [v3_output, v3_status]) | |
| # TAB 4 : VIDEO + TRACKER | |
| with gr.Tab("🎯 Video - Visual Tracker"): | |
| gr.Markdown("### Track a specific object in video (Multi-point Support)\n1. Upload a video.\n2. Click on the object in the 'First Frame'. **You can click multiple times** to refine the selection.\n3. Click 'Start Object Tracking' when ready.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| v4_input = gr.Video(label="Input Video", format="mp4") | |
| v4_frame0 = gr.Image(label="First Frame (Click to add points)", interactive=True) | |
| v4_max_frames = gr.Slider(10, 1000, value=50, step=10, label="Max Frames to Process") | |
| v4_duration = gr.Radio([60, 120], value=60, label="Max Processing Time (seconds)", info="Choose 60s for short clips, 120s for complex tasks") | |
| with gr.Row(): | |
| v4_btn = gr.Button("Start Object Tracking", variant="primary") | |
| v4_clear = gr.Button("Reset Tracking") | |
| # États pour stocker les points multiples | |
| v4_points_state = gr.State([]) | |
| v4_labels_state = gr.State([]) | |
| with gr.Column(): | |
| v4_output = gr.Video(label="Result Video") | |
| v4_status = gr.Textbox(label="Status") | |
| # --- CORRECTION ICI --- | |
| # Fusion des deux événements pour éviter le conflit (affichage vs reset) | |
| def on_video_upload(video_path): | |
| # 1. On récupère l'image | |
| frame = get_first_frame(video_path) | |
| # 2. On reset les états (points et labels vides) | |
| # Retourne : Image, Points vides, Labels vides | |
| return frame, [], [] | |
| v4_input.change(on_video_upload, inputs=v4_input, outputs=[v4_frame0, v4_points_state, v4_labels_state]) | |
| # ---------------------- | |
| # 1. Clic -> Ajout point visuel (CPU) + Mise à jour State | |
| v4_frame0.select( | |
| add_point_video_preview, | |
| inputs=[v4_input, v4_points_state, v4_labels_state], | |
| outputs=[v4_frame0, v4_points_state, v4_labels_state] | |
| ) | |
| # 2. Bouton Start -> Envoi de la liste complète des points au GPU | |
| v4_btn.click(process_video_tracker_gpu, [v4_input, v4_points_state, v4_labels_state, v4_max_frames, v4_duration], [v4_output, v4_status]) | |
| # 3. Bouton Reset -> Vide les points, recharge l'image vierge | |
| def reset_tracking_view(video_path): | |
| img = get_first_frame(video_path) | |
| return None, "", [], [], img | |
| v4_clear.click(reset_tracking_view, inputs=[v4_input], outputs=[v4_output, v4_status, v4_points_state, v4_labels_state, v4_frame0]) | |
| if __name__ == "__main__": | |
| demo.launch(share=False, debug=True, theme=gr.themes.Soft()) |