hysts HF Staff commited on
Commit
0551705
·
1 Parent(s): 2ca03b4
Files changed (1) hide show
  1. app.py +31 -46
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import colorsys
2
  import gc
3
- from typing import Optional
4
 
5
  import cv2
6
  import gradio as gr
@@ -10,24 +9,15 @@ from gradio.themes import Soft
10
  from PIL import Image, ImageDraw, ImageFont
11
  from transformers import Sam3TrackerVideoModel, Sam3TrackerVideoProcessor, Sam3VideoModel, Sam3VideoProcessor
12
 
 
 
 
13
 
14
- def get_device_and_dtype() -> tuple[str, torch.dtype]:
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
- dtype = torch.bfloat16
17
- return device, dtype
18
 
19
-
20
- _GLOBAL_DEVICE, _GLOBAL_DTYPE = get_device_and_dtype()
21
- _GLOBAL_MODEL_REPO_ID = "facebook/sam3"
22
-
23
- _GLOBAL_TRACKER_MODEL = Sam3TrackerVideoModel.from_pretrained(
24
- _GLOBAL_MODEL_REPO_ID, torch_dtype=_GLOBAL_DTYPE, device_map=_GLOBAL_DEVICE
25
- ).eval()
26
- _GLOBAL_TRACKER_PROCESSOR = Sam3TrackerVideoProcessor.from_pretrained(_GLOBAL_MODEL_REPO_ID)
27
-
28
- _GLOBAL_TEXT_VIDEO_MODEL = Sam3VideoModel.from_pretrained(_GLOBAL_MODEL_REPO_ID)
29
- _GLOBAL_TEXT_VIDEO_MODEL = _GLOBAL_TEXT_VIDEO_MODEL.to(_GLOBAL_DEVICE, dtype=_GLOBAL_DTYPE).eval()
30
- _GLOBAL_TEXT_VIDEO_PROCESSOR = Sam3VideoProcessor.from_pretrained(_GLOBAL_MODEL_REPO_ID)
31
  print("Models loaded successfully!")
32
 
33
 
@@ -149,9 +139,6 @@ def init_video_session(
149
  GLOBAL_STATE.inference_session = None
150
  GLOBAL_STATE.active_tab = active_tab
151
 
152
- device = _GLOBAL_DEVICE
153
- dtype = _GLOBAL_DTYPE
154
-
155
  video_path: str | None = None
156
  if isinstance(video, dict):
157
  video_path = video.get("name") or video.get("path") or video.get("data")
@@ -182,23 +169,23 @@ def init_video_session(
182
  raw_video = [np.array(frame) for frame in frames]
183
 
184
  if active_tab == "text":
185
- processor = _GLOBAL_TEXT_VIDEO_PROCESSOR
186
  GLOBAL_STATE.inference_session = processor.init_video_session(
187
  video=frames,
188
- inference_device=device,
189
  processing_device="cpu",
190
  video_storage_device="cpu",
191
- dtype=dtype,
192
  )
193
  else:
194
- processor = _GLOBAL_TRACKER_PROCESSOR
195
  GLOBAL_STATE.inference_session = processor.init_video_session(
196
  video=raw_video,
197
- inference_device=device,
198
  video_storage_device="cpu",
199
  processing_device="cpu",
200
- inference_state_device=device,
201
- dtype=dtype,
202
  )
203
 
204
  first_frame = frames[0]
@@ -206,12 +193,12 @@ def init_video_session(
206
  if active_tab == "text":
207
  status = (
208
  f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. "
209
- f"Device: {device}, dtype: bfloat16. Ready for text prompting."
210
  )
211
  else:
212
  status = (
213
  f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. "
214
- f"Device: {device}, dtype: bfloat16. Video session initialized."
215
  )
216
  return GLOBAL_STATE, 0, max_idx, first_frame, status
217
 
@@ -384,8 +371,8 @@ def on_image_click(
384
  if state is None or state.inference_session is None:
385
  return img
386
 
387
- model = _GLOBAL_TRACKER_MODEL
388
- processor = _GLOBAL_TRACKER_PROCESSOR
389
 
390
  x = y = None
391
  if evt is not None:
@@ -492,8 +479,8 @@ def on_text_prompt(
492
  if state is None or state.inference_session is None:
493
  return None, "Upload a video and enter text prompt.", "**Active prompts:** None"
494
 
495
- model = _GLOBAL_TEXT_VIDEO_MODEL
496
- processor = _GLOBAL_TEXT_VIDEO_PROCESSOR
497
 
498
  if not text_prompt or not text_prompt.strip():
499
  active_prompts = _get_active_prompts_display(state)
@@ -626,8 +613,8 @@ def propagate_masks(GLOBAL_STATE: gr.State):
626
  yield GLOBAL_STATE, "Text video model not loaded.", gr.update()
627
  return
628
 
629
- model = _GLOBAL_TEXT_VIDEO_MODEL
630
- processor = _GLOBAL_TEXT_VIDEO_PROCESSOR
631
 
632
  # Collect all unique prompts from existing frame annotations
633
  text_prompt_to_obj_ids = {}
@@ -723,8 +710,8 @@ def propagate_masks(GLOBAL_STATE: gr.State):
723
  yield GLOBAL_STATE, "Tracker model not loaded.", gr.update()
724
  return
725
 
726
- model = _GLOBAL_TRACKER_MODEL
727
- processor = _GLOBAL_TRACKER_PROCESSOR
728
 
729
  for sam2_video_output in model.propagate_in_video_iterator(
730
  inference_session=GLOBAL_STATE.inference_session
@@ -826,27 +813,27 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
826
 
827
  if GLOBAL_STATE.active_tab == "text":
828
  if GLOBAL_STATE.video_frames:
829
- processor = _GLOBAL_TEXT_VIDEO_PROCESSOR
830
  GLOBAL_STATE.inference_session = processor.init_video_session(
831
  video=GLOBAL_STATE.video_frames,
832
- inference_device=_GLOBAL_DEVICE,
833
  processing_device="cpu",
834
  video_storage_device="cpu",
835
- dtype=_GLOBAL_DTYPE,
836
  )
837
  elif GLOBAL_STATE.inference_session is not None and hasattr(
838
  GLOBAL_STATE.inference_session, "reset_inference_session"
839
  ):
840
  GLOBAL_STATE.inference_session.reset_inference_session()
841
  elif GLOBAL_STATE.video_frames:
842
- processor = _GLOBAL_TRACKER_PROCESSOR
843
  raw_video = [np.array(frame) for frame in GLOBAL_STATE.video_frames]
844
  GLOBAL_STATE.inference_session = processor.init_video_session(
845
  video=raw_video,
846
- inference_device=_GLOBAL_DEVICE,
847
  video_storage_device="cpu",
848
  processing_device="cpu",
849
- dtype=_GLOBAL_DTYPE,
850
  )
851
 
852
  GLOBAL_STATE.masks_by_frame.clear()
@@ -894,9 +881,7 @@ def _on_video_change_text(GLOBAL_STATE: gr.State, video):
894
  )
895
 
896
 
897
- theme = Soft(primary_hue="blue", secondary_hue="rose", neutral_hue="slate")
898
-
899
- with gr.Blocks(title="SAM3", theme=theme) as demo:
900
  GLOBAL_STATE = gr.State(AppState())
901
 
902
  gr.Markdown(
 
1
  import colorsys
2
  import gc
 
3
 
4
  import cv2
5
  import gradio as gr
 
9
  from PIL import Image, ImageDraw, ImageFont
10
  from transformers import Sam3TrackerVideoModel, Sam3TrackerVideoProcessor, Sam3VideoModel, Sam3VideoProcessor
11
 
12
+ MODEL_ID = "facebook/sam3"
13
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
+ DTYPE = torch.bfloat16
15
 
16
+ TRACKER_MODEL = Sam3TrackerVideoModel.from_pretrained(MODEL_ID, torch_dtype=DTYPE, device_map=DEVICE).eval()
17
+ TRACKER_PROCESSOR = Sam3TrackerVideoProcessor.from_pretrained(MODEL_ID)
 
 
18
 
19
+ TEXT_VIDEO_MODEL = Sam3VideoModel.from_pretrained(MODEL_ID).to(DEVICE, dtype=DTYPE).eval()
20
+ TEXT_VIDEO_PROCESSOR = Sam3VideoProcessor.from_pretrained(MODEL_ID)
 
 
 
 
 
 
 
 
 
 
21
  print("Models loaded successfully!")
22
 
23
 
 
139
  GLOBAL_STATE.inference_session = None
140
  GLOBAL_STATE.active_tab = active_tab
141
 
 
 
 
142
  video_path: str | None = None
143
  if isinstance(video, dict):
144
  video_path = video.get("name") or video.get("path") or video.get("data")
 
169
  raw_video = [np.array(frame) for frame in frames]
170
 
171
  if active_tab == "text":
172
+ processor = TEXT_VIDEO_PROCESSOR
173
  GLOBAL_STATE.inference_session = processor.init_video_session(
174
  video=frames,
175
+ inference_device=DEVICE,
176
  processing_device="cpu",
177
  video_storage_device="cpu",
178
+ dtype=DTYPE,
179
  )
180
  else:
181
+ processor = TRACKER_PROCESSOR
182
  GLOBAL_STATE.inference_session = processor.init_video_session(
183
  video=raw_video,
184
+ inference_device=DEVICE,
185
  video_storage_device="cpu",
186
  processing_device="cpu",
187
+ inference_state_device=DEVICE,
188
+ dtype=DTYPE,
189
  )
190
 
191
  first_frame = frames[0]
 
193
  if active_tab == "text":
194
  status = (
195
  f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. "
196
+ f"Device: {DEVICE}, dtype: bfloat16. Ready for text prompting."
197
  )
198
  else:
199
  status = (
200
  f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. "
201
+ f"Device: {DEVICE}, dtype: bfloat16. Video session initialized."
202
  )
203
  return GLOBAL_STATE, 0, max_idx, first_frame, status
204
 
 
371
  if state is None or state.inference_session is None:
372
  return img
373
 
374
+ model = TRACKER_MODEL
375
+ processor = TRACKER_PROCESSOR
376
 
377
  x = y = None
378
  if evt is not None:
 
479
  if state is None or state.inference_session is None:
480
  return None, "Upload a video and enter text prompt.", "**Active prompts:** None"
481
 
482
+ model = TEXT_VIDEO_MODEL
483
+ processor = TEXT_VIDEO_PROCESSOR
484
 
485
  if not text_prompt or not text_prompt.strip():
486
  active_prompts = _get_active_prompts_display(state)
 
613
  yield GLOBAL_STATE, "Text video model not loaded.", gr.update()
614
  return
615
 
616
+ model = TEXT_VIDEO_MODEL
617
+ processor = TEXT_VIDEO_PROCESSOR
618
 
619
  # Collect all unique prompts from existing frame annotations
620
  text_prompt_to_obj_ids = {}
 
710
  yield GLOBAL_STATE, "Tracker model not loaded.", gr.update()
711
  return
712
 
713
+ model = TRACKER_MODEL
714
+ processor = TRACKER_PROCESSOR
715
 
716
  for sam2_video_output in model.propagate_in_video_iterator(
717
  inference_session=GLOBAL_STATE.inference_session
 
813
 
814
  if GLOBAL_STATE.active_tab == "text":
815
  if GLOBAL_STATE.video_frames:
816
+ processor = TEXT_VIDEO_PROCESSOR
817
  GLOBAL_STATE.inference_session = processor.init_video_session(
818
  video=GLOBAL_STATE.video_frames,
819
+ inference_device=DEVICE,
820
  processing_device="cpu",
821
  video_storage_device="cpu",
822
+ dtype=DTYPE,
823
  )
824
  elif GLOBAL_STATE.inference_session is not None and hasattr(
825
  GLOBAL_STATE.inference_session, "reset_inference_session"
826
  ):
827
  GLOBAL_STATE.inference_session.reset_inference_session()
828
  elif GLOBAL_STATE.video_frames:
829
+ processor = TRACKER_PROCESSOR
830
  raw_video = [np.array(frame) for frame in GLOBAL_STATE.video_frames]
831
  GLOBAL_STATE.inference_session = processor.init_video_session(
832
  video=raw_video,
833
+ inference_device=DEVICE,
834
  video_storage_device="cpu",
835
  processing_device="cpu",
836
+ dtype=DTYPE,
837
  )
838
 
839
  GLOBAL_STATE.masks_by_frame.clear()
 
881
  )
882
 
883
 
884
+ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose", neutral_hue="slate")) as demo:
 
 
885
  GLOBAL_STATE = gr.State(AppState())
886
 
887
  gr.Markdown(