hysts HF Staff commited on
Commit
dd289a8
·
1 Parent(s): c5ddb51

Add missing type annotation

Browse files
Files changed (1) hide show
  1. app.py +29 -16
app.py CHANGED
@@ -1017,7 +1017,7 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
1017
  show_progress=True,
1018
  )
1019
 
1020
- def _sync_frame_idx_pointbox(state_in: AppState, idx: int):
1021
  if state_in is not None:
1022
  state_in.current_frame_idx = int(idx)
1023
  return update_frame_display(state_in, int(idx))
@@ -1035,7 +1035,7 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
1035
  show_progress=True,
1036
  )
1037
 
1038
- def _sync_frame_idx_text(state_in: AppState, idx: int):
1039
  if state_in is not None:
1040
  state_in.current_frame_idx = int(idx)
1041
  return update_frame_display(state_in, int(idx))
@@ -1046,28 +1046,33 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
1046
  outputs=preview_text,
1047
  )
1048
 
1049
- def _sync_obj_id(s: AppState, oid):
1050
  if s is not None and oid is not None:
1051
  s.current_obj_id = int(oid)
1052
 
1053
- obj_id_inp.change(fn=_sync_obj_id, inputs=[app_state, obj_id_inp])
 
 
 
1054
 
1055
- def _sync_label(s: AppState, lab: str):
1056
  if s is not None and lab is not None:
1057
  s.current_label = str(lab)
1058
 
1059
- label_radio.change(fn=_sync_label, inputs=[app_state, label_radio])
 
 
 
1060
 
1061
- def _sync_prompt_type(s: AppState, val: str):
1062
  if s is not None and val is not None:
1063
  s.current_prompt_type = str(val)
1064
  s.pending_box_start = None
1065
  is_points = str(val).lower() == "points"
1066
- updates = [
1067
  gr.update(visible=is_points),
1068
  gr.update(interactive=is_points) if is_points else gr.update(value=True, interactive=False),
1069
- ]
1070
- return updates
1071
 
1072
  prompt_type.change(
1073
  fn=_sync_prompt_type,
@@ -1081,7 +1086,7 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
1081
  outputs=preview_pointbox,
1082
  )
1083
 
1084
- def _on_text_apply(state: AppState, frame_idx: int, text: str):
1085
  img, status, active_prompts = on_text_prompt(state, frame_idx, text)
1086
  return img, status, active_prompts
1087
 
@@ -1093,11 +1098,11 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
1093
 
1094
  reset_prompts_btn.click(
1095
  fn=reset_prompts,
1096
- inputs=[app_state],
1097
  outputs=[app_state, preview_text, text_status, active_prompts_display],
1098
  )
1099
 
1100
- def _render_video(s: AppState):
1101
  if s is None or s.num_frames == 0:
1102
  raise gr.Error("Load a video first.")
1103
  fps = s.video_fps if s.video_fps and s.video_fps > 0 else 12
@@ -1123,12 +1128,20 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
1123
  print(f"Failed to render video with cv2: {e}")
1124
  raise gr.Error(f"Failed to render video: {e}")
1125
 
1126
- render_btn_pointbox.click(_render_video, inputs=[app_state], outputs=[playback_video_pointbox])
1127
- render_btn_text.click(_render_video, inputs=[app_state], outputs=[playback_video_text])
 
 
 
 
 
 
 
 
1128
 
1129
  propagate_btn_pointbox.click(
1130
  fn=propagate_masks,
1131
- inputs=[app_state],
1132
  outputs=[app_state, propagate_status_pointbox, frame_slider_pointbox],
1133
  )
1134
 
 
1017
  show_progress=True,
1018
  )
1019
 
1020
+ def _sync_frame_idx_pointbox(state_in: AppState, idx: int) -> Image.Image:
1021
  if state_in is not None:
1022
  state_in.current_frame_idx = int(idx)
1023
  return update_frame_display(state_in, int(idx))
 
1035
  show_progress=True,
1036
  )
1037
 
1038
+ def _sync_frame_idx_text(state_in: AppState, idx: int) -> Image.Image:
1039
  if state_in is not None:
1040
  state_in.current_frame_idx = int(idx)
1041
  return update_frame_display(state_in, int(idx))
 
1046
  outputs=preview_text,
1047
  )
1048
 
1049
+ def _sync_obj_id(s: AppState, oid) -> None:
1050
  if s is not None and oid is not None:
1051
  s.current_obj_id = int(oid)
1052
 
1053
+ obj_id_inp.change(
1054
+ fn=_sync_obj_id,
1055
+ inputs=[app_state, obj_id_inp],
1056
+ )
1057
 
1058
+ def _sync_label(s: AppState, lab: str) -> None:
1059
  if s is not None and lab is not None:
1060
  s.current_label = str(lab)
1061
 
1062
+ label_radio.change(
1063
+ fn=_sync_label,
1064
+ inputs=[app_state, label_radio],
1065
+ )
1066
 
1067
+ def _sync_prompt_type(s: AppState, val: str) -> tuple[dict, dict]:
1068
  if s is not None and val is not None:
1069
  s.current_prompt_type = str(val)
1070
  s.pending_box_start = None
1071
  is_points = str(val).lower() == "points"
1072
+ return (
1073
  gr.update(visible=is_points),
1074
  gr.update(interactive=is_points) if is_points else gr.update(value=True, interactive=False),
1075
+ )
 
1076
 
1077
  prompt_type.change(
1078
  fn=_sync_prompt_type,
 
1086
  outputs=preview_pointbox,
1087
  )
1088
 
1089
+ def _on_text_apply(state: AppState, frame_idx: int, text: str) -> tuple[Image.Image, str, str]:
1090
  img, status, active_prompts = on_text_prompt(state, frame_idx, text)
1091
  return img, status, active_prompts
1092
 
 
1098
 
1099
  reset_prompts_btn.click(
1100
  fn=reset_prompts,
1101
+ inputs=app_state,
1102
  outputs=[app_state, preview_text, text_status, active_prompts_display],
1103
  )
1104
 
1105
+ def _render_video(s: AppState) -> str:
1106
  if s is None or s.num_frames == 0:
1107
  raise gr.Error("Load a video first.")
1108
  fps = s.video_fps if s.video_fps and s.video_fps > 0 else 12
 
1128
  print(f"Failed to render video with cv2: {e}")
1129
  raise gr.Error(f"Failed to render video: {e}")
1130
 
1131
+ render_btn_pointbox.click(
1132
+ fn=_render_video,
1133
+ inputs=app_state,
1134
+ outputs=playback_video_pointbox,
1135
+ )
1136
+ render_btn_text.click(
1137
+ fn=_render_video,
1138
+ inputs=app_state,
1139
+ outputs=playback_video_text,
1140
+ )
1141
 
1142
  propagate_btn_pointbox.click(
1143
  fn=propagate_masks,
1144
+ inputs=app_state,
1145
  outputs=[app_state, propagate_status_pointbox, frame_slider_pointbox],
1146
  )
1147