Spaces:
Sleeping
Sleeping
Update streamlit_app.py
Browse files- streamlit_app.py +59 -6
streamlit_app.py
CHANGED
|
@@ -12,7 +12,7 @@ import pandas as pd
|
|
| 12 |
import streamlit as st
|
| 13 |
from PIL import Image, ImageOps, ImageEnhance
|
| 14 |
from supabase import create_client, Client
|
| 15 |
-
|
| 16 |
# ------------------------ Page ------------------------
|
| 17 |
st.set_page_config(page_title="Care Count Inventory", layout="centered")
|
| 18 |
st.title("📦 Care Count Inventory")
|
|
@@ -33,19 +33,72 @@ sb: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
|
|
| 33 |
|
| 34 |
# ---- VQA model config (free serverless endpoint) ----
|
| 35 |
HF_TOKEN = get_secret("HF_TOKEN") # optional but helps cold-starts
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
"Salesforce/blip-vqa-
|
| 40 |
-
"
|
| 41 |
]
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
# ------------------------ Tiny image utils ------------------------
|
| 44 |
def _to_png_bytes(img: Image.Image) -> bytes:
|
| 45 |
b = io.BytesIO()
|
| 46 |
img.save(b, format="PNG")
|
| 47 |
return b.getvalue()
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
def preprocess_for_label(img: Image.Image) -> Image.Image:
|
| 50 |
"""Lighten/contrast + gentle resize for mobile, improves label legibility."""
|
| 51 |
img = img.convert("RGB")
|
|
|
|
| 12 |
import streamlit as st
|
| 13 |
from PIL import Image, ImageOps, ImageEnhance
|
| 14 |
from supabase import create_client, Client
|
| 15 |
+
import json
|
| 16 |
# ------------------------ Page ------------------------
|
| 17 |
st.set_page_config(page_title="Care Count Inventory", layout="centered")
|
| 18 |
st.title("📦 Care Count Inventory")
|
|
|
|
| 33 |
|
| 34 |
# ---- VQA model config (free serverless endpoint) ----
|
| 35 |
HF_TOKEN = get_secret("HF_TOKEN") # optional but helps cold-starts
|
| 36 |
+
# Try these in order (first item can be overridden from Variables → VQA_MODEL)
|
| 37 |
+
VQA_MODELS = [
|
| 38 |
+
os.getenv("VQA_MODEL") or "Salesforce/blip-vqa-capfilt-large",
|
| 39 |
+
"Salesforce/blip-vqa-base",
|
| 40 |
+
"dandelin/vilt-b32-finetuned-vqa",
|
| 41 |
]
|
| 42 |
|
| 43 |
+
# OCR fallback (optional variable OCR_MODEL can override the first)
|
| 44 |
+
OCR_MODELS = [
|
| 45 |
+
os.getenv("OCR_MODEL") or "microsoft/trocr-large-printed",
|
| 46 |
+
"microsoft/trocr-base-printed",
|
| 47 |
+
]
|
| 48 |
# ------------------------ Tiny image utils ------------------------
|
| 49 |
def _to_png_bytes(img: Image.Image) -> bytes:
|
| 50 |
b = io.BytesIO()
|
| 51 |
img.save(b, format="PNG")
|
| 52 |
return b.getvalue()
|
| 53 |
|
| 54 |
+
|
| 55 |
+
def _hf_post_form(model_id: str, files: dict, data: dict | None = None):
|
| 56 |
+
"""Low-level multipart POST to HF Inference API."""
|
| 57 |
+
url = f"https://api-inference.huggingface.co/models/{model_id}"
|
| 58 |
+
headers = {"Accept": "application/json"}
|
| 59 |
+
if HF_TOKEN:
|
| 60 |
+
headers["Authorization"] = f"Bearer {HF_TOKEN}"
|
| 61 |
+
return requests.post(url, headers=headers, files=files, data=(data or {}), timeout=60)
|
| 62 |
+
|
| 63 |
+
def vqa_http(img: Image.Image, question: str) -> tuple[str, str | None, list[str]]:
|
| 64 |
+
"""
|
| 65 |
+
Try the configured VQA models in order (free endpoints).
|
| 66 |
+
Returns (answer, model_used, errors[]).
|
| 67 |
+
"""
|
| 68 |
+
img_bytes = _to_png_bytes(img)
|
| 69 |
+
errors: list[str] = []
|
| 70 |
+
|
| 71 |
+
for mid in VQA_MODELS:
|
| 72 |
+
try:
|
| 73 |
+
files = {"image": ("image.png", img_bytes, "image/png")}
|
| 74 |
+
data = {"inputs": json.dumps({"question": question})}
|
| 75 |
+
r = _hf_post_form(mid, files=files, data=data)
|
| 76 |
+
|
| 77 |
+
# Common API statuses
|
| 78 |
+
if r.status_code in (503, 524):
|
| 79 |
+
errors.append(f"{mid.split('/')[-1]} loading ({r.status_code})")
|
| 80 |
+
time.sleep(1.0)
|
| 81 |
+
continue
|
| 82 |
+
if r.status_code == 404:
|
| 83 |
+
errors.append(f"{mid.split('/')[-1]} not found (404)")
|
| 84 |
+
continue
|
| 85 |
+
if r.status_code != 200:
|
| 86 |
+
errors.append(f"{mid.split('/')[-1]} HTTP {r.status_code}: {r.text[:160]}")
|
| 87 |
+
continue
|
| 88 |
+
|
| 89 |
+
out = r.json()
|
| 90 |
+
# BLIP/VILT style responses
|
| 91 |
+
ans = ""
|
| 92 |
+
if isinstance(out, list) and out:
|
| 93 |
+
ans = out[0].get("answer") or out[0].get("generated_text") or ""
|
| 94 |
+
elif isinstance(out, dict):
|
| 95 |
+
ans = out.get("answer") or out.get("generated_text") or ""
|
| 96 |
+
if ans:
|
| 97 |
+
return ans.strip(), mid, errors
|
| 98 |
+
except Exception as e:
|
| 99 |
+
errors.append(f"{mid.split('/')[-1]} error: {e}")
|
| 100 |
+
|
| 101 |
+
return "", None, errors
|
| 102 |
def preprocess_for_label(img: Image.Image) -> Image.Image:
|
| 103 |
"""Lighten/contrast + gentle resize for mobile, improves label legibility."""
|
| 104 |
img = img.convert("RGB")
|