Spaces:
Paused
Paused
| import os | |
| import io | |
| import json | |
| import logging | |
| import sys | |
| import tempfile | |
| import re | |
| import base64 | |
| from pathlib import Path | |
| from typing import Optional | |
| import fitz # PyMuPDF | |
| import numpy as np | |
| import requests | |
| import torch | |
| import torchvision | |
| from PIL import Image, ImageDraw, ImageFont | |
| from fastapi import FastAPI, File, Form, HTTPException, UploadFile | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from paddleocr import PaddleOCR | |
| from pydantic import BaseModel, HttpUrl | |
| # --- Configure Logging --- | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| logger.info("Starting application initialization...") | |
| # --- Configuration --- | |
| MODEL_PATH = "/content/layout-model.pt" | |
| # --- Global Variables --- | |
| ocr: Optional[PaddleOCR] = None | |
| layout_model = None | |
| device: str = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # Label mapping | |
| id_to_names = { | |
| 0: 'title', 1: 'plain text', 2: 'abandon', 3: 'figure', 4: 'figure_caption', | |
| 5: 'table', 6: 'table_caption', 7: 'table_footnote', 8: 'isolate_formula', | |
| 9: 'formula_caption' | |
| } | |
| # --- FastAPI Application --- | |
| app = FastAPI(title="Document Layout Analysis API", version="1.0.0") | |
| # --- FastAPI Startup Event --- | |
| async def startup_event(): | |
| global ocr, layout_model | |
| lang="en" | |
| try: | |
| logger.info("Initializing PaddleOCR...") | |
| ocr = PaddleOCR( | |
| use_angle_cls=True, | |
| lang=lang, | |
| use_gpu=False, | |
| show_log=False, | |
| det_model_dir=f'/app/models/det/{lang}/en_PP-OCRv3_det_infer', | |
| rec_model_dir=f'/app/models/rec/{lang}/en_PP-OCRv4_rec_infer', | |
| cls_model_dir=f'/app/models/cls/{lang}/ch_ppocr_mobile_v2.0_cls_infer' | |
| ) | |
| logger.info("β PaddleOCR initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize PaddleOCR: {e}", exc_info=True) | |
| raise RuntimeError("Could not initialize PaddleOCR") from e | |
| try: | |
| logger.info(f"Loading DocLayout-YOLO model from {MODEL_PATH}...") | |
| if not os.path.exists(MODEL_PATH): | |
| logger.error(f"Model file not found at {MODEL_PATH}") | |
| raise FileNotFoundError(f"Model file not found at {MODEL_PATH}") | |
| # Import YOLOv10 from doclayout_yolo | |
| from doclayout_yolo import YOLOv10 | |
| layout_model = YOLOv10(MODEL_PATH) | |
| logger.info(f"β DocLayout-YOLO model loaded successfully on device: {device}") | |
| except Exception as e: | |
| logger.error(f"Failed to load DocLayout-YOLO model: {e}", exc_info=True) | |
| raise RuntimeError("Could not load layout model") from e | |
| # --- Pydantic Request Models --- | |
| class URLRequest(BaseModel): | |
| url: HttpUrl | |
| resolution: Optional[int] = None | |
| # --- Helper Functions --- | |
| def extract_number_from_caption(caption_text: str) -> Optional[str]: | |
| """Extract the number from a caption like 'Table 3' or 'Figure 2.1'""" | |
| if not caption_text: | |
| return None | |
| NUMBER_PATTERN = re.compile(r"(?:Table|Figure)\s*([\d\.]+)", re.IGNORECASE) | |
| match = NUMBER_PATTERN.search(caption_text) | |
| return match.group(1) if match else None | |
| def detect_layout_regions(page: fitz.Page, target_width: Optional[int] = None, conf_threshold=0.25, iou_threshold=0.3): | |
| """Use DocLayout-YOLO to detect document elements.""" | |
| if layout_model is None: | |
| raise RuntimeError("Layout model is not initialized.") | |
| logger.debug(f"Detecting layout regions with target_width={target_width}") | |
| try: | |
| pix = page.get_pixmap(dpi=150) | |
| img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| if target_width: | |
| aspect_ratio = img.height / img.width | |
| target_height = int(target_width * aspect_ratio) | |
| img = img.resize((target_width, target_height), Image.LANCZOS) | |
| logger.debug(f"Resized image to {target_width}x{target_height}") | |
| logger.debug(f"Running model prediction on image size: {img.width}x{img.height}") | |
| # Run prediction using YOLOv10 | |
| results = layout_model.predict( | |
| img, | |
| imgsz=1280, | |
| conf=conf_threshold, | |
| device=device | |
| ) | |
| # Get first result | |
| det_res = results[0] | |
| # Access boxes using the correct API | |
| boxes = det_res.boxes.xyxy.cpu().numpy() | |
| classes = det_res.boxes.cls.cpu().numpy() | |
| scores = det_res.boxes.conf.cpu().numpy() | |
| logger.debug(f"Detected {len(boxes)} boxes before NMS") | |
| if len(boxes) == 0: | |
| logger.info("No objects detected") | |
| return [], img | |
| # Apply NMS | |
| boxes_tensor = torch.from_numpy(boxes) | |
| scores_tensor = torch.from_numpy(scores) | |
| indices = torchvision.ops.nms(boxes_tensor, scores_tensor, iou_threshold) | |
| boxes = boxes[indices.numpy()] | |
| scores = scores[indices.numpy()] | |
| classes = classes[indices.numpy()] | |
| logger.debug(f"Detected {len(boxes)} boxes after NMS") | |
| detected_regions = [] | |
| for box, score, cls in zip(boxes, scores, classes): | |
| box = [float(coord) for coord in box] | |
| label_name = id_to_names.get(int(cls), 'unknown') | |
| detected_regions.append({ | |
| "bbox": box, | |
| "type": label_name, | |
| "confidence": float(score) | |
| }) | |
| logger.debug(f"Returning {len(detected_regions)} detected regions") | |
| return detected_regions, img | |
| except Exception as e: | |
| logger.error(f"Error in detect_layout_regions: {e}", exc_info=True) | |
| raise | |
| def extract_text_from_bbox(img: Image.Image, bbox: list, padding: int = 5) -> str: | |
| """Run OCR on a specific bounding box region of a PIL Image.""" | |
| if ocr is None: | |
| raise RuntimeError("OCR model is not initialized.") | |
| logger.debug(f"Extracting text from bbox: {bbox}") | |
| try: | |
| x0, y0, x1, y1 = [int(coord) for coord in bbox] | |
| x0 = max(0, x0 - padding) | |
| y0 = max(0, y0 - padding) | |
| x1 = min(img.width, x1 + padding) | |
| y1 = min(img.height, y1 + padding) | |
| if x0 >= x1 or y0 >= y1: | |
| logger.debug("Invalid bbox dimensions") | |
| return "" | |
| region = img.crop((x0, y0, x1, y1)) | |
| region_np = np.array(region) | |
| ocr_result = ocr.ocr(region_np, cls=True) | |
| if not ocr_result or not ocr_result[0]: | |
| logger.debug("No OCR results") | |
| return "" | |
| text_parts = [line[1][0] for line in ocr_result[0]] | |
| result_text = " ".join(text_parts) | |
| logger.debug(f"Extracted text: {result_text[:100]}...") | |
| return result_text | |
| except Exception as e: | |
| logger.error(f"Error in extract_text_from_bbox: {e}", exc_info=True) | |
| return "" | |
| def process_document(file_path: str, target_width: Optional[int] = None): | |
| """Process a document and extract layout information.""" | |
| logger.info(f"Processing document: {file_path}") | |
| try: | |
| doc = fitz.open(file_path) | |
| logger.info(f"Document opened successfully. Pages: {len(doc)}") | |
| results = [] | |
| for page_num, page in enumerate(doc): | |
| logger.info(f"Processing page {page_num + 1}/{len(doc)}") | |
| try: | |
| detected_regions, processed_img = detect_layout_regions(page, target_width=target_width) | |
| # Group regions by type | |
| figures = [r for r in detected_regions if r["type"] == 'figure'] | |
| figure_captions = [r for r in detected_regions if r["type"] == 'figure_caption'] | |
| tables = [r for r in detected_regions if r["type"] == 'table'] | |
| table_captions = [r for r in detected_regions if r["type"] == 'table_caption'] | |
| logger.debug(f"Found {len(figures)} figures, {len(figure_captions)} figure captions, {len(tables)} tables, {len(table_captions)} table captions") | |
| image_entries = [] | |
| table_entries = [] | |
| # Match figures with their captions (caption usually BELOW figure) | |
| for idx, figure in enumerate(figures, start=1): | |
| figure_bbox = figure["bbox"] | |
| best_caption = None | |
| min_distance = float('inf') | |
| for caption in figure_captions: | |
| cap_bbox = caption["bbox"] | |
| distance = cap_bbox[1] - figure_bbox[3] | |
| if 0 <= distance < min_distance: | |
| min_distance = distance | |
| best_caption = caption | |
| caption_text = extract_text_from_bbox(processed_img, best_caption["bbox"]) if best_caption else None | |
| figure_number = extract_number_from_caption(caption_text) or str(idx) | |
| image_entries.append({ | |
| "figure_number": figure_number, | |
| "figure_bbox": figure_bbox, | |
| "caption": caption_text, | |
| "caption_bbox": best_caption["bbox"] if best_caption else None, | |
| "confidence": figure["confidence"] | |
| }) | |
| # Match tables with their captions (caption usually ABOVE table) | |
| for idx, table in enumerate(tables, start=1): | |
| table_bbox = table["bbox"] | |
| best_caption = None | |
| min_distance = float('inf') | |
| for caption in table_captions: | |
| cap_bbox = caption["bbox"] | |
| distance = table_bbox[1] - cap_bbox[3] | |
| if 0 <= distance < min_distance: | |
| min_distance = distance | |
| best_caption = caption | |
| caption_text = extract_text_from_bbox(processed_img, best_caption["bbox"]) if best_caption else None | |
| table_number = extract_number_from_caption(caption_text) or str(idx) | |
| table_entries.append({ | |
| "table_number": table_number, | |
| "bbox": table_bbox, | |
| "caption": caption_text, | |
| "caption_bbox": best_caption["bbox"] if best_caption else None, | |
| "confidence": table["confidence"] | |
| }) | |
| # Create annotated image | |
| annotated_img = create_annotated_image( | |
| processed_img, | |
| image_entries, | |
| table_entries | |
| ) | |
| # Convert annotated image to base64 | |
| buffered = io.BytesIO() | |
| annotated_img.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| results.append({ | |
| "page_number": page.number + 1, | |
| "figures": image_entries, | |
| "tables": table_entries, | |
| "image_dimensions": {"width": processed_img.width, "height": processed_img.height}, | |
| "annotated_image": f"data:image/png;base64,{img_str}" | |
| }) | |
| logger.info(f"Page {page_num + 1} processed: {len(image_entries)} figures, {len(table_entries)} tables") | |
| except Exception as e: | |
| logger.error(f"Error processing page {page_num + 1}: {e}", exc_info=True) | |
| raise | |
| doc.close() | |
| logger.info(f"Document processing completed. Total pages: {len(results)}") | |
| return results | |
| except Exception as e: | |
| logger.error(f"Error in process_document: {e}", exc_info=True) | |
| raise | |
| def create_annotated_image(img: Image.Image, figures: list, tables: list) -> Image.Image: | |
| """Create an annotated image with bounding boxes.""" | |
| # Create a copy to draw on | |
| annotated = img.copy() | |
| draw = ImageDraw.Draw(annotated) | |
| # Try to load a font | |
| try: | |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20) | |
| small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14) | |
| except: | |
| font = ImageFont.load_default() | |
| small_font = ImageFont.load_default() | |
| # Draw tables (green boxes) | |
| for table in tables: | |
| bbox = table["bbox"] | |
| caption_bbox = table.get("caption_bbox") | |
| table_num = table.get("table_number", "?") | |
| conf = table.get("confidence", 0) | |
| # Draw table content box | |
| draw.rectangle(bbox, outline="green", width=3) | |
| draw.text( | |
| (bbox[0] + 5, bbox[1] + 5), | |
| f"Table {table_num} ({conf:.2f})", | |
| fill="green", | |
| font=font | |
| ) | |
| # Draw caption box | |
| if caption_bbox: | |
| draw.rectangle(caption_bbox, outline="blue", width=2) | |
| draw.text( | |
| (caption_bbox[0], caption_bbox[1] - 20), | |
| "Caption", | |
| fill="blue", | |
| font=small_font | |
| ) | |
| # Draw figures (red boxes) | |
| for figure in figures: | |
| bbox = figure["figure_bbox"] | |
| caption_bbox = figure.get("caption_bbox") | |
| fig_num = figure.get("figure_number", "?") | |
| conf = figure.get("confidence", 0) | |
| # Draw figure content box | |
| draw.rectangle(bbox, outline="red", width=3) | |
| draw.text( | |
| (bbox[0] + 5, bbox[1] + 5), | |
| f"Figure {fig_num} ({conf:.2f})", | |
| fill="red", | |
| font=font | |
| ) | |
| # Draw caption box | |
| if caption_bbox: | |
| draw.rectangle(caption_bbox, outline="blue", width=2) | |
| draw.text( | |
| (caption_bbox[0], caption_bbox[1] - 20), | |
| "Caption", | |
| fill="blue", | |
| font=small_font | |
| ) | |
| return annotated | |
| # --- API Endpoints --- | |
| async def read_root(): | |
| """Serve the frontend UI""" | |
| html_content = """ | |
| <!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>Document Layout Analysis API</title> | |
| <script src="https://cdn.tailwindcss.com"></script> | |
| <style> | |
| .card-grainy { filter: url(#grainy); } | |
| </style> | |
| </head> | |
| <body class="bg-[#09090B] min-h-screen"> | |
| <svg class="absolute h-0 w-0"> | |
| <filter id="grainy"> | |
| <feTurbulence type="fractalNoise" baseFrequency="0.7" numOctaves="2" result="noise" /> | |
| <feComponentTransfer> | |
| <feFuncA type="table" tableValues="0 0.15 0" /> | |
| </feComponentTransfer> | |
| </filter> | |
| </svg> | |
| <div class="container mx-auto px-4 py-12"> | |
| <!-- Header --> | |
| <div class="mb-12 text-center"> | |
| <h3 class="text-sm font-semibold tracking-wider text-cyan-400/90 uppercase mb-4">AI-Powered Document Analysis</h3> | |
| <h1 class="text-5xl font-bold mb-4"> | |
| <span class="bg-gradient-to-r from-gray-100 to-gray-300 bg-clip-text text-transparent">Document Layout</span> | |
| <span class="text-gray-600"> Detection API</span> | |
| </h1> | |
| <p class="text-gray-400 text-lg">Extract tables, figures, and captions from PDFs and images with precision</p> | |
| </div> | |
| <!-- Main Card --> | |
| <div class="relative isolate max-w-4xl mx-auto rounded-3xl border border-white/10 bg-gradient-to-br from-[#1A1D29] via-[#151821] to-[#0F1117] p-10"> | |
| <div class="card-grainy absolute top-0 left-0 h-full w-full"></div> | |
| <div class="pointer-events-none absolute top-0 left-0 h-96 w-96 rounded-full bg-blue-500/5 blur-3xl"></div> | |
| <div class="relative"> | |
| <!-- Upload Section --> | |
| <div class="mb-8"> | |
| <label class="block text-sm font-semibold text-gray-300 mb-4">Upload Document</label> | |
| <div class="rounded-2xl bg-black/30 p-8 ring-1 ring-white/10 backdrop-blur-sm"> | |
| <input type="file" id="fileInput" accept=".pdf,.png,.jpg,.jpeg" | |
| class="block w-full text-sm text-gray-400 file:mr-4 file:py-3 file:px-6 file:rounded-lg file:border-0 file:text-sm file:font-semibold file:bg-cyan-500/10 file:text-cyan-400 hover:file:bg-cyan-500/20 cursor-pointer"> | |
| </div> | |
| </div> | |
| <!-- OR Divider --> | |
| <div class="flex items-center my-8"> | |
| <div class="flex-1 h-px bg-white/10"></div> | |
| <span class="px-4 text-gray-500 text-sm font-semibold">OR</span> | |
| <div class="flex-1 h-px bg-white/10"></div> | |
| </div> | |
| <!-- URL Section --> | |
| <div class="mb-8"> | |
| <label class="block text-sm font-semibold text-gray-300 mb-4">Document URL</label> | |
| <div class="rounded-2xl bg-black/30 p-8 ring-1 ring-white/10 backdrop-blur-sm"> | |
| <input type="url" id="urlInput" placeholder="https://example.com/document.pdf" | |
| class="w-full bg-white/5 border border-white/10 rounded-lg px-4 py-3 text-gray-300 placeholder-gray-600 focus:outline-none focus:ring-2 focus:ring-cyan-500/50"> | |
| </div> | |
| </div> | |
| <!-- Resolution Section --> | |
| <div class="mb-8"> | |
| <label class="block text-sm font-semibold text-gray-300 mb-4"> | |
| Target Width (Optional) | |
| <span class="text-gray-500 text-xs font-normal ml-2">Leave empty for original size</span> | |
| </label> | |
| <div class="rounded-2xl bg-black/30 p-8 ring-1 ring-white/10 backdrop-blur-sm"> | |
| <input type="number" id="resolutionInput" placeholder="e.g., 1280" min="256" max="4096" | |
| class="w-full bg-white/5 border border-white/10 rounded-lg px-4 py-3 text-gray-300 placeholder-gray-600 focus:outline-none focus:ring-2 focus:ring-cyan-500/50"> | |
| </div> | |
| </div> | |
| <!-- Analyze Button --> | |
| <button id="analyzeBtn" onclick="analyzeDocument()" | |
| class="w-full py-4 rounded-lg bg-gradient-to-r from-cyan-500 to-blue-500 text-white font-semibold text-lg hover:from-cyan-600 hover:to-blue-600 transition-all shadow-lg hover:shadow-cyan-500/25"> | |
| Analyze Document | |
| </button> | |
| <!-- Loading --> | |
| <div id="loading" class="hidden mt-8 text-center"> | |
| <div class="inline-block animate-spin rounded-full h-12 w-12 border-4 border-cyan-500 border-t-transparent"></div> | |
| <p class="text-gray-400 mt-4">Processing document...</p> | |
| </div> | |
| <!-- Results --> | |
| <div id="results" class="hidden mt-8"> | |
| <h3 class="text-xl font-bold text-gray-300 mb-4">Analysis Results</h3> | |
| <!-- Annotated Images --> | |
| <div id="annotatedImages" class="mb-6 space-y-6"></div> | |
| <!-- JSON Results --> | |
| <div class="rounded-2xl bg-black/30 p-8 ring-1 ring-white/10 backdrop-blur-sm"> | |
| <div class="flex justify-between items-center mb-4"> | |
| <h4 class="text-lg font-semibold text-gray-300">JSON Output</h4> | |
| <button onclick="toggleJSON()" class="px-4 py-2 rounded-lg bg-gray-500/10 text-gray-400 text-sm hover:bg-gray-500/20 transition-all"> | |
| <span id="toggleText">Show JSON</span> | |
| </button> | |
| </div> | |
| <pre id="resultsContent" class="hidden text-sm text-gray-300 overflow-x-auto max-h-96"></pre> | |
| </div> | |
| <button onclick="downloadJSON()" class="mt-4 px-6 py-3 rounded-lg bg-emerald-500/10 text-emerald-400 font-semibold hover:bg-emerald-500/20 transition-all ring-1 ring-emerald-500/30"> | |
| Download JSON | |
| </button> | |
| </div> | |
| <!-- Error --> | |
| <div id="error" class="hidden mt-8 rounded-2xl bg-rose-500/10 p-6 ring-1 ring-rose-500/30"> | |
| <p class="text-rose-400 font-semibold" id="errorMessage"></p> | |
| </div> | |
| </div> | |
| </div> | |
| <!-- API Documentation --> | |
| <div class="mt-16 max-w-4xl mx-auto"> | |
| <h2 class="text-3xl font-bold text-gray-300 mb-8">API Documentation</h2> | |
| <div class="space-y-6"> | |
| <!-- Endpoint 1 --> | |
| <div class="rounded-2xl border border-white/10 bg-gradient-to-br from-[#1A1D29] via-[#151821] to-[#0F1117] p-8"> | |
| <div class="flex items-center gap-3 mb-4"> | |
| <span class="inline-flex items-center rounded-lg bg-emerald-500/10 px-3 py-1.5 text-xs font-bold text-emerald-400 uppercase ring-1 ring-emerald-500/30">POST</span> | |
| <code class="text-cyan-400 text-lg font-mono">/analyze</code> | |
| </div> | |
| <p class="text-gray-400 mb-4">Analyze a document by uploading a file</p> | |
| <div class="bg-black/30 rounded-lg p-4 overflow-x-auto"> | |
| <pre class="text-sm text-gray-300"><code>curl -X POST "http://your-api-url/analyze" \\ | |
| -F "[email protected]" \\ | |
| -F "resolution=1280"</code></pre> | |
| </div> | |
| </div> | |
| <!-- Endpoint 2 --> | |
| <div class="rounded-2xl border border-white/10 bg-gradient-to-br from-[#1A1D29] via-[#151821] to-[#0F1117] p-8"> | |
| <div class="flex items-center gap-3 mb-4"> | |
| <span class="inline-flex items-center rounded-lg bg-emerald-500/10 px-3 py-1.5 text-xs font-bold text-emerald-400 uppercase ring-1 ring-emerald-500/30">POST</span> | |
| <code class="text-cyan-400 text-lg font-mono">/analyze-url</code> | |
| </div> | |
| <p class="text-gray-400 mb-4">Analyze a document from a URL</p> | |
| <div class="bg-black/30 rounded-lg p-4 overflow-x-auto"> | |
| <pre class="text-sm text-gray-300"><code>curl -X POST "http://your-api-url/analyze-url" \\ | |
| -H "Content-Type: application/json" \\ | |
| -d '{"url": "https://example.com/doc.pdf", "resolution": 1280}'</code></pre> | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| <script> | |
| let analysisResults = null; | |
| async function analyzeDocument() { | |
| const fileInput = document.getElementById('fileInput'); | |
| const urlInput = document.getElementById('urlInput'); | |
| const resolutionInput = document.getElementById('resolutionInput'); | |
| const loading = document.getElementById('loading'); | |
| const resultsDiv = document.getElementById('results'); | |
| const errorDiv = document.getElementById('error'); | |
| const analyzeBtn = document.getElementById('analyzeBtn'); | |
| resultsDiv.classList.add('hidden'); | |
| errorDiv.classList.add('hidden'); | |
| analyzeBtn.disabled = true; | |
| analyzeBtn.textContent = 'Analyzing...'; | |
| const resolution = resolutionInput.value ? parseInt(resolutionInput.value) : null; | |
| try { | |
| loading.classList.remove('hidden'); | |
| let response; | |
| if (fileInput.files.length > 0) { | |
| const formData = new FormData(); | |
| formData.append('file', fileInput.files[0]); | |
| if (resolution) formData.append('resolution', resolution); | |
| response = await fetch('/analyze', { | |
| method: 'POST', | |
| body: formData | |
| }); | |
| } else if (urlInput.value) { | |
| const body = { url: urlInput.value }; | |
| if (resolution) body.resolution = resolution; | |
| response = await fetch('/analyze-url', { | |
| method: 'POST', | |
| headers: { 'Content-Type': 'application/json' }, | |
| body: JSON.stringify(body) | |
| }); | |
| } else { | |
| throw new Error('Please provide a file or URL'); | |
| } | |
| const responseData = await response.json(); | |
| if (!response.ok) { | |
| throw new Error(responseData.detail || 'Analysis failed with status ' + response.status); | |
| } | |
| analysisResults = responseData; | |
| // Display annotated images | |
| displayAnnotatedImages(responseData.results); | |
| // Prepare JSON without base64 images for display | |
| const jsonForDisplay = { | |
| ...responseData, | |
| results: responseData.results.map(r => { | |
| const {annotated_image, ...rest} = r; | |
| return rest; | |
| }) | |
| }; | |
| document.getElementById('resultsContent').textContent = JSON.stringify(jsonForDisplay, null, 2); | |
| resultsDiv.classList.remove('hidden'); | |
| } catch (err) { | |
| document.getElementById('errorMessage').textContent = err.message; | |
| errorDiv.classList.remove('hidden'); | |
| } finally { | |
| loading.classList.add('hidden'); | |
| analyzeBtn.disabled = false; | |
| analyzeBtn.textContent = 'Analyze Document'; | |
| } | |
| } | |
| function displayAnnotatedImages(results) { | |
| const container = document.getElementById('annotatedImages'); | |
| container.innerHTML = ''; | |
| results.forEach((page, idx) => { | |
| if (page.annotated_image) { | |
| const pageDiv = document.createElement('div'); | |
| pageDiv.className = 'rounded-2xl bg-black/30 p-6 ring-1 ring-white/10 backdrop-blur-sm'; | |
| const title = document.createElement('h4'); | |
| title.className = 'text-lg font-semibold text-gray-300 mb-4'; | |
| title.textContent = `Page ${page.page_number}`; | |
| const stats = document.createElement('div'); | |
| stats.className = 'text-sm text-gray-400 mb-4 flex gap-6'; | |
| stats.innerHTML = ` | |
| <span class="flex items-center gap-2"> | |
| <span class="inline-block w-3 h-3 bg-red-500 rounded"></span> | |
| ${page.figures.length} Figure${page.figures.length !== 1 ? 's' : ''} | |
| </span> | |
| <span class="flex items-center gap-2"> | |
| <span class="inline-block w-3 h-3 bg-green-500 rounded"></span> | |
| ${page.tables.length} Table${page.tables.length !== 1 ? 's' : ''} | |
| </span> | |
| <span class="flex items-center gap-2"> | |
| <span class="inline-block w-3 h-3 bg-blue-500 rounded"></span> | |
| Captions | |
| </span> | |
| `; | |
| const img = document.createElement('img'); | |
| img.src = page.annotated_image; | |
| img.className = 'w-full rounded-lg border border-white/10'; | |
| img.alt = `Annotated page ${page.page_number}`; | |
| pageDiv.appendChild(title); | |
| pageDiv.appendChild(stats); | |
| pageDiv.appendChild(img); | |
| container.appendChild(pageDiv); | |
| } | |
| }); | |
| } | |
| function toggleJSON() { | |
| const jsonContent = document.getElementById('resultsContent'); | |
| const toggleText = document.getElementById('toggleText'); | |
| if (jsonContent.classList.contains('hidden')) { | |
| jsonContent.classList.remove('hidden'); | |
| toggleText.textContent = 'Hide JSON'; | |
| } else { | |
| jsonContent.classList.add('hidden'); | |
| toggleText.textContent = 'Show JSON'; | |
| } | |
| } | |
| function downloadJSON() { | |
| if (!analysisResults) return; | |
| // Remove base64 images from download to reduce file size | |
| const downloadData = { | |
| ...analysisResults, | |
| results: analysisResults.results.map(r => { | |
| const {annotated_image, ...rest} = r; | |
| return rest; | |
| }) | |
| }; | |
| const blob = new Blob([JSON.stringify(downloadData, null, 2)], { type: 'application/json' }); | |
| const url = URL.createObjectURL(blob); | |
| const a = document.createElement('a'); | |
| a.href = url; | |
| a.download = 'layout_analysis.json'; | |
| a.click(); | |
| URL.revokeObjectURL(url); | |
| } | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| return HTMLResponse(content=html_content) | |
| async def analyze_file(file: UploadFile = File(...), resolution: Optional[int] = Form(None)): | |
| """Analyze an uploaded document file""" | |
| logger.info(f"Received file upload: {file.filename}, resolution: {resolution}") | |
| tmp_path = None | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as tmp: | |
| content = await file.read() | |
| tmp.write(content) | |
| tmp_path = tmp.name | |
| logger.info(f"Processing file: {tmp_path}") | |
| results = process_document(tmp_path, target_width=resolution) | |
| return JSONResponse(content={ | |
| "status": "success", | |
| "filename": file.filename, | |
| "pages": len(results), | |
| "results": results | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error analyzing file {file.filename}: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}") | |
| finally: | |
| if tmp_path and os.path.exists(tmp_path): | |
| os.unlink(tmp_path) | |
| logger.debug(f"Cleaned up temporary file: {tmp_path}") | |
| async def analyze_url(request: URLRequest): | |
| """Analyze a document from a URL""" | |
| logger.info(f"Received URL request: {request.url}, resolution: {request.resolution}") | |
| tmp_path = None | |
| try: | |
| logger.info("Downloading file from URL...") | |
| response = requests.get(str(request.url), timeout=30) | |
| response.raise_for_status() | |
| logger.info(f"File downloaded. Size: {len(response.content)} bytes") | |
| content_type = response.headers.get('content-type', '') | |
| ext = '.pdf' if 'pdf' in content_type else '.png' | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp: | |
| tmp.write(response.content) | |
| tmp_path = tmp.name | |
| logger.info(f"Processing file: {tmp_path}") | |
| results = process_document(tmp_path, target_width=request.resolution) | |
| return JSONResponse(content={ | |
| "status": "success", | |
| "url": str(request.url), | |
| "pages": len(results), | |
| "results": results | |
| }) | |
| except requests.RequestException as e: | |
| logger.error(f"Failed to download file from {request.url}: {e}", exc_info=True) | |
| raise HTTPException(status_code=400, detail=f"Failed to download file: {str(e)}") | |
| except Exception as e: | |
| logger.error(f"Error analyzing URL {request.url}: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}") | |
| finally: | |
| if tmp_path and os.path.exists(tmp_path): | |
| os.unlink(tmp_path) | |
| logger.debug(f"Cleaned up temporary file: {tmp_path}") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "device": device, | |
| "models_loaded": { | |
| "ocr": ocr is not None, | |
| "layout_model": layout_model is not None | |
| } | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| logger.info("Starting Document Layout Analysis API server...") | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |