Spaces:
Sleeping
Sleeping
| # app.py — MCP server (single-file) | |
| from mcp.server.fastmcp import FastMCP | |
| from typing import Optional, List, Tuple, Any, Dict | |
| import requests | |
| import os | |
| import gradio as gr | |
| import json | |
| import re | |
| import logging | |
| import gc | |
| # --- Import OCR Engine & Prompts --- | |
| try: | |
| from ocr_engine import extract_text_from_file | |
| from prompts import get_ocr_extraction_prompt, get_agent_prompt | |
| except ImportError: | |
| def extract_text_from_file(path): return "" | |
| def get_ocr_extraction_prompt(txt): return txt | |
| def get_agent_prompt(h, u): return u | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("mcp_server") | |
| # --- Load Config --- | |
| try: | |
| from config import ( | |
| CLIENT_ID, CLIENT_SECRET, REFRESH_TOKEN, API_BASE, | |
| INVOICE_API_BASE, ORGANIZATION_ID, LOCAL_MODEL | |
| ) | |
| except Exception: | |
| raise SystemExit("Config missing.") | |
| mcp = FastMCP("ZohoCRMAgent") | |
| # --- Globals --- | |
| LLM_PIPELINE = None | |
| TOKENIZER = None | |
| # --- Helpers --- | |
| def extract_json_safely(text: str) -> Optional[Any]: | |
| try: | |
| return json.loads(text) | |
| except: | |
| match = re.search(r'(\{.*\}|\[.*\])', text, re.DOTALL) | |
| return json.loads(match.group(0)) if match else None | |
| def _normalize_local_path_args(args: Any) -> Any: | |
| if not isinstance(args, dict): return args | |
| fp = args.get("file_path") or args.get("path") | |
| if isinstance(fp, str) and fp.startswith("/mnt/data/") and os.path.exists(fp): | |
| args["file_url"] = f"file://{fp}" | |
| return args | |
| # --- Model Loading --- | |
| def init_local_model(): | |
| global LLM_PIPELINE, TOKENIZER | |
| if LLM_PIPELINE is not None: return | |
| try: | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| logger.info(f"Loading lighter model: {LOCAL_MODEL}...") | |
| TOKENIZER = AutoTokenizer.from_pretrained(LOCAL_MODEL) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| LOCAL_MODEL, | |
| device_map="auto", | |
| torch_dtype="auto" | |
| ) | |
| LLM_PIPELINE = pipeline("text-generation", model=model, tokenizer=TOKENIZER) | |
| logger.info("Model loaded.") | |
| except Exception as e: | |
| logger.error(f"Model load error: {e}") | |
| def local_llm_generate(prompt: str, max_tokens: int = 512) -> Dict[str, Any]: | |
| if LLM_PIPELINE is None: | |
| init_local_model() | |
| if LLM_PIPELINE is None: | |
| return {"text": "Model not loaded.", "raw": None} | |
| try: | |
| out = LLM_PIPELINE( | |
| prompt, | |
| max_new_tokens=max_tokens, | |
| return_full_text=False, | |
| do_sample=False | |
| ) | |
| text = out[0]["generated_text"] if out else "" | |
| return {"text": text, "raw": out} | |
| except Exception as e: | |
| return {"text": f"Error: {e}", "raw": None} | |
| # --- Tools (Zoho) --- | |
| def _get_valid_token_headers() -> dict: | |
| r = requests.post("https://accounts.zoho.in/oauth/v2/token", params={ | |
| "refresh_token": REFRESH_TOKEN, "client_id": CLIENT_ID, | |
| "client_secret": CLIENT_SECRET, "grant_type": "refresh_token" | |
| }, timeout=10) | |
| if r.status_code == 200: | |
| return {"Authorization": f"Zoho-oauthtoken {r.json().get('access_token')}"} | |
| return {} | |
| def create_record(module_name: str, record_data: dict) -> str: | |
| h = _get_valid_token_headers() | |
| if not h: return "Auth Failed" | |
| r = requests.post(f"{API_BASE}/{module_name}", headers=h, json={"data": [record_data]}) | |
| if r.status_code in (200, 201): | |
| try: | |
| d = r.json().get("data", [{}])[0].get("details", {}) | |
| return json.dumps({"status": "success", "id": d.get("id"), "zoho_response": r.json()}) | |
| except: | |
| return json.dumps(r.json()) | |
| return r.text | |
| def create_invoice(data: dict) -> str: | |
| h = _get_valid_token_headers() | |
| if not h: return "Auth Failed" | |
| r = requests.post(f"{INVOICE_API_BASE}/invoices", headers=h, | |
| params={"organization_id": ORGANIZATION_ID}, json=data) | |
| return json.dumps(r.json()) if r.status_code in (200, 201) else r.text | |
| def process_document(file_path: str, target_module: Optional[str] = "Contacts") -> dict: | |
| if not os.path.exists(file_path): | |
| return {"error": f"File not found at path: {file_path}"} | |
| # 1. OCR | |
| raw_text = extract_text_from_file(file_path) | |
| if not raw_text: return {"error": "OCR empty"} | |
| # 2. LLM Extraction | |
| prompt = get_ocr_extraction_prompt(raw_text) | |
| res = local_llm_generate(prompt, max_tokens=300) | |
| data = extract_json_safely(res["text"]) | |
| return { | |
| "status": "success", | |
| "file": os.path.basename(file_path), | |
| "extracted_data": data or {"raw": res["text"]} | |
| } | |
| # --- Executor --- | |
| def parse_and_execute(model_text: str, history: list) -> str: | |
| payload = extract_json_safely(model_text) | |
| if not payload: return "No valid tool call found." | |
| cmds = [payload] if isinstance(payload, dict) else payload | |
| results = [] | |
| last_contact_id = None | |
| for cmd in cmds: | |
| if not isinstance(cmd, dict): continue | |
| tool = cmd.get("tool") | |
| args = _normalize_local_path_args(cmd.get("args", {})) | |
| if tool == "create_record": | |
| res = create_record(args.get("module_name", "Contacts"), args.get("record_data", {})) | |
| results.append(f"Record: {res}") | |
| try: | |
| rj = json.loads(res) | |
| if isinstance(rj, dict) and "id" in rj: | |
| last_contact_id = rj["id"] | |
| except: pass | |
| elif tool == "create_invoice": | |
| # Auto-fill contact_id if we just created one | |
| if not args.get("customer_id") and last_contact_id: | |
| args["customer_id"] = last_contact_id | |
| # Map Items from strict structure | |
| invoice_payload = args # Assuming LLM passes correct structure, or map here | |
| if last_contact_id and "customer_id" not in invoice_payload: | |
| invoice_payload["customer_id"] = last_contact_id | |
| res = create_invoice(invoice_payload) | |
| results.append(f"Invoice: {res}") | |
| return "\n".join(results) | |
| # --- Chat Core --- | |
| def chat_logic(message: str, file_path: str, history: list) -> str: | |
| # PHASE 1: File Upload -> Extraction Only (No Zoho Auth yet) | |
| if file_path: | |
| logger.info(f"Processing file: {file_path}") | |
| doc = process_document(file_path) | |
| if doc.get("status") == "success": | |
| extracted_json = json.dumps(doc["extracted_data"], indent=2) | |
| # We return this text. It gets added to history. | |
| # The User must then say "Yes, push it" to trigger Phase 2. | |
| return ( | |
| f"I extracted the following data from **{doc['file']}**:\n\n" | |
| f"```json\n{extracted_json}\n```\n\n" | |
| "Please review it. If it looks correct, type **'Create Invoice'** or **'Push to Zoho'**." | |
| ) | |
| else: | |
| return f"OCR Failed: {doc.get('error')}" | |
| # PHASE 2: Text Interaction (Check History for JSON + Intent) | |
| hist_txt = "\n".join([f"U: {h[0]}\nA: {h[1]}" for h in history]) | |
| # The Prompt now checks history for JSON and waits for explicit "save/push" keywords | |
| prompt = get_agent_prompt(hist_txt, message) | |
| gen = local_llm_generate(prompt, max_tokens=256) | |
| logger.info(f"LLM Decision: {gen['text']}") | |
| tool_data = extract_json_safely(gen["text"]) | |
| if tool_data: | |
| # User confirmed -> Execute Tool (Triggers Zoho Auth) | |
| return parse_and_execute(gen["text"], history) | |
| # Just chat/clarification | |
| return gen["text"] | |
| # --- UI --- | |
| def chat_handler(msg, hist): | |
| txt = msg.get("text", "") | |
| files = msg.get("files", []) | |
| path = files[0] if files else None | |
| return chat_logic(txt, path, hist) | |
| if __name__ == "__main__": | |
| gc.collect() | |
| demo = gr.ChatInterface(fn=chat_handler, multimodal=True) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |