Spaces:
Running
on
Zero
Running
on
Zero
File size: 19,739 Bytes
2b120f6 5e513b8 2b120f6 5e513b8 2b120f6 5e513b8 2b120f6 5e513b8 2b120f6 8e127b7 2b120f6 5e513b8 2b120f6 8e127b7 2b120f6 5e513b8 2b120f6 5e513b8 2b120f6 5e513b8 2b120f6 8e127b7 2b120f6 16bc89d 2b120f6 5e513b8 2b120f6 5e513b8 16bc89d 2b120f6 8f0a7b7 5e513b8 2b120f6 5e513b8 8e127b7 5e513b8 8f0a7b7 5e513b8 8f0a7b7 5e513b8 8f0a7b7 5e513b8 2b120f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 |
import gradio as gr
import torch
import numpy as np
import cv2
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import matplotlib
import tempfile
import os
import gc
import spaces # <--- IMPORT OBLIGATOIRE POUR ZEROGPU
from transformers import (
Sam3Model, Sam3Processor,
Sam3TrackerModel, Sam3TrackerProcessor,
Sam3VideoModel, Sam3VideoProcessor,
Sam3TrackerVideoModel, Sam3TrackerVideoProcessor
)
# --- CONFIGURATION ---
MODELS = {}
# Sur ZeroGPU, on peut forcer "cuda" car le décorateur nous garantit un GPU.
device = "cuda"
print(f"🖥️ Configuration ZeroGPU active.")
def cleanup_memory():
"""Force le nettoyage de la VRAM."""
if MODELS:
print("🧹 Nettoyage préventif mémoire...")
MODELS.clear()
gc.collect()
torch.cuda.empty_cache()
def get_model(model_type):
"""Charge le modèle. Sur ZeroGPU, cela se produit à l'intérieur de la fonction décorée."""
if model_type in MODELS:
return MODELS[model_type]
# Même avec 70GB de VRAM sur ZeroGPU, garder le swap est une bonne pratique de stabilité
cleanup_memory()
print(f"⏳ Chargement de {model_type} sur H200...")
try:
if model_type == "sam3_image_text":
model = Sam3Model.from_pretrained("facebook/sam3").to(device)
processor = Sam3Processor.from_pretrained("facebook/sam3")
MODELS[model_type] = (model, processor)
elif model_type == "sam3_image_tracker":
model = Sam3TrackerModel.from_pretrained("facebook/sam3").to(device)
processor = Sam3TrackerProcessor.from_pretrained("facebook/sam3")
MODELS[model_type] = (model, processor)
elif model_type == "sam3_video_text":
# Sur H200, on peut se permettre le float32, mais bfloat16 reste plus rapide
model = Sam3VideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16)
processor = Sam3VideoProcessor.from_pretrained("facebook/sam3")
MODELS[model_type] = (model, processor)
elif model_type == "sam3_video_tracker":
model = Sam3TrackerVideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16)
processor = Sam3TrackerVideoProcessor.from_pretrained("facebook/sam3")
MODELS[model_type] = (model, processor)
print(f"✅ {model_type} chargé.")
return MODELS[model_type]
except Exception as e:
print(f"❌ Erreur chargement : {e}")
cleanup_memory()
raise e
# --- UTILITAIRES (Pas besoin de GPU ici) ---
def overlay_masks(image, masks, scores=None, alpha=0.5):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = image.convert("RGBA")
if masks is None or len(masks) == 0:
return image
if isinstance(masks, torch.Tensor):
masks = masks.cpu().numpy()
masks = masks.astype(np.uint8)
if masks.ndim == 4: masks = masks[0]
if masks.ndim == 3 and masks.shape[0] == 1: masks = masks[0]
n_masks = masks.shape[0] if masks.ndim == 3 else 1
if masks.ndim == 2:
masks = [masks]
n_masks = 1
try:
cmap = matplotlib.colormaps["rainbow"].resampled(max(n_masks, 1))
except AttributeError:
import matplotlib.cm as cm
cmap = cm.get_cmap("rainbow").resampled(max(n_masks, 1))
colors = [tuple(int(c * 255) for c in cmap(i)[:3]) for i in range(n_masks)]
overlay_layer = Image.new("RGBA", image.size, (0, 0, 0, 0))
for i, mask in enumerate(masks):
mask_img = Image.fromarray((mask * 255).astype(np.uint8))
if mask_img.size != image.size:
mask_img = mask_img.resize(image.size, resample=Image.NEAREST)
color = colors[i]
color_layer = Image.new("RGBA", image.size, color + (0,))
mask_alpha = mask_img.point(lambda v: int(v * alpha) if v > 0 else 0)
color_layer.putalpha(mask_alpha)
overlay_layer = Image.alpha_composite(overlay_layer, color_layer)
return Image.alpha_composite(image, overlay_layer).convert("RGB")
def get_first_frame(video_path):
"""Extrait la première frame d'une vidéo pour permettre le clic."""
if not video_path: return None
cap = cv2.VideoCapture(video_path)
ret, frame = cap.read()
cap.release()
if ret:
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return None
def draw_points_on_image(image, points):
"""Dessine des points rouges sur l'image pour feedback visuel."""
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Créer une copie pour dessiner
draw_img = image.copy()
draw = ImageDraw.Draw(draw_img)
for pt in points:
x, y = pt
r = 5
draw.ellipse((x-r, y-r, x+r, y+r), fill="red", outline="white")
return draw_img
# --- HELPERS POUR DUREE DYNAMIQUE ZEROGPU ---
def compute_duration_text(video_path, text_prompt, max_frames, timeout_seconds):
return timeout_seconds
def compute_duration_tracker(video_path, points_state, labels_state, max_frames, timeout_seconds):
return timeout_seconds
# --- LOGIQUE AVEC DÉCORATEURS ZEROGPU ---
@spaces.GPU
def process_image_text(image, text_prompt, threshold, mask_threshold):
if image is None or not text_prompt:
return image, "Please provide an image and a text prompt."
try:
model, processor = get_model("sam3_image_text")
inputs = processor(images=image, text=text_prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_instance_segmentation(
outputs, threshold=threshold, mask_threshold=mask_threshold,
target_sizes=inputs.get("original_sizes").tolist()
)[0]
final_img = overlay_masks(image, results["masks"])
info = f"Objects found: {len(results['masks'])}\nScores: {results['scores'].cpu().numpy()}"
return final_img, info
except Exception as e:
return image, f"Error: {str(e)}"
# Image Tracker avec Multi-points
@spaces.GPU
def process_image_tracker_gpu(image, x, y, points_state, labels_state, multimask):
if image is None: return image, [], []
if points_state is None: points_state = []; labels_state = []
points_state.append([x, y])
labels_state.append(1)
try:
model, processor = get_model("sam3_image_tracker")
input_points = [[points_state]]
input_labels = [[labels_state]]
inputs = processor(images=image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs, multimask_output=multimask)
masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"], binarize=True)[0]
masks_to_show = masks[0]
if multimask and masks_to_show.shape[0] > 1:
scores = outputs.iou_scores.cpu().numpy()[0, 0]
best_idx = np.argmax(scores)
masks_to_show = masks_to_show[best_idx:best_idx+1]
final_img = overlay_masks(image, masks_to_show)
# Dessiner les points
final_img = draw_points_on_image(final_img, points_state)
return final_img, points_state, labels_state
except Exception as e:
print(f"Tracker Error: {e}")
return image, points_state, labels_state
def process_image_tracker_wrapper(image, evt: gr.SelectData, points_state, labels_state, multimask):
if evt is None: return image, points_state, labels_state
x, y = evt.index
return process_image_tracker_gpu(image, x, y, points_state, labels_state, multimask)
@spaces.GPU(duration=compute_duration_text)
def process_video_text(video_path, text_prompt, max_frames, timeout_seconds):
if not video_path or not text_prompt: return None, "Missing video or prompt."
try:
model, processor = get_model("sam3_video_text")
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
frames = []
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret or (max_frames > 0 and frame_count >= max_frames): break
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
frame_count += 1
cap.release()
inference_session = processor.init_video_session(video=frames, inference_device=device, dtype=torch.bfloat16)
inference_session = processor.add_text_prompt(inference_session=inference_session, text=text_prompt)
output_path = tempfile.mktemp(suffix=".mp4")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
for model_outputs in model.propagate_in_video_iterator(inference_session=inference_session, max_frame_num_to_track=len(frames)):
processed_outputs = processor.postprocess_outputs(inference_session, model_outputs)
frame_idx = model_outputs.frame_idx
orig_frame = Image.fromarray(frames[frame_idx])
if 'masks' in processed_outputs:
masks = processed_outputs['masks']
if masks.ndim == 4: masks = masks.squeeze(1)
res_frame = overlay_masks(orig_frame, masks)
else: res_frame = orig_frame
out.write(cv2.cvtColor(np.array(res_frame), cv2.COLOR_RGB2BGR))
out.release()
return output_path, "Done!"
except Exception as e: return None, f"Error: {str(e)}"
# --- VIDEO TRACKER MULTI-POINT ---
# Fonction CPU pour ajouter un point VISUELLEMENT (sans appeler le GPU)
def add_point_video_preview(video_path, evt: gr.SelectData, points_state, labels_state):
"""Ajoute un point à la liste et met à jour l'image de preview avec un point rouge."""
if not video_path: return None, points_state, labels_state
# Récupérer la frame originale brute (sans points)
# Pour faire simple ici, on la recharge à chaque fois.
# Optimisation possible: stocker l'image originale dans un State.
orig_frame = get_first_frame(video_path)
if orig_frame is None: return None, points_state, labels_state
orig_img = Image.fromarray(orig_frame)
x, y = evt.index
if points_state is None: points_state = []; labels_state = []
points_state.append([x, y])
labels_state.append(1)
# Dessiner TOUS les points sur l'image originale
preview_img = draw_points_on_image(orig_img, points_state)
return preview_img, points_state, labels_state
@spaces.GPU(duration=compute_duration_tracker)
def process_video_tracker_gpu(video_path, points_state, labels_state, max_frames, timeout_seconds):
if not video_path or not points_state: return None, "Please click on the frame first."
try:
model, processor = get_model("sam3_video_tracker")
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
frames = []
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret or (max_frames > 0 and frame_count >= max_frames): break
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
frame_count += 1
cap.release()
inference_session = processor.init_video_session(video=frames, inference_device=device, dtype=torch.bfloat16)
# Envoi de TOUS les points accumulés
input_points = [[points_state]] # [Obj=1 [Points...]]
input_labels = [[labels_state]]
processor.add_inputs_to_inference_session(
inference_session=inference_session,
frame_idx=0,
obj_ids=1,
input_points=input_points,
input_labels=input_labels
)
output_path = tempfile.mktemp(suffix=".mp4")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
model(inference_session=inference_session, frame_idx=0) # Verrouillage
for sam3_out in model.propagate_in_video_iterator(inference_session):
masks = processor.post_process_masks([sam3_out.pred_masks], original_sizes=[[height, width]], binarize=True)[0]
frame_idx = sam3_out.frame_idx
orig_frame = Image.fromarray(frames[frame_idx])
res_frame = overlay_masks(orig_frame, masks[:, 0, :, :])
out.write(cv2.cvtColor(np.array(res_frame), cv2.COLOR_RGB2BGR))
out.release()
return output_path, "Tracking Finished!"
except Exception as e:
print(f"Video Tracker Error: {e}")
return None, f"Fatal Error: {str(e)}"
# --- INTERFACE GRADIO ---
with gr.Blocks(title="SAM3 Ultimate Suite") as demo:
gr.Markdown("# 🚀 SAM 3 : Unified Promptable Segmentation")
gr.Markdown("This application allows you to utilize **all the powerful features of the SAM 3 model** for segmenting images and videos using text or visual prompts.")
with gr.Tabs():
# TAB 1 : IMAGE + TEXTE
with gr.Tab("🖼️ Image - Text Prompt"):
gr.Markdown("### Segment objects by description\nSimply upload an image and type the name of the objects you want to detect (e.g., 'cat', 'wheel', 'person'). The model will find all instances.")
with gr.Row():
with gr.Column():
i1_input = gr.Image(type="pil", label="Input Image")
i1_text = gr.Textbox(label="Text Prompt", placeholder="e.g.: cat, wheel, person")
with gr.Accordion("Advanced Settings", open=False):
i1_thresh = gr.Slider(0.0, 1.0, value=0.5, label="Confidence Threshold")
i1_mask_thresh = gr.Slider(0.0, 1.0, value=0.5, label="Mask Threshold")
i1_btn = gr.Button("Segment Image", variant="primary")
with gr.Column():
i1_output = gr.Image(type="pil", label="Result")
i1_info = gr.Textbox(label="Details", lines=2)
i1_btn.click(process_image_text, [i1_input, i1_text, i1_thresh, i1_mask_thresh], [i1_output, i1_info])
# TAB 2 : IMAGE + TRACKER
with gr.Tab("🖱️ Image - Visual Tracker"):
gr.Markdown("### Segment objects by clicking\nUpload an image and click on the object you wish to segment. You can click multiple times to refine the selection.")
with gr.Row():
with gr.Column():
i2_input = gr.Image(type="pil", label="Input Image (Click to add points)", interactive=True)
i2_multimask = gr.Checkbox(label="Return Multiple Masks (Handling Ambiguity)", value=False)
i2_clear = gr.Button("Reset Points")
points_state = gr.State([])
labels_state = gr.State([])
with gr.Column():
i2_output = gr.Image(type="pil", label="Interactive Result")
i2_input.select(process_image_tracker_wrapper, [i2_input, points_state, labels_state, i2_multimask], [i2_output, points_state, labels_state])
i2_clear.click(lambda: (None, [], []), outputs=[i2_output, points_state, labels_state])
# TAB 3 : VIDEO + TEXTE
with gr.Tab("🎥 Video - Text Prompt"):
gr.Markdown("### Track objects in video by description\nUpload a video and describe what to track. The model will detect and segment all matching objects throughout the video.")
with gr.Row():
with gr.Column():
v3_input = gr.Video(label="Input Video", format="mp4")
v3_text = gr.Textbox(label="Text Prompt", placeholder="e.g.: person, car")
v3_max_frames = gr.Slider(10, 1000, value=50, step=10, label="Max Frames to Process")
v3_duration = gr.Radio([60, 120], value=60, label="Max Processing Time (seconds)", info="Choose 60s for short clips, 120s for complex tasks")
v3_btn = gr.Button("Start Video Segmentation", variant="primary")
with gr.Column():
v3_output = gr.Video(label="Result Video")
v3_status = gr.Textbox(label="Status")
v3_btn.click(process_video_text, [v3_input, v3_text, v3_max_frames, v3_duration], [v3_output, v3_status])
# TAB 4 : VIDEO + TRACKER
with gr.Tab("🎯 Video - Visual Tracker"):
gr.Markdown("### Track a specific object in video (Multi-point Support)\n1. Upload a video.\n2. Click on the object in the 'First Frame'. **You can click multiple times** to refine the selection.\n3. Click 'Start Object Tracking' when ready.")
with gr.Row():
with gr.Column():
v4_input = gr.Video(label="Input Video", format="mp4")
v4_frame0 = gr.Image(label="First Frame (Click to add points)", interactive=True)
v4_max_frames = gr.Slider(10, 1000, value=50, step=10, label="Max Frames to Process")
v4_duration = gr.Radio([60, 120], value=60, label="Max Processing Time (seconds)", info="Choose 60s for short clips, 120s for complex tasks")
with gr.Row():
v4_btn = gr.Button("Start Object Tracking", variant="primary")
v4_clear = gr.Button("Reset Tracking")
# États pour stocker les points multiples
v4_points_state = gr.State([])
v4_labels_state = gr.State([])
with gr.Column():
v4_output = gr.Video(label="Result Video")
v4_status = gr.Textbox(label="Status")
# --- CORRECTION ICI ---
# Fusion des deux événements pour éviter le conflit (affichage vs reset)
def on_video_upload(video_path):
# 1. On récupère l'image
frame = get_first_frame(video_path)
# 2. On reset les états (points et labels vides)
# Retourne : Image, Points vides, Labels vides
return frame, [], []
v4_input.change(on_video_upload, inputs=v4_input, outputs=[v4_frame0, v4_points_state, v4_labels_state])
# ----------------------
# 1. Clic -> Ajout point visuel (CPU) + Mise à jour State
v4_frame0.select(
add_point_video_preview,
inputs=[v4_input, v4_points_state, v4_labels_state],
outputs=[v4_frame0, v4_points_state, v4_labels_state]
)
# 2. Bouton Start -> Envoi de la liste complète des points au GPU
v4_btn.click(process_video_tracker_gpu, [v4_input, v4_points_state, v4_labels_state, v4_max_frames, v4_duration], [v4_output, v4_status])
# 3. Bouton Reset -> Vide les points, recharge l'image vierge
def reset_tracking_view(video_path):
img = get_first_frame(video_path)
return None, "", [], [], img
v4_clear.click(reset_tracking_view, inputs=[v4_input], outputs=[v4_output, v4_status, v4_points_state, v4_labels_state, v4_frame0])
if __name__ == "__main__":
demo.launch(share=False, debug=True, theme=gr.themes.Soft()) |