Spaces:
Sleeping
Sleeping
| # app.py β MCP server (single-file) | |
| from mcp.server.fastmcp import FastMCP | |
| from typing import Optional, List, Tuple, Any, Dict | |
| import requests | |
| import os | |
| import gradio as gr | |
| import json | |
| import re | |
| import logging | |
| import gc | |
| # --- Import OCR Engine & Prompts --- | |
| try: | |
| # UPDATED IMPORT | |
| from ocr_engine import extract_text_and_conf | |
| from prompts import get_ocr_extraction_prompt, get_agent_prompt | |
| except ImportError: | |
| def extract_text_and_conf(path): return "", 0.0 | |
| def get_ocr_extraction_prompt(txt): return txt | |
| def get_agent_prompt(h, u): return u | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("mcp_server") | |
| # --- Load Config --- | |
| try: | |
| from config import ( | |
| CLIENT_ID, CLIENT_SECRET, REFRESH_TOKEN, API_BASE, | |
| INVOICE_API_BASE, ORGANIZATION_ID, LOCAL_MODEL | |
| ) | |
| except Exception: | |
| raise SystemExit("Config missing.") | |
| mcp = FastMCP("ZohoCRMAgent") | |
| # --- Globals --- | |
| LLM_PIPELINE = None | |
| TOKENIZER = None | |
| # --- NEW: Evaluation / KPI Logic (Integrated OCR Score) --- | |
| def calculate_extraction_confidence(data: dict, ocr_score: float) -> dict: | |
| """ | |
| Calculates Hybrid Confidence: | |
| - 20% based on OCR Engine Signal (Tesseract Confidence) | |
| - 80% based on Data Quality (LLM Extraction Completeness) | |
| """ | |
| semantic_score = 0 | |
| issues = [] | |
| # 1. Structure Check (Base 10 pts) | |
| semantic_score += 10 | |
| # 2. Total Amount Check (30 pts) | |
| amt = str(data.get("total_amount", "")).replace("$", "").replace(",", "") | |
| if amt and re.match(r'^\d+(\.\d+)?$', amt): | |
| semantic_score += 30 | |
| else: | |
| issues.append("Missing/Invalid Total Amount") | |
| # 3. Date Check (20 pts) | |
| date_str = str(data.get("invoice_date", "")) | |
| if date_str and len(date_str) >= 8: | |
| semantic_score += 20 | |
| else: | |
| issues.append("Missing Invoice Date") | |
| # 4. Line Items Check (30 pts) | |
| items = data.get("line_items", []) | |
| if isinstance(items, list) and len(items) > 0: | |
| if any(i.get("name") for i in items): | |
| semantic_score += 30 | |
| else: | |
| semantic_score += 10 | |
| issues.append("Line Items missing descriptions") | |
| else: | |
| issues.append("No Line Items detected") | |
| # 5. Contact Name (10 pts) | |
| if data.get("contact_name"): | |
| semantic_score += 10 | |
| else: | |
| issues.append("Missing Vendor Name") | |
| # --- HYBRID CALCULATION --- | |
| # Weight: 80% Data Quality + 20% OCR Quality | |
| final_score = (semantic_score * 0.8) + (ocr_score * 0.2) | |
| # Add OCR warnings | |
| if ocr_score < 60: | |
| issues.append(f"Low OCR Confidence ({ocr_score}%) - Check image quality") | |
| return { | |
| "score": int(final_score), | |
| "ocr_score": ocr_score, | |
| "semantic_score": semantic_score, | |
| "rating": "High" if final_score > 80 else ("Medium" if final_score > 50 else "Low"), | |
| "issues": issues | |
| } | |
| # --- Helpers --- | |
| def extract_json_safely(text: str) -> Optional[Any]: | |
| try: | |
| return json.loads(text) | |
| except: | |
| match = re.search(r'(\{.*\}|\[.*\])', text, re.DOTALL) | |
| return json.loads(match.group(0)) if match else None | |
| def _normalize_local_path_args(args: Any) -> Any: | |
| if not isinstance(args, dict): return args | |
| fp = args.get("file_path") or args.get("path") | |
| if isinstance(fp, str) and fp.startswith("/mnt/data/") and os.path.exists(fp): | |
| args["file_url"] = f"file://{fp}" | |
| return args | |
| # --- Model Loading --- | |
| def init_local_model(): | |
| global LLM_PIPELINE, TOKENIZER | |
| if LLM_PIPELINE is not None: return | |
| try: | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| logger.info(f"Loading lighter model: {LOCAL_MODEL}...") | |
| TOKENIZER = AutoTokenizer.from_pretrained(LOCAL_MODEL) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| LOCAL_MODEL, | |
| device_map="auto", | |
| torch_dtype="auto" | |
| ) | |
| LLM_PIPELINE = pipeline("text-generation", model=model, tokenizer=TOKENIZER) | |
| logger.info("Model loaded.") | |
| except Exception as e: | |
| logger.error(f"Model load error: {e}") | |
| def local_llm_generate(prompt: str, max_tokens: int = 512) -> Dict[str, Any]: | |
| if LLM_PIPELINE is None: | |
| init_local_model() | |
| if LLM_PIPELINE is None: | |
| return {"text": "Model not loaded.", "raw": None} | |
| try: | |
| out = LLM_PIPELINE( | |
| prompt, | |
| max_new_tokens=max_tokens, | |
| return_full_text=False, | |
| do_sample=False | |
| ) | |
| text = out[0]["generated_text"] if out else "" | |
| return {"text": text, "raw": out} | |
| except Exception as e: | |
| return {"text": f"Error: {e}", "raw": None} | |
| # --- Tools (Zoho) --- | |
| def _get_valid_token_headers() -> dict: | |
| r = requests.post("https://accounts.zoho.in/oauth/v2/token", params={ | |
| "refresh_token": REFRESH_TOKEN, "client_id": CLIENT_ID, | |
| "client_secret": CLIENT_SECRET, "grant_type": "refresh_token" | |
| }, timeout=10) | |
| if r.status_code == 200: | |
| return {"Authorization": f"Zoho-oauthtoken {r.json().get('access_token')}"} | |
| return {} | |
| def create_record(module_name: str, record_data: dict) -> str: | |
| h = _get_valid_token_headers() | |
| if not h: return "Auth Failed" | |
| r = requests.post(f"{API_BASE}/{module_name}", headers=h, json={"data": [record_data]}) | |
| if r.status_code in (200, 201): | |
| try: | |
| d = r.json().get("data", [{}])[0].get("details", {}) | |
| return json.dumps({"status": "success", "id": d.get("id"), "zoho_response": r.json()}) | |
| except: | |
| return json.dumps(r.json()) | |
| return r.text | |
| def create_invoice(data: dict) -> str: | |
| h = _get_valid_token_headers() | |
| if not h: return "Auth Failed" | |
| r = requests.post(f"{INVOICE_API_BASE}/invoices", headers=h, | |
| params={"organization_id": ORGANIZATION_ID}, json=data) | |
| return json.dumps(r.json()) if r.status_code in (200, 201) else r.text | |
| def process_document(file_path: str, target_module: Optional[str] = "Contacts") -> dict: | |
| if not os.path.exists(file_path): | |
| return {"error": f"File not found at path: {file_path}"} | |
| # 1. OCR (UPDATED: Returns text AND score) | |
| raw_text, ocr_score = extract_text_and_conf(file_path) | |
| if not raw_text: return {"error": "OCR empty"} | |
| # 2. LLM Extraction | |
| prompt = get_ocr_extraction_prompt(raw_text) | |
| res = local_llm_generate(prompt, max_tokens=300) | |
| data = extract_json_safely(res["text"]) | |
| # 3. Evaluation / KPI Calculation (UPDATED: Uses ocr_score) | |
| kpis = {"score": 0, "rating": "Fail", "issues": ["Extraction Failed"]} | |
| if data: | |
| kpis = calculate_extraction_confidence(data, ocr_score) | |
| return { | |
| "status": "success", | |
| "file": os.path.basename(file_path), | |
| "extracted_data": data or {"raw": res["text"]}, | |
| "kpis": kpis | |
| } | |
| # --- Executor --- | |
| def parse_and_execute(model_text: str, history: list) -> str: | |
| payload = extract_json_safely(model_text) | |
| if not payload: return "No valid tool call found." | |
| cmds = [payload] if isinstance(payload, dict) else payload | |
| results = [] | |
| last_contact_id = None | |
| for cmd in cmds: | |
| if not isinstance(cmd, dict): continue | |
| tool = cmd.get("tool") | |
| args = _normalize_local_path_args(cmd.get("args", {})) | |
| if tool == "create_record": | |
| res = create_record(args.get("module_name", "Contacts"), args.get("record_data", {})) | |
| results.append(f"Record: {res}") | |
| try: | |
| rj = json.loads(res) | |
| if isinstance(rj, dict) and "id" in rj: | |
| last_contact_id = rj["id"] | |
| except: pass | |
| elif tool == "create_invoice": | |
| if not args.get("customer_id") and last_contact_id: | |
| args["customer_id"] = last_contact_id | |
| invoice_payload = args | |
| if last_contact_id and "customer_id" not in invoice_payload: | |
| invoice_payload["customer_id"] = last_contact_id | |
| res = create_invoice(invoice_payload) | |
| results.append(f"Invoice: {res}") | |
| return "\n".join(results) | |
| # --- Chat Core --- | |
| def chat_logic(message: str, file_path: str, history: list) -> str: | |
| # PHASE 1: File Upload -> Extraction -> KPI Report | |
| if file_path: | |
| logger.info(f"Processing file: {file_path}") | |
| doc = process_document(file_path) | |
| if doc.get("status") == "success": | |
| data = doc["extracted_data"] | |
| kpi = doc["kpis"] | |
| extracted_json = json.dumps(data, indent=2) | |
| # Format KPI output (Expanded) | |
| rating_emoji = "π’" if kpi['rating'] == 'High' else ("π‘" if kpi['rating'] == 'Medium' else "π΄") | |
| issues_txt = "\n".join([f"- {i}" for i in kpi['issues']]) if kpi['issues'] else "None" | |
| return ( | |
| f"### π Extraction Complete: **{doc['file']}**\n" | |
| f"**Combined Confidence:** {rating_emoji} {kpi['score']}/100\n" | |
| f"*(OCR Signal: {kpi['ocr_score']}% | Data Quality: {kpi['semantic_score']}%)*\n\n" | |
| f"**Issues Detected:**\n{issues_txt}\n\n" | |
| f"```json\n{extracted_json}\n```\n\n" | |
| "Type **'Create Invoice'** to push this to Zoho." | |
| ) | |
| else: | |
| return f"OCR Failed: {doc.get('error')}" | |
| # PHASE 2: Text Interaction | |
| hist_txt = "\n".join([f"U: {h[0]}\nA: {h[1]}" for h in history]) | |
| prompt = get_agent_prompt(hist_txt, message) | |
| gen = local_llm_generate(prompt, max_tokens=256) | |
| tool_data = extract_json_safely(gen["text"]) | |
| if tool_data: | |
| return parse_and_execute(gen["text"], history) | |
| return gen["text"] | |
| # --- UI --- | |
| def chat_handler(msg, hist): | |
| txt = msg.get("text", "") | |
| files = msg.get("files", []) | |
| path = files[0] if files else None | |
| return chat_logic(txt, path, hist) | |
| if __name__ == "__main__": | |
| gc.collect() | |
| demo = gr.ChatInterface(fn=chat_handler, multimodal=True) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |