focustiki commited on
Commit
c4e209e
·
1 Parent(s): 1efead4

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +59 -6
streamlit_app.py CHANGED
@@ -12,7 +12,7 @@ import pandas as pd
12
  import streamlit as st
13
  from PIL import Image, ImageOps, ImageEnhance
14
  from supabase import create_client, Client
15
-
16
  # ------------------------ Page ------------------------
17
  st.set_page_config(page_title="Care Count Inventory", layout="centered")
18
  st.title("📦 Care Count Inventory")
@@ -33,19 +33,72 @@ sb: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
33
 
34
  # ---- VQA model config (free serverless endpoint) ----
35
  HF_TOKEN = get_secret("HF_TOKEN") # optional but helps cold-starts
36
- PRIMARY_MODEL = os.getenv("VQA_MODEL", "Salesforce/blip-vqa-capfilt-large")
37
- FALLBACK_MODELS = [
38
- PRIMARY_MODEL, # 1) what you set in Space → Variables
39
- "Salesforce/blip-vqa-capfilt-large", # 2) strong default
40
- "Salesforce/blip-vqa-base", # 3) smaller fallback
41
  ]
42
 
 
 
 
 
 
43
  # ------------------------ Tiny image utils ------------------------
44
  def _to_png_bytes(img: Image.Image) -> bytes:
45
  b = io.BytesIO()
46
  img.save(b, format="PNG")
47
  return b.getvalue()
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def preprocess_for_label(img: Image.Image) -> Image.Image:
50
  """Lighten/contrast + gentle resize for mobile, improves label legibility."""
51
  img = img.convert("RGB")
 
12
  import streamlit as st
13
  from PIL import Image, ImageOps, ImageEnhance
14
  from supabase import create_client, Client
15
+ import json
16
  # ------------------------ Page ------------------------
17
  st.set_page_config(page_title="Care Count Inventory", layout="centered")
18
  st.title("📦 Care Count Inventory")
 
33
 
34
  # ---- VQA model config (free serverless endpoint) ----
35
  HF_TOKEN = get_secret("HF_TOKEN") # optional but helps cold-starts
36
+ # Try these in order (first item can be overridden from Variables → VQA_MODEL)
37
+ VQA_MODELS = [
38
+ os.getenv("VQA_MODEL") or "Salesforce/blip-vqa-capfilt-large",
39
+ "Salesforce/blip-vqa-base",
40
+ "dandelin/vilt-b32-finetuned-vqa",
41
  ]
42
 
43
+ # OCR fallback (optional variable OCR_MODEL can override the first)
44
+ OCR_MODELS = [
45
+ os.getenv("OCR_MODEL") or "microsoft/trocr-large-printed",
46
+ "microsoft/trocr-base-printed",
47
+ ]
48
  # ------------------------ Tiny image utils ------------------------
49
  def _to_png_bytes(img: Image.Image) -> bytes:
50
  b = io.BytesIO()
51
  img.save(b, format="PNG")
52
  return b.getvalue()
53
 
54
+
55
+ def _hf_post_form(model_id: str, files: dict, data: dict | None = None):
56
+ """Low-level multipart POST to HF Inference API."""
57
+ url = f"https://api-inference.huggingface.co/models/{model_id}"
58
+ headers = {"Accept": "application/json"}
59
+ if HF_TOKEN:
60
+ headers["Authorization"] = f"Bearer {HF_TOKEN}"
61
+ return requests.post(url, headers=headers, files=files, data=(data or {}), timeout=60)
62
+
63
+ def vqa_http(img: Image.Image, question: str) -> tuple[str, str | None, list[str]]:
64
+ """
65
+ Try the configured VQA models in order (free endpoints).
66
+ Returns (answer, model_used, errors[]).
67
+ """
68
+ img_bytes = _to_png_bytes(img)
69
+ errors: list[str] = []
70
+
71
+ for mid in VQA_MODELS:
72
+ try:
73
+ files = {"image": ("image.png", img_bytes, "image/png")}
74
+ data = {"inputs": json.dumps({"question": question})}
75
+ r = _hf_post_form(mid, files=files, data=data)
76
+
77
+ # Common API statuses
78
+ if r.status_code in (503, 524):
79
+ errors.append(f"{mid.split('/')[-1]} loading ({r.status_code})")
80
+ time.sleep(1.0)
81
+ continue
82
+ if r.status_code == 404:
83
+ errors.append(f"{mid.split('/')[-1]} not found (404)")
84
+ continue
85
+ if r.status_code != 200:
86
+ errors.append(f"{mid.split('/')[-1]} HTTP {r.status_code}: {r.text[:160]}")
87
+ continue
88
+
89
+ out = r.json()
90
+ # BLIP/VILT style responses
91
+ ans = ""
92
+ if isinstance(out, list) and out:
93
+ ans = out[0].get("answer") or out[0].get("generated_text") or ""
94
+ elif isinstance(out, dict):
95
+ ans = out.get("answer") or out.get("generated_text") or ""
96
+ if ans:
97
+ return ans.strip(), mid, errors
98
+ except Exception as e:
99
+ errors.append(f"{mid.split('/')[-1]} error: {e}")
100
+
101
+ return "", None, errors
102
  def preprocess_for_label(img: Image.Image) -> Image.Image:
103
  """Lighten/contrast + gentle resize for mobile, improves label legibility."""
104
  img = img.convert("RGB")