|
|
import os |
|
|
import hashlib |
|
|
import sqlite3 |
|
|
from datetime import datetime, timezone |
|
|
from typing import List, Tuple |
|
|
from pathlib import Path |
|
|
import json |
|
|
import numpy as np |
|
|
|
|
|
from flask import Flask, request, jsonify, send_from_directory, render_template, url_for |
|
|
from werkzeug.utils import secure_filename |
|
|
|
|
|
|
|
|
|
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "").strip() |
|
|
|
|
|
|
|
|
client = None |
|
|
if OPENAI_API_KEY: |
|
|
try: |
|
|
from openai import OpenAI |
|
|
client = OpenAI(api_key=OPENAI_API_KEY) |
|
|
except Exception as e: |
|
|
print("Warning: OpenAI client init failed:", e) |
|
|
client = None |
|
|
|
|
|
|
|
|
FAST_WHISPER = None |
|
|
try: |
|
|
from faster_whisper import WhisperModel |
|
|
|
|
|
FAST_WHISPER = WhisperModel( |
|
|
"base", |
|
|
device="cpu", |
|
|
compute_type="int8" |
|
|
) |
|
|
except Exception as e: |
|
|
print("faster-whisper unavailable:", e) |
|
|
FAST_WHISPER = None |
|
|
|
|
|
|
|
|
import fitz |
|
|
import docx |
|
|
import mammoth |
|
|
from PIL import Image |
|
|
import pytesseract |
|
|
|
|
|
from sentence_transformers import SentenceTransformer |
|
|
import faiss |
|
|
|
|
|
|
|
|
AUDIO_DURATION_ENABLED = True |
|
|
try: |
|
|
from pydub import AudioSegment |
|
|
except Exception: |
|
|
AUDIO_DURATION_ENABLED = False |
|
|
|
|
|
|
|
|
|
|
|
BASE_DIR = Path(__file__).parent.resolve() |
|
|
|
|
|
|
|
|
DATA_ROOT = Path(os.getenv("DATA_DIR", "/data")).resolve() |
|
|
os.makedirs(DATA_ROOT, exist_ok=True) |
|
|
|
|
|
DB_PATH = str(DATA_ROOT / "rag_local.db") |
|
|
MEDIA_DIR = str(DATA_ROOT / "ingested_media") |
|
|
INDEX_DIR = str(DATA_ROOT / "index_store") |
|
|
os.makedirs(MEDIA_DIR, exist_ok=True) |
|
|
os.makedirs(INDEX_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
DOC_EXT = {'.pdf', '.docx', '.doc', '.txt'} |
|
|
IMG_EXT = {'.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.webp'} |
|
|
AUDIO_EXT = {'.mp3', '.wav', '.m4a', '.aac', '.ogg', '.flac', '.webm'} |
|
|
ALLOWED_EXT = DOC_EXT | IMG_EXT | AUDIO_EXT |
|
|
|
|
|
|
|
|
conn = sqlite3.connect(DB_PATH, check_same_thread=False) |
|
|
c = conn.cursor() |
|
|
|
|
|
c.execute('''CREATE TABLE IF NOT EXISTS documents ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
file_name TEXT, |
|
|
file_type TEXT, |
|
|
file_hash TEXT UNIQUE, |
|
|
created_at TEXT, |
|
|
orig_path TEXT |
|
|
)''') |
|
|
|
|
|
c.execute('''CREATE TABLE IF NOT EXISTS text_chunks ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
document_id INTEGER, |
|
|
chunk_text TEXT, |
|
|
chunk_meta TEXT, |
|
|
FOREIGN KEY(document_id) REFERENCES documents(id) |
|
|
)''') |
|
|
|
|
|
c.execute('''CREATE TABLE IF NOT EXISTS images ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
document_id INTEGER, |
|
|
image_path TEXT, |
|
|
page_num INTEGER, |
|
|
FOREIGN KEY(document_id) REFERENCES documents(id) |
|
|
)''') |
|
|
|
|
|
|
|
|
c.execute('''CREATE TABLE IF NOT EXISTS audios ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
document_id INTEGER, |
|
|
audio_path TEXT, |
|
|
duration_sec REAL, |
|
|
sample_rate INTEGER, |
|
|
FOREIGN KEY(document_id) REFERENCES documents(id) |
|
|
)''') |
|
|
|
|
|
conn.commit() |
|
|
|
|
|
|
|
|
def file_hash_bytes(b: bytes) -> str: |
|
|
return hashlib.sha256(b).hexdigest() |
|
|
|
|
|
def save_uploaded_file(file_storage, dest_dir=MEDIA_DIR) -> str: |
|
|
filename = secure_filename(file_storage.filename) |
|
|
dest = os.path.join(dest_dir, filename) |
|
|
count = 1 |
|
|
base, ext = os.path.splitext(filename) |
|
|
while os.path.exists(dest): |
|
|
filename = f"{base}_{count}{ext}" |
|
|
dest = os.path.join(dest_dir, filename) |
|
|
count += 1 |
|
|
file_storage.save(dest) |
|
|
return dest |
|
|
|
|
|
|
|
|
def process_pdf(file_path: str) -> Tuple[str, List[Tuple[str,int]]]: |
|
|
doc = fitz.open(file_path) |
|
|
text_pages, images = [], [] |
|
|
for page_num in range(len(doc)): |
|
|
page = doc[page_num] |
|
|
text_pages.append(page.get_text()) |
|
|
for img_index, img in enumerate(page.get_images(full=True)): |
|
|
xref = img[0] |
|
|
pix = fitz.Pixmap(doc, xref) |
|
|
img_name = f'{Path(file_path).stem}_page{page_num}_{img_index}.png' |
|
|
img_path = os.path.join(MEDIA_DIR, img_name) |
|
|
if pix.n - pix.alpha < 4: |
|
|
pix.save(img_path) |
|
|
else: |
|
|
pix0 = fitz.Pixmap(fitz.csRGB, pix) |
|
|
pix0.save(img_path) |
|
|
pix0 = None |
|
|
pix = None |
|
|
images.append((img_path, page_num)) |
|
|
return '\n'.join(text_pages), images |
|
|
|
|
|
def process_docx(file_path: str) -> Tuple[str, List[Tuple[str,int]]]: |
|
|
d = docx.Document(file_path) |
|
|
text = '\n'.join([p.text for p in d.paragraphs]) |
|
|
return text, [] |
|
|
|
|
|
def process_doc(file_path: str) -> Tuple[str, List[Tuple[str,int]]]: |
|
|
with open(file_path, 'rb') as f: |
|
|
result = mammoth.extract_raw_text(f) |
|
|
return result.value, [] |
|
|
|
|
|
def ocr_image(file_path: str) -> Tuple[str, List[Tuple[str,int]]]: |
|
|
try: |
|
|
img = Image.open(file_path) |
|
|
text = pytesseract.image_to_string(img) |
|
|
return text, [(file_path, 0)] |
|
|
except Exception as e: |
|
|
print("OCR error:", e) |
|
|
return "", [(file_path, 0)] |
|
|
|
|
|
|
|
|
def _audio_duration(file_path: str) -> float: |
|
|
if not AUDIO_DURATION_ENABLED: |
|
|
return None |
|
|
try: |
|
|
seg = AudioSegment.from_file(file_path) |
|
|
return round(len(seg) / 1000.0, 3) |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
def _transcribe_offline_faster_whisper(file_path: str) -> Tuple[str, int]: |
|
|
""" |
|
|
Returns (text, sample_rate) using faster-whisper if available. |
|
|
""" |
|
|
if FAST_WHISPER is None: |
|
|
return "", None |
|
|
try: |
|
|
segments, info = FAST_WHISPER.transcribe(file_path, beam_size=1) |
|
|
text = " ".join([s.text.strip() for s in segments if s.text]) |
|
|
sr = getattr(info, "sample_rate", None) |
|
|
return text.strip(), sr |
|
|
except Exception as e: |
|
|
print("faster-whisper failed:", e) |
|
|
return "", None |
|
|
|
|
|
def _transcribe_openai(file_path: str) -> Tuple[str, int]: |
|
|
""" |
|
|
Returns (text, sample_rate). OpenAI Whisper returns only text; sample_rate unknown. |
|
|
""" |
|
|
if client is None: |
|
|
return "", None |
|
|
try: |
|
|
with open(file_path, "rb") as af: |
|
|
|
|
|
resp = client.audio.transcriptions.create( |
|
|
model="whisper-1", |
|
|
file=af |
|
|
) |
|
|
text = getattr(resp, "text", None) or (resp.get("text") if isinstance(resp, dict) else "") |
|
|
return (text or "").strip(), None |
|
|
except Exception as e: |
|
|
print("OpenAI Whisper failed:", e) |
|
|
return "", None |
|
|
|
|
|
def process_audio(file_path: str) -> Tuple[str, float, int]: |
|
|
""" |
|
|
Returns (transcript_text, duration_sec, sample_rate) |
|
|
Tries faster-whisper (offline) first, then OpenAI Whisper. If both fail, returns empty text. |
|
|
""" |
|
|
duration = _audio_duration(file_path) |
|
|
|
|
|
text, sr = _transcribe_offline_faster_whisper(file_path) |
|
|
if not text: |
|
|
|
|
|
t2, sr2 = _transcribe_openai(file_path) |
|
|
text = t2 |
|
|
sr = sr2 |
|
|
return text, duration, (sr or None) |
|
|
|
|
|
|
|
|
def ingest_file_disk(file_path: str): |
|
|
ext = Path(file_path).suffix.lower() |
|
|
with open(file_path, 'rb') as f: |
|
|
b = f.read() |
|
|
fh = file_hash_bytes(b) |
|
|
c.execute('SELECT id FROM documents WHERE file_hash=?', (fh,)) |
|
|
if c.fetchone(): |
|
|
return {"status": "exists", "file_hash": fh} |
|
|
|
|
|
text, images = "", [] |
|
|
audio_meta = None |
|
|
|
|
|
try: |
|
|
if ext == '.pdf': |
|
|
text, images = process_pdf(file_path) |
|
|
elif ext == '.docx': |
|
|
text, images = process_docx(file_path) |
|
|
elif ext == '.doc': |
|
|
text, images = process_doc(file_path) |
|
|
elif ext in IMG_EXT: |
|
|
text, images = ocr_image(file_path) |
|
|
elif ext == '.txt': |
|
|
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: |
|
|
text = f.read() |
|
|
elif ext in AUDIO_EXT: |
|
|
|
|
|
t, duration, sr = process_audio(file_path) |
|
|
text = t or "" |
|
|
audio_meta = (duration, sr) |
|
|
else: |
|
|
return {"status": "unsupported", "ext": ext} |
|
|
except Exception as e: |
|
|
return {"status": "error", "error": str(e)} |
|
|
|
|
|
now = datetime.now(timezone.utc).isoformat() |
|
|
c.execute('INSERT INTO documents (file_name, file_type, file_hash, created_at, orig_path) VALUES (?, ?, ?, ?, ?)', |
|
|
(Path(file_path).name, ext, fh, now, file_path)) |
|
|
doc_id = c.lastrowid |
|
|
conn.commit() |
|
|
|
|
|
|
|
|
if text and text.strip(): |
|
|
meta = {"source": Path(file_path).name} |
|
|
if ext in AUDIO_EXT: |
|
|
meta["modality"] = "audio_transcript" |
|
|
c.execute('INSERT INTO text_chunks (document_id, chunk_text, chunk_meta) VALUES (?, ?, ?)', |
|
|
(doc_id, text, json.dumps(meta))) |
|
|
conn.commit() |
|
|
|
|
|
|
|
|
for img_path, page_num in images: |
|
|
c.execute('INSERT INTO images (document_id, image_path, page_num) VALUES (?, ?, ?)', |
|
|
(doc_id, img_path, page_num)) |
|
|
conn.commit() |
|
|
|
|
|
|
|
|
if ext in AUDIO_EXT: |
|
|
dur, sr = (audio_meta or (None, None)) |
|
|
c.execute('INSERT INTO audios (document_id, audio_path, duration_sec, sample_rate) VALUES (?, ?, ?, ?)', |
|
|
(doc_id, file_path, dur, sr)) |
|
|
conn.commit() |
|
|
|
|
|
return {"status": "ingested", "doc_id": doc_id, "file_hash": fh, "file_type": ext} |
|
|
|
|
|
|
|
|
EMBED_MODEL = SentenceTransformer('intfloat/e5-small-v2') |
|
|
INDEX_FILE = os.path.join(INDEX_DIR, 'faiss_e5_small.index') |
|
|
ID_MAP_FILE = os.path.join(INDEX_DIR, 'id_mapping.jsonl') |
|
|
|
|
|
def _chunk_text(text: str, max_chars=1000, overlap=200) -> List[str]: |
|
|
text = text.strip() |
|
|
if not text: return [] |
|
|
chunks, start, L = [], 0, len(text) |
|
|
while start < L: |
|
|
end = min(start + max_chars, L) |
|
|
chunks.append(text[start:end].strip()) |
|
|
if end == L: break |
|
|
start = end - overlap |
|
|
return chunks |
|
|
|
|
|
def build_faiss_index(): |
|
|
c.execute('SELECT id, document_id, chunk_text, chunk_meta FROM text_chunks') |
|
|
rows = c.fetchall() |
|
|
pieces = [] |
|
|
for row in rows: |
|
|
parent_id, doc_id, text, meta = row |
|
|
meta_dict = json.loads(meta) if meta else {} |
|
|
for piece in _chunk_text(text): |
|
|
pieces.append((parent_id, doc_id, piece, meta_dict)) |
|
|
if not pieces: return {"status": "no_text"} |
|
|
|
|
|
texts = [p[2] for p in pieces] |
|
|
embeddings = EMBED_MODEL.encode(texts, show_progress_bar=True, convert_to_numpy=True, normalize_embeddings=True) |
|
|
dim = embeddings.shape[1] |
|
|
index = faiss.IndexFlatIP(dim) |
|
|
index.add(embeddings) |
|
|
faiss.write_index(index, INDEX_FILE) |
|
|
|
|
|
mapping = [] |
|
|
for i, p in enumerate(pieces): |
|
|
mapping.append({ |
|
|
"faiss_idx": i, |
|
|
"parent_chunk_id": p[0], |
|
|
"document_id": p[1], |
|
|
"text": p[2], |
|
|
"meta": p[3] |
|
|
}) |
|
|
with open(ID_MAP_FILE, 'w', encoding='utf-8') as f: |
|
|
for m in mapping: |
|
|
f.write(json.dumps(m, ensure_ascii=False) + '\n') |
|
|
return {"status": "built", "num_pieces": len(pieces)} |
|
|
|
|
|
def load_faiss_index(): |
|
|
if os.path.exists(INDEX_FILE) and os.path.exists(ID_MAP_FILE): |
|
|
index = faiss.read_index(INDEX_FILE) |
|
|
id_map = [json.loads(line) for line in open(ID_MAP_FILE, 'r', encoding='utf-8')] |
|
|
return index, id_map |
|
|
return None, None |
|
|
|
|
|
def semantic_search(query: str, top_k=5): |
|
|
index, id_map = load_faiss_index() |
|
|
if index is None: return [] |
|
|
q_emb = EMBED_MODEL.encode([query], convert_to_numpy=True, normalize_embeddings=True) |
|
|
D, I = index.search(q_emb, top_k) |
|
|
results = [] |
|
|
for score, idx in zip(D[0], I[0]): |
|
|
if idx < 0: continue |
|
|
entry = id_map[idx] |
|
|
doc_id = entry['document_id'] |
|
|
c.execute('SELECT file_name FROM documents WHERE id=?', (doc_id,)) |
|
|
r = c.fetchone() |
|
|
fn = r[0] if r else None |
|
|
results.append({ |
|
|
"faiss_idx": int(idx), |
|
|
"score": float(score), |
|
|
"document_id": int(doc_id), |
|
|
"file_name": fn, |
|
|
"text": entry['text'], |
|
|
"meta": entry.get('meta', {}) |
|
|
}) |
|
|
return results |
|
|
|
|
|
|
|
|
def generate_grounded_summary(query: str, hits: List[dict], max_tokens=300) -> str: |
|
|
if OPENAI_API_KEY and client: |
|
|
prompt_snippets = [] |
|
|
for i, h in enumerate(hits, start=1): |
|
|
src = h.get('file_name') or f"doc_{h['document_id']}" |
|
|
snippet = h.get('text', '').strip().replace('\n', ' ') |
|
|
if len(snippet) > 400: snippet = snippet[:400].rsplit(' ', 1)[0] + "…" |
|
|
prompt_snippets.append(f"[{i}] Source: {src}\nText: {snippet}") |
|
|
|
|
|
system = ( |
|
|
"You are a helpful assistant that answers user queries using only the provided source snippets. " |
|
|
"Always include numbered citations like [1], [2]. If info not found, say so clearly." |
|
|
) |
|
|
user = f"User query: {query}\n\nSources:\n" + "\n\n".join(prompt_snippets) |
|
|
|
|
|
try: |
|
|
resp = client.chat.completions.create( |
|
|
model="gpt-4o-mini", |
|
|
messages=[ |
|
|
{"role": "system", "content": system}, |
|
|
{"role": "user", "content": user} |
|
|
], |
|
|
max_tokens=max_tokens, |
|
|
temperature=0.0 |
|
|
) |
|
|
return resp.choices[0].message.content.strip() |
|
|
except Exception as e: |
|
|
print("OpenAI summarization failed:", e) |
|
|
|
|
|
|
|
|
lines = [] |
|
|
for i, h in enumerate(hits, start=1): |
|
|
snippet = h.get('text', '').strip() |
|
|
if len(snippet) > 300: snippet = snippet[:300].rsplit(' ', 1)[0] + "…" |
|
|
lines.append(f"[{i}] {snippet}") |
|
|
return "Extractive snippets (top results):\n" + "\n\n".join(lines) |
|
|
|
|
|
|
|
|
app = Flask(__name__, static_folder=str(BASE_DIR / "static"), template_folder=str(BASE_DIR / "templates")) |
|
|
app.config['MAX_CONTENT_LENGTH'] = 200 * 1024 * 1024 |
|
|
|
|
|
@app.route('/') |
|
|
def home(): |
|
|
return render_template('index.html') |
|
|
|
|
|
@app.route('/media/<path:filename>') |
|
|
def media(filename): |
|
|
return send_from_directory(MEDIA_DIR, filename) |
|
|
|
|
|
@app.route('/api/upload', methods=['POST']) |
|
|
def api_upload(): |
|
|
if 'file' not in request.files: |
|
|
return jsonify({"error": "no file part"}), 400 |
|
|
f = request.files['file'] |
|
|
if f.filename == '': |
|
|
return jsonify({"error": "no selected file"}), 400 |
|
|
ext = Path(f.filename).suffix.lower() |
|
|
if ext not in ALLOWED_EXT: |
|
|
return jsonify({"error": "unsupported file type", "ext": ext}), 400 |
|
|
saved = save_uploaded_file(f, dest_dir=MEDIA_DIR) |
|
|
result = ingest_file_disk(saved) |
|
|
return jsonify(result) |
|
|
|
|
|
@app.route('/api/build_index', methods=['POST']) |
|
|
def api_build_index(): |
|
|
res = build_faiss_index() |
|
|
return jsonify(res) |
|
|
|
|
|
def json_safe(o): |
|
|
if isinstance(o, np.integer): return int(o) |
|
|
if isinstance(o, np.floating): return float(o) |
|
|
if isinstance(o, np.ndarray): return o.tolist() |
|
|
raise TypeError(f"Object {type(o)} not serializable") |
|
|
|
|
|
@app.route('/api/query', methods=['POST']) |
|
|
def api_query(): |
|
|
data = request.json or {} |
|
|
q = (data.get('query') or "").strip() |
|
|
if not q: |
|
|
return jsonify({"error": "must provide 'query'"}), 400 |
|
|
|
|
|
|
|
|
hits = semantic_search(q, top_k=1) |
|
|
if not hits: |
|
|
|
|
|
return jsonify({"hits": [], "summary": "", "message": "no match"}), 200 |
|
|
|
|
|
h = hits[0] |
|
|
doc_id = h['document_id'] |
|
|
|
|
|
|
|
|
c.execute('SELECT file_name FROM documents WHERE id=?', (doc_id,)) |
|
|
r = c.fetchone() |
|
|
file_name = (r[0] if r else h.get('file_name')) or "" |
|
|
file_url = url_for('media', filename=file_name, _external=False) if file_name else None |
|
|
|
|
|
|
|
|
minimal_hit = { |
|
|
"document_id": doc_id, |
|
|
"file_name": file_name, |
|
|
"file_url": file_url, |
|
|
"text": h.get("text", "") |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
return jsonify({ |
|
|
"hits": [minimal_hit], |
|
|
"summary": minimal_hit["text"] |
|
|
}) |
|
|
|
|
|
@app.route('/api/list_docs', methods=['GET']) |
|
|
def api_list_docs(): |
|
|
c.execute('SELECT id, file_name, file_type, created_at FROM documents ORDER BY id DESC') |
|
|
rows = c.fetchall() |
|
|
out = [{"id": r[0], "file_name": r[1], "file_type": r[2], "created_at": r[3], |
|
|
"url": url_for('media', filename=r[1], _external=False)} for r in rows] |
|
|
return jsonify(out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
import os |
|
|
port = int(os.environ.get("PORT", "7860")) |
|
|
print(f"🚀 Starting Flask app — open http://0.0.0.0:{port}") |
|
|
app.run(host='0.0.0.0', port=port, debug=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|