Spaces:
Running
on
Zero
Running
on
Zero
Clean up
Browse files
app.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
import colorsys
|
| 2 |
import gc
|
| 3 |
-
from typing import Optional
|
| 4 |
|
| 5 |
import cv2
|
| 6 |
import gradio as gr
|
|
@@ -10,24 +9,15 @@ from gradio.themes import Soft
|
|
| 10 |
from PIL import Image, ImageDraw, ImageFont
|
| 11 |
from transformers import Sam3TrackerVideoModel, Sam3TrackerVideoProcessor, Sam3VideoModel, Sam3VideoProcessor
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
dtype = torch.bfloat16
|
| 17 |
-
return device, dtype
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
_GLOBAL_MODEL_REPO_ID = "facebook/sam3"
|
| 22 |
-
|
| 23 |
-
_GLOBAL_TRACKER_MODEL = Sam3TrackerVideoModel.from_pretrained(
|
| 24 |
-
_GLOBAL_MODEL_REPO_ID, torch_dtype=_GLOBAL_DTYPE, device_map=_GLOBAL_DEVICE
|
| 25 |
-
).eval()
|
| 26 |
-
_GLOBAL_TRACKER_PROCESSOR = Sam3TrackerVideoProcessor.from_pretrained(_GLOBAL_MODEL_REPO_ID)
|
| 27 |
-
|
| 28 |
-
_GLOBAL_TEXT_VIDEO_MODEL = Sam3VideoModel.from_pretrained(_GLOBAL_MODEL_REPO_ID)
|
| 29 |
-
_GLOBAL_TEXT_VIDEO_MODEL = _GLOBAL_TEXT_VIDEO_MODEL.to(_GLOBAL_DEVICE, dtype=_GLOBAL_DTYPE).eval()
|
| 30 |
-
_GLOBAL_TEXT_VIDEO_PROCESSOR = Sam3VideoProcessor.from_pretrained(_GLOBAL_MODEL_REPO_ID)
|
| 31 |
print("Models loaded successfully!")
|
| 32 |
|
| 33 |
|
|
@@ -149,9 +139,6 @@ def init_video_session(
|
|
| 149 |
GLOBAL_STATE.inference_session = None
|
| 150 |
GLOBAL_STATE.active_tab = active_tab
|
| 151 |
|
| 152 |
-
device = _GLOBAL_DEVICE
|
| 153 |
-
dtype = _GLOBAL_DTYPE
|
| 154 |
-
|
| 155 |
video_path: str | None = None
|
| 156 |
if isinstance(video, dict):
|
| 157 |
video_path = video.get("name") or video.get("path") or video.get("data")
|
|
@@ -182,23 +169,23 @@ def init_video_session(
|
|
| 182 |
raw_video = [np.array(frame) for frame in frames]
|
| 183 |
|
| 184 |
if active_tab == "text":
|
| 185 |
-
processor =
|
| 186 |
GLOBAL_STATE.inference_session = processor.init_video_session(
|
| 187 |
video=frames,
|
| 188 |
-
inference_device=
|
| 189 |
processing_device="cpu",
|
| 190 |
video_storage_device="cpu",
|
| 191 |
-
dtype=
|
| 192 |
)
|
| 193 |
else:
|
| 194 |
-
processor =
|
| 195 |
GLOBAL_STATE.inference_session = processor.init_video_session(
|
| 196 |
video=raw_video,
|
| 197 |
-
inference_device=
|
| 198 |
video_storage_device="cpu",
|
| 199 |
processing_device="cpu",
|
| 200 |
-
inference_state_device=
|
| 201 |
-
dtype=
|
| 202 |
)
|
| 203 |
|
| 204 |
first_frame = frames[0]
|
|
@@ -206,12 +193,12 @@ def init_video_session(
|
|
| 206 |
if active_tab == "text":
|
| 207 |
status = (
|
| 208 |
f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. "
|
| 209 |
-
f"Device: {
|
| 210 |
)
|
| 211 |
else:
|
| 212 |
status = (
|
| 213 |
f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. "
|
| 214 |
-
f"Device: {
|
| 215 |
)
|
| 216 |
return GLOBAL_STATE, 0, max_idx, first_frame, status
|
| 217 |
|
|
@@ -384,8 +371,8 @@ def on_image_click(
|
|
| 384 |
if state is None or state.inference_session is None:
|
| 385 |
return img
|
| 386 |
|
| 387 |
-
model =
|
| 388 |
-
processor =
|
| 389 |
|
| 390 |
x = y = None
|
| 391 |
if evt is not None:
|
|
@@ -492,8 +479,8 @@ def on_text_prompt(
|
|
| 492 |
if state is None or state.inference_session is None:
|
| 493 |
return None, "Upload a video and enter text prompt.", "**Active prompts:** None"
|
| 494 |
|
| 495 |
-
model =
|
| 496 |
-
processor =
|
| 497 |
|
| 498 |
if not text_prompt or not text_prompt.strip():
|
| 499 |
active_prompts = _get_active_prompts_display(state)
|
|
@@ -626,8 +613,8 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 626 |
yield GLOBAL_STATE, "Text video model not loaded.", gr.update()
|
| 627 |
return
|
| 628 |
|
| 629 |
-
model =
|
| 630 |
-
processor =
|
| 631 |
|
| 632 |
# Collect all unique prompts from existing frame annotations
|
| 633 |
text_prompt_to_obj_ids = {}
|
|
@@ -723,8 +710,8 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 723 |
yield GLOBAL_STATE, "Tracker model not loaded.", gr.update()
|
| 724 |
return
|
| 725 |
|
| 726 |
-
model =
|
| 727 |
-
processor =
|
| 728 |
|
| 729 |
for sam2_video_output in model.propagate_in_video_iterator(
|
| 730 |
inference_session=GLOBAL_STATE.inference_session
|
|
@@ -826,27 +813,27 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
|
|
| 826 |
|
| 827 |
if GLOBAL_STATE.active_tab == "text":
|
| 828 |
if GLOBAL_STATE.video_frames:
|
| 829 |
-
processor =
|
| 830 |
GLOBAL_STATE.inference_session = processor.init_video_session(
|
| 831 |
video=GLOBAL_STATE.video_frames,
|
| 832 |
-
inference_device=
|
| 833 |
processing_device="cpu",
|
| 834 |
video_storage_device="cpu",
|
| 835 |
-
dtype=
|
| 836 |
)
|
| 837 |
elif GLOBAL_STATE.inference_session is not None and hasattr(
|
| 838 |
GLOBAL_STATE.inference_session, "reset_inference_session"
|
| 839 |
):
|
| 840 |
GLOBAL_STATE.inference_session.reset_inference_session()
|
| 841 |
elif GLOBAL_STATE.video_frames:
|
| 842 |
-
processor =
|
| 843 |
raw_video = [np.array(frame) for frame in GLOBAL_STATE.video_frames]
|
| 844 |
GLOBAL_STATE.inference_session = processor.init_video_session(
|
| 845 |
video=raw_video,
|
| 846 |
-
inference_device=
|
| 847 |
video_storage_device="cpu",
|
| 848 |
processing_device="cpu",
|
| 849 |
-
dtype=
|
| 850 |
)
|
| 851 |
|
| 852 |
GLOBAL_STATE.masks_by_frame.clear()
|
|
@@ -894,9 +881,7 @@ def _on_video_change_text(GLOBAL_STATE: gr.State, video):
|
|
| 894 |
)
|
| 895 |
|
| 896 |
|
| 897 |
-
|
| 898 |
-
|
| 899 |
-
with gr.Blocks(title="SAM3", theme=theme) as demo:
|
| 900 |
GLOBAL_STATE = gr.State(AppState())
|
| 901 |
|
| 902 |
gr.Markdown(
|
|
|
|
| 1 |
import colorsys
|
| 2 |
import gc
|
|
|
|
| 3 |
|
| 4 |
import cv2
|
| 5 |
import gradio as gr
|
|
|
|
| 9 |
from PIL import Image, ImageDraw, ImageFont
|
| 10 |
from transformers import Sam3TrackerVideoModel, Sam3TrackerVideoProcessor, Sam3VideoModel, Sam3VideoProcessor
|
| 11 |
|
| 12 |
+
MODEL_ID = "facebook/sam3"
|
| 13 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 14 |
+
DTYPE = torch.bfloat16
|
| 15 |
|
| 16 |
+
TRACKER_MODEL = Sam3TrackerVideoModel.from_pretrained(MODEL_ID, torch_dtype=DTYPE, device_map=DEVICE).eval()
|
| 17 |
+
TRACKER_PROCESSOR = Sam3TrackerVideoProcessor.from_pretrained(MODEL_ID)
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
TEXT_VIDEO_MODEL = Sam3VideoModel.from_pretrained(MODEL_ID).to(DEVICE, dtype=DTYPE).eval()
|
| 20 |
+
TEXT_VIDEO_PROCESSOR = Sam3VideoProcessor.from_pretrained(MODEL_ID)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
print("Models loaded successfully!")
|
| 22 |
|
| 23 |
|
|
|
|
| 139 |
GLOBAL_STATE.inference_session = None
|
| 140 |
GLOBAL_STATE.active_tab = active_tab
|
| 141 |
|
|
|
|
|
|
|
|
|
|
| 142 |
video_path: str | None = None
|
| 143 |
if isinstance(video, dict):
|
| 144 |
video_path = video.get("name") or video.get("path") or video.get("data")
|
|
|
|
| 169 |
raw_video = [np.array(frame) for frame in frames]
|
| 170 |
|
| 171 |
if active_tab == "text":
|
| 172 |
+
processor = TEXT_VIDEO_PROCESSOR
|
| 173 |
GLOBAL_STATE.inference_session = processor.init_video_session(
|
| 174 |
video=frames,
|
| 175 |
+
inference_device=DEVICE,
|
| 176 |
processing_device="cpu",
|
| 177 |
video_storage_device="cpu",
|
| 178 |
+
dtype=DTYPE,
|
| 179 |
)
|
| 180 |
else:
|
| 181 |
+
processor = TRACKER_PROCESSOR
|
| 182 |
GLOBAL_STATE.inference_session = processor.init_video_session(
|
| 183 |
video=raw_video,
|
| 184 |
+
inference_device=DEVICE,
|
| 185 |
video_storage_device="cpu",
|
| 186 |
processing_device="cpu",
|
| 187 |
+
inference_state_device=DEVICE,
|
| 188 |
+
dtype=DTYPE,
|
| 189 |
)
|
| 190 |
|
| 191 |
first_frame = frames[0]
|
|
|
|
| 193 |
if active_tab == "text":
|
| 194 |
status = (
|
| 195 |
f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. "
|
| 196 |
+
f"Device: {DEVICE}, dtype: bfloat16. Ready for text prompting."
|
| 197 |
)
|
| 198 |
else:
|
| 199 |
status = (
|
| 200 |
f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. "
|
| 201 |
+
f"Device: {DEVICE}, dtype: bfloat16. Video session initialized."
|
| 202 |
)
|
| 203 |
return GLOBAL_STATE, 0, max_idx, first_frame, status
|
| 204 |
|
|
|
|
| 371 |
if state is None or state.inference_session is None:
|
| 372 |
return img
|
| 373 |
|
| 374 |
+
model = TRACKER_MODEL
|
| 375 |
+
processor = TRACKER_PROCESSOR
|
| 376 |
|
| 377 |
x = y = None
|
| 378 |
if evt is not None:
|
|
|
|
| 479 |
if state is None or state.inference_session is None:
|
| 480 |
return None, "Upload a video and enter text prompt.", "**Active prompts:** None"
|
| 481 |
|
| 482 |
+
model = TEXT_VIDEO_MODEL
|
| 483 |
+
processor = TEXT_VIDEO_PROCESSOR
|
| 484 |
|
| 485 |
if not text_prompt or not text_prompt.strip():
|
| 486 |
active_prompts = _get_active_prompts_display(state)
|
|
|
|
| 613 |
yield GLOBAL_STATE, "Text video model not loaded.", gr.update()
|
| 614 |
return
|
| 615 |
|
| 616 |
+
model = TEXT_VIDEO_MODEL
|
| 617 |
+
processor = TEXT_VIDEO_PROCESSOR
|
| 618 |
|
| 619 |
# Collect all unique prompts from existing frame annotations
|
| 620 |
text_prompt_to_obj_ids = {}
|
|
|
|
| 710 |
yield GLOBAL_STATE, "Tracker model not loaded.", gr.update()
|
| 711 |
return
|
| 712 |
|
| 713 |
+
model = TRACKER_MODEL
|
| 714 |
+
processor = TRACKER_PROCESSOR
|
| 715 |
|
| 716 |
for sam2_video_output in model.propagate_in_video_iterator(
|
| 717 |
inference_session=GLOBAL_STATE.inference_session
|
|
|
|
| 813 |
|
| 814 |
if GLOBAL_STATE.active_tab == "text":
|
| 815 |
if GLOBAL_STATE.video_frames:
|
| 816 |
+
processor = TEXT_VIDEO_PROCESSOR
|
| 817 |
GLOBAL_STATE.inference_session = processor.init_video_session(
|
| 818 |
video=GLOBAL_STATE.video_frames,
|
| 819 |
+
inference_device=DEVICE,
|
| 820 |
processing_device="cpu",
|
| 821 |
video_storage_device="cpu",
|
| 822 |
+
dtype=DTYPE,
|
| 823 |
)
|
| 824 |
elif GLOBAL_STATE.inference_session is not None and hasattr(
|
| 825 |
GLOBAL_STATE.inference_session, "reset_inference_session"
|
| 826 |
):
|
| 827 |
GLOBAL_STATE.inference_session.reset_inference_session()
|
| 828 |
elif GLOBAL_STATE.video_frames:
|
| 829 |
+
processor = TRACKER_PROCESSOR
|
| 830 |
raw_video = [np.array(frame) for frame in GLOBAL_STATE.video_frames]
|
| 831 |
GLOBAL_STATE.inference_session = processor.init_video_session(
|
| 832 |
video=raw_video,
|
| 833 |
+
inference_device=DEVICE,
|
| 834 |
video_storage_device="cpu",
|
| 835 |
processing_device="cpu",
|
| 836 |
+
dtype=DTYPE,
|
| 837 |
)
|
| 838 |
|
| 839 |
GLOBAL_STATE.masks_by_frame.clear()
|
|
|
|
| 881 |
)
|
| 882 |
|
| 883 |
|
| 884 |
+
with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose", neutral_hue="slate")) as demo:
|
|
|
|
|
|
|
| 885 |
GLOBAL_STATE = gr.State(AppState())
|
| 886 |
|
| 887 |
gr.Markdown(
|