savvy7007 commited on
Commit
1c10d3b
·
verified ·
1 Parent(s): 09fe63b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -181
app.py CHANGED
@@ -1,5 +1,5 @@
1
  # =========================
2
- # app.py (production-ready)
3
  # =========================
4
  import os
5
 
@@ -13,12 +13,27 @@ import cv2
13
  import tempfile
14
  import traceback
15
 
16
- # Lazy imports for heavy libs inside cached loaders to avoid early session init issues
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def _has_cuda():
18
  try:
19
  import torch
20
  return torch.cuda.is_available()
21
  except Exception:
 
22
  return False
23
 
24
  # -----------------------------------
@@ -66,34 +81,50 @@ def load_models():
66
  """
67
  Load InsightFace detectors and the inswapper model once.
68
  Auto-select GPU if available, else CPU.
 
69
  """
70
- # Defer heavy imports until Streamlit session is ready
71
  import insightface
72
  from insightface.app import FaceAnalysis
73
 
74
- # Providers for ONNX Runtime (insightface uses ORT under the hood)
75
- providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if _has_cuda() else ["CPUExecutionProvider"]
 
76
 
77
  # Face detector/landmarks (retinaface + arcface in buffalo_l)
 
78
  app = FaceAnalysis(name="buffalo_l")
79
- # ctx_id: 0 (GPU) or -1 (CPU)
80
- ctx_id = 0 if _has_cuda() else -1
81
  app.prepare(ctx_id=ctx_id, det_size=(640, 640))
82
 
83
  # Face swapper (inswapper_128)
84
- # Let insightface download the model if not present
85
- swapper = insightface.model_zoo.get_model(
86
- "inswapper_128.onnx",
87
- download=True,
88
- download_zip=False,
89
- providers=providers
90
- )
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  return app, swapper, providers, ctx_id
93
 
94
  # Initialize models
95
  with st.spinner("Loading models…"):
96
- app, swapper, providers, ctx_id = load_models()
 
 
 
 
97
 
98
  st.caption(
99
  f"Device: {'GPU (CUDA)' if ctx_id == 0 else 'CPU'} • ORT Providers: {', '.join(providers)}"
@@ -122,14 +153,15 @@ def _parse_fps_cap(original_fps, cap_choice):
122
  if not original_fps or original_fps <= 0:
123
  original_fps = 25.0
124
  if cap_choice == "Original":
125
- return original_fps, 1 # write_fps, frame_step
126
  try:
127
  tgt = float(cap_choice)
 
128
  step = max(1, int(round(original_fps / tgt)))
129
- write_fps = original_fps / step
130
  return write_fps, step
131
  except Exception:
132
- return original_fps, 1
133
 
134
  def _safe_imdecode(file_bytes):
135
  arr = np.frombuffer(file_bytes, np.uint8)
@@ -149,7 +181,12 @@ def swap_faces_in_video(
149
  progress
150
  ):
151
  # Validate source image
152
- source_faces = app.get(image_bgr)
 
 
 
 
 
153
  if not source_faces:
154
  st.error("❌ No face detected in the source image.")
155
  return None
@@ -157,7 +194,13 @@ def swap_faces_in_video(
157
  # Use the largest detected face if there are multiple
158
  source_face = max(
159
  source_faces,
160
- key=lambda f: (f.bbox[2]-f.bbox[0]) * (f.bbox[3]-f.bbox[1])
 
 
 
 
 
 
161
  )
162
 
163
  # Open video
@@ -170,7 +213,9 @@ def swap_faces_in_video(
170
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
171
  orig_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
172
  orig_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
173
- orig_fps = float(cap.get(cv2.CAP_PROP_FPS)) or 25.0
 
 
174
 
175
  # Decide processing size & FPS behavior
176
  proc_w, proc_h = _get_proc_size_choice(orig_w, orig_h, proc_res)
@@ -181,12 +226,14 @@ def swap_faces_in_video(
181
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_out:
182
  output_path = tmp_out.name
183
 
184
- # `mp4v` is widely compatible on Spaces/desktop browsers
185
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
186
  out = cv2.VideoWriter(output_path, fourcc, write_fps, (out_w, out_h))
187
  if not out.isOpened():
188
  cap.release()
189
- st.error("❌ Failed to open VideoWriter. Try a different resolution/FPS setting.")
 
 
 
190
  return None
191
 
192
  st.info(
@@ -207,161 +254,4 @@ def swap_faces_in_video(
207
  break
208
 
209
  # FPS cap by skipping frames
210
- if frame_step > 1 and (read_idx % frame_step != 0):
211
- read_idx += 1
212
- if frame_count > 0:
213
- progress.progress(min(1.0, read_idx / frame_count))
214
- continue
215
-
216
- # Resize for processing
217
- if (proc_w, proc_h) != (orig_w, orig_h):
218
- proc_frame = cv2.resize(frame, (proc_w, proc_h), interpolation=cv2.INTER_AREA)
219
- else:
220
- proc_frame = frame
221
-
222
- try:
223
- # Detect faces on processed frame
224
- target_faces = app.get(proc_frame)
225
-
226
- if target_faces:
227
- # Optionally limit faces to largest N for speed
228
- target_faces = sorted(
229
- target_faces,
230
- key=lambda f: (f.bbox[2]-f.bbox[0])*(f.bbox[3]-f.bbox[1]),
231
- reverse=True
232
- )[:max_faces]
233
-
234
- # Swap into a working buffer
235
- result_frame = proc_frame.copy()
236
- for tface in target_faces:
237
- # Two-call fallback for different insightface versions
238
- try:
239
- result_frame = swapper.get(
240
- proc_frame, tface, source_face, paste_back=True
241
- )
242
- except Exception:
243
- result_frame = swapper.get(
244
- result_frame, tface, source_face, paste_back=True
245
- )
246
-
247
- # Upscale back to original if requested
248
- if keep_original_res and (proc_w, proc_h) != (orig_w, orig_h):
249
- result_frame = cv2.resize(result_frame, (orig_w, orig_h), interpolation=cv2.INTER_CUBIC)
250
-
251
- out.write(result_frame)
252
-
253
- except Exception as e:
254
- # Log & write fallback frame (processed size or original size)
255
- print(f"[WARN] Frame {read_idx} failed: {e}")
256
- traceback.print_exc()
257
- fallback = proc_frame
258
- if keep_original_res and (proc_w, proc_h) != (orig_w, orig_h):
259
- fallback = cv2.resize(proc_frame, (orig_w, orig_h), interpolation=cv2.INTER_CUBIC)
260
- out.write(fallback)
261
-
262
- read_idx += 1
263
- processed_frames += 1
264
-
265
- # Update progress
266
- if frame_count > 0:
267
- progress.progress(min(1.0, read_idx / frame_count))
268
- elif processed_frames % 30 == 0:
269
- # Fallback progress for unknown frame counts
270
- progress.progress(min(1.0, (processed_frames % 300) / 300.0))
271
-
272
- finally:
273
- cap.release()
274
- out.release()
275
-
276
- return output_path
277
-
278
- # -------------------------
279
- # UI: Uploads & Preview
280
- # -------------------------
281
- st.write("Upload a **source face image** and a **target video**, preview them, tweak speed options, then start swapping.")
282
-
283
- image_file = st.file_uploader("Upload Source Image", type=["jpg", "jpeg", "png"])
284
- video_file = st.file_uploader("Upload Target Video", type=["mp4", "mov", "mkv", "avi"])
285
-
286
- # Previews
287
- if image_file:
288
- st.subheader("📷 Source Image Preview")
289
- st.image(image_file, caption="Source Image", use_column_width=True)
290
-
291
- if video_file:
292
- st.subheader("🎬 Target Video Preview")
293
- st.video(video_file)
294
-
295
- # -------------------------
296
- # Run button
297
- # -------------------------
298
- if st.button("🚀 Start Face Swap"):
299
- if not image_file or not video_file:
300
- st.error("⚠️ Please upload both a source image and a target video.")
301
- else:
302
- # Read uploads safely (do not consume file pointer used by preview)
303
- try:
304
- image_bytes = image_file.getvalue()
305
- source_image = _safe_imdecode(image_bytes)
306
- if source_image is None:
307
- st.error("❌ Failed to decode source image. Please use a valid JPG/PNG.")
308
- st.stop()
309
- except Exception:
310
- st.error("❌ Failed to read the source image bytes.")
311
- st.stop()
312
-
313
- try:
314
- video_bytes = video_file.getvalue()
315
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_video:
316
- tmp_video.write(video_bytes)
317
- tmp_video_path = tmp_video.name
318
- except Exception:
319
- st.error("❌ Failed to save the uploaded video to a temp file.")
320
- st.stop()
321
-
322
- with st.spinner("Processing video… This can take a while ⏳"):
323
- progress_bar = st.progress(0)
324
- output_video_path = swap_faces_in_video(
325
- source_image,
326
- tmp_video_path,
327
- proc_res=proc_res,
328
- fps_cap=fps_cap,
329
- keep_original_res=keep_original_res,
330
- max_faces=max_faces,
331
- progress=progress_bar
332
- )
333
-
334
- if output_video_path:
335
- st.success("✅ Face swapping completed!")
336
-
337
- st.subheader("📺 Output Video Preview")
338
- st.video(output_video_path)
339
-
340
- # Download button
341
- try:
342
- with open(output_video_path, "rb") as f:
343
- st.download_button(
344
- label="⬇️ Download Processed Video",
345
- data=f,
346
- file_name="output_swapped_video.mp4",
347
- mime="video/mp4"
348
- )
349
- except Exception:
350
- st.warning("⚠️ Could not open the output file for download.")
351
-
352
- # Cleanup temp input video; keep output so it can be downloaded
353
- try:
354
- os.remove(tmp_video_path)
355
- except Exception:
356
- pass
357
-
358
- # -------------
359
- # Diagnostics
360
- # -------------
361
- with st.expander("🩺 Diagnostics"):
362
- st.write(
363
- "- If you see **SessionInfo** errors: this app defers heavy imports via `@st.cache_resource` "
364
- "so Streamlit initializes first. If errors persist, restart the runtime.\n"
365
- "- If output is jumpy/stutters: lower **Target FPS** or choose **480p** processing.\n"
366
- "- If video fails to open: re-encode your input to **MP4 (H.264, AAC)**."
367
- )
 
1
  # =========================
2
+ # app.py (production-ready, safer)
3
  # =========================
4
  import os
5
 
 
13
  import tempfile
14
  import traceback
15
 
16
+ # -------------------------
17
+ # VERY EARLY: initialize session state
18
+ # -------------------------
19
+ # This prevents the "SessionInfo before it was initialized" glitch on some boots
20
+ for key, default in {
21
+ "uploaded_image": None,
22
+ "uploaded_video": None,
23
+ "output_video": None,
24
+ }.items():
25
+ if key not in st.session_state:
26
+ st.session_state[key] = default
27
+
28
+ # -------------------------
29
+ # GPU check (optional torch import)
30
+ # -------------------------
31
  def _has_cuda():
32
  try:
33
  import torch
34
  return torch.cuda.is_available()
35
  except Exception:
36
+ # If torch isn't installed, just say no CUDA
37
  return False
38
 
39
  # -----------------------------------
 
81
  """
82
  Load InsightFace detectors and the inswapper model once.
83
  Auto-select GPU if available, else CPU.
84
+ Be tolerant of insightface versions (providers kwarg may not exist).
85
  """
 
86
  import insightface
87
  from insightface.app import FaceAnalysis
88
 
89
+ # Desired providers for ORT
90
+ wants_cuda = _has_cuda()
91
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if wants_cuda else ["CPUExecutionProvider"]
92
 
93
  # Face detector/landmarks (retinaface + arcface in buffalo_l)
94
+ ctx_id = 0 if wants_cuda else -1
95
  app = FaceAnalysis(name="buffalo_l")
 
 
96
  app.prepare(ctx_id=ctx_id, det_size=(640, 640))
97
 
98
  # Face swapper (inswapper_128)
99
+ # Some insightface versions accept providers=..., some don't.
100
+ swapper = None
101
+ try:
102
+ swapper = insightface.model_zoo.get_model(
103
+ "inswapper_128.onnx",
104
+ download=True,
105
+ download_zip=False,
106
+ providers=providers
107
+ )
108
+ except TypeError:
109
+ # Fallback path: older insightface without providers kwarg
110
+ swapper = insightface.model_zoo.get_model(
111
+ "inswapper_128.onnx",
112
+ download=True,
113
+ download_zip=False
114
+ )
115
+ except Exception as e:
116
+ # Last resort: surface a helpful error
117
+ raise RuntimeError(f"Failed to load inswapper_128.onnx: {e}")
118
 
119
  return app, swapper, providers, ctx_id
120
 
121
  # Initialize models
122
  with st.spinner("Loading models…"):
123
+ try:
124
+ app, swapper, providers, ctx_id = load_models()
125
+ except Exception as e:
126
+ st.error("❌ Model loading failed. See logs for details.")
127
+ raise
128
 
129
  st.caption(
130
  f"Device: {'GPU (CUDA)' if ctx_id == 0 else 'CPU'} • ORT Providers: {', '.join(providers)}"
 
153
  if not original_fps or original_fps <= 0:
154
  original_fps = 25.0
155
  if cap_choice == "Original":
156
+ return max(1.0, original_fps), 1 # write_fps, frame_step
157
  try:
158
  tgt = float(cap_choice)
159
+ tgt = max(1.0, tgt)
160
  step = max(1, int(round(original_fps / tgt)))
161
+ write_fps = max(1.0, original_fps / step)
162
  return write_fps, step
163
  except Exception:
164
+ return max(1.0, original_fps), 1
165
 
166
  def _safe_imdecode(file_bytes):
167
  arr = np.frombuffer(file_bytes, np.uint8)
 
181
  progress
182
  ):
183
  # Validate source image
184
+ try:
185
+ source_faces = app.get(image_bgr)
186
+ except Exception as e:
187
+ st.error(f"❌ FaceAnalysis failed on source image: {e}")
188
+ return None
189
+
190
  if not source_faces:
191
  st.error("❌ No face detected in the source image.")
192
  return None
 
194
  # Use the largest detected face if there are multiple
195
  source_face = max(
196
  source_faces,
197
+ key=lambda f: (f.bbox[2]-f.bbox[0]) * (f.bbox[1]-f.bbox[3]) # absolute area doesn't depend on sign but keep positive
198
+ if hasattr(f, "bbox") else 0
199
+ )
200
+ # (safer area) re-compute properly
201
+ source_face = max(
202
+ source_faces,
203
+ key=lambda f: max(1, int((f.bbox[2]-f.bbox[0]) * (f.bbox[3]-f.bbox[1])))
204
  )
205
 
206
  # Open video
 
213
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
214
  orig_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
215
  orig_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
216
+ orig_fps = float(cap.get(cv2.CAP_PROP_FPS))
217
+ if orig_fps <= 0 or np.isnan(orig_fps):
218
+ orig_fps = 25.0
219
 
220
  # Decide processing size & FPS behavior
221
  proc_w, proc_h = _get_proc_size_choice(orig_w, orig_h, proc_res)
 
226
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_out:
227
  output_path = tmp_out.name
228
 
 
229
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
230
  out = cv2.VideoWriter(output_path, fourcc, write_fps, (out_w, out_h))
231
  if not out.isOpened():
232
  cap.release()
233
+ st.error(
234
+ "❌ Failed to open VideoWriter. "
235
+ "Try setting Processing Resolution to 480p or Target FPS to 24."
236
+ )
237
  return None
238
 
239
  st.info(
 
254
  break
255
 
256
  # FPS cap by skipping frames
257
+ if frame_step > 1 and (re_