import os import io import json import logging import sys import tempfile import re import base64 from pathlib import Path from typing import Optional import fitz # PyMuPDF import numpy as np import requests import torch import torchvision from PIL import Image, ImageDraw, ImageFont from fastapi import FastAPI, File, Form, HTTPException, UploadFile from fastapi.responses import HTMLResponse, JSONResponse from paddleocr import PaddleOCR from pydantic import BaseModel, HttpUrl # --- Configure Logging --- logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger(__name__) logger.info("Starting application initialization...") # --- Configuration --- MODEL_PATH = "/content/layout-model.pt" # --- Global Variables --- ocr: Optional[PaddleOCR] = None layout_model = None device: str = 'cuda' if torch.cuda.is_available() else 'cpu' # Label mapping id_to_names = { 0: 'title', 1: 'plain text', 2: 'abandon', 3: 'figure', 4: 'figure_caption', 5: 'table', 6: 'table_caption', 7: 'table_footnote', 8: 'isolate_formula', 9: 'formula_caption' } # --- FastAPI Application --- app = FastAPI(title="Document Layout Analysis API", version="1.0.0") # --- FastAPI Startup Event --- @app.on_event("startup") async def startup_event(): global ocr, layout_model lang="en" try: logger.info("Initializing PaddleOCR...") ocr = PaddleOCR( use_angle_cls=True, lang=lang, use_gpu=False, show_log=False, det_model_dir=f'/app/models/det/{lang}/en_PP-OCRv3_det_infer', rec_model_dir=f'/app/models/rec/{lang}/en_PP-OCRv4_rec_infer', cls_model_dir=f'/app/models/cls/{lang}/ch_ppocr_mobile_v2.0_cls_infer' ) logger.info("✓ PaddleOCR initialized successfully") except Exception as e: logger.error(f"Failed to initialize PaddleOCR: {e}", exc_info=True) raise RuntimeError("Could not initialize PaddleOCR") from e try: logger.info(f"Loading DocLayout-YOLO model from {MODEL_PATH}...") if not os.path.exists(MODEL_PATH): logger.error(f"Model file not found at {MODEL_PATH}") raise FileNotFoundError(f"Model file not found at {MODEL_PATH}") # Import YOLOv10 from doclayout_yolo from doclayout_yolo import YOLOv10 layout_model = YOLOv10(MODEL_PATH) logger.info(f"✓ DocLayout-YOLO model loaded successfully on device: {device}") except Exception as e: logger.error(f"Failed to load DocLayout-YOLO model: {e}", exc_info=True) raise RuntimeError("Could not load layout model") from e # --- Pydantic Request Models --- class URLRequest(BaseModel): url: HttpUrl resolution: Optional[int] = None # --- Helper Functions --- def extract_number_from_caption(caption_text: str) -> Optional[str]: """Extract the number from a caption like 'Table 3' or 'Figure 2.1'""" if not caption_text: return None NUMBER_PATTERN = re.compile(r"(?:Table|Figure)\s*([\d\.]+)", re.IGNORECASE) match = NUMBER_PATTERN.search(caption_text) return match.group(1) if match else None def detect_layout_regions(page: fitz.Page, target_width: Optional[int] = None, conf_threshold=0.25, iou_threshold=0.3): """Use DocLayout-YOLO to detect document elements.""" if layout_model is None: raise RuntimeError("Layout model is not initialized.") logger.debug(f"Detecting layout regions with target_width={target_width}") try: pix = page.get_pixmap(dpi=150) img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) if target_width: aspect_ratio = img.height / img.width target_height = int(target_width * aspect_ratio) img = img.resize((target_width, target_height), Image.LANCZOS) logger.debug(f"Resized image to {target_width}x{target_height}") logger.debug(f"Running model prediction on image size: {img.width}x{img.height}") # Run prediction using YOLOv10 results = layout_model.predict( img, imgsz=1280, conf=conf_threshold, device=device ) # Get first result det_res = results[0] # Access boxes using the correct API boxes = det_res.boxes.xyxy.cpu().numpy() classes = det_res.boxes.cls.cpu().numpy() scores = det_res.boxes.conf.cpu().numpy() logger.debug(f"Detected {len(boxes)} boxes before NMS") if len(boxes) == 0: logger.info("No objects detected") return [], img # Apply NMS boxes_tensor = torch.from_numpy(boxes) scores_tensor = torch.from_numpy(scores) indices = torchvision.ops.nms(boxes_tensor, scores_tensor, iou_threshold) boxes = boxes[indices.numpy()] scores = scores[indices.numpy()] classes = classes[indices.numpy()] logger.debug(f"Detected {len(boxes)} boxes after NMS") detected_regions = [] for box, score, cls in zip(boxes, scores, classes): box = [float(coord) for coord in box] label_name = id_to_names.get(int(cls), 'unknown') detected_regions.append({ "bbox": box, "type": label_name, "confidence": float(score) }) logger.debug(f"Returning {len(detected_regions)} detected regions") return detected_regions, img except Exception as e: logger.error(f"Error in detect_layout_regions: {e}", exc_info=True) raise def extract_text_from_bbox(img: Image.Image, bbox: list, padding: int = 5) -> str: """Run OCR on a specific bounding box region of a PIL Image.""" if ocr is None: raise RuntimeError("OCR model is not initialized.") logger.debug(f"Extracting text from bbox: {bbox}") try: x0, y0, x1, y1 = [int(coord) for coord in bbox] x0 = max(0, x0 - padding) y0 = max(0, y0 - padding) x1 = min(img.width, x1 + padding) y1 = min(img.height, y1 + padding) if x0 >= x1 or y0 >= y1: logger.debug("Invalid bbox dimensions") return "" region = img.crop((x0, y0, x1, y1)) region_np = np.array(region) ocr_result = ocr.ocr(region_np, cls=True) if not ocr_result or not ocr_result[0]: logger.debug("No OCR results") return "" text_parts = [line[1][0] for line in ocr_result[0]] result_text = " ".join(text_parts) logger.debug(f"Extracted text: {result_text[:100]}...") return result_text except Exception as e: logger.error(f"Error in extract_text_from_bbox: {e}", exc_info=True) return "" def process_document(file_path: str, target_width: Optional[int] = None): """Process a document and extract layout information.""" logger.info(f"Processing document: {file_path}") try: doc = fitz.open(file_path) logger.info(f"Document opened successfully. Pages: {len(doc)}") results = [] for page_num, page in enumerate(doc): logger.info(f"Processing page {page_num + 1}/{len(doc)}") try: detected_regions, processed_img = detect_layout_regions(page, target_width=target_width) # Group regions by type figures = [r for r in detected_regions if r["type"] == 'figure'] figure_captions = [r for r in detected_regions if r["type"] == 'figure_caption'] tables = [r for r in detected_regions if r["type"] == 'table'] table_captions = [r for r in detected_regions if r["type"] == 'table_caption'] logger.debug(f"Found {len(figures)} figures, {len(figure_captions)} figure captions, {len(tables)} tables, {len(table_captions)} table captions") image_entries = [] table_entries = [] # Match figures with their captions (caption usually BELOW figure) for idx, figure in enumerate(figures, start=1): figure_bbox = figure["bbox"] best_caption = None min_distance = float('inf') for caption in figure_captions: cap_bbox = caption["bbox"] distance = cap_bbox[1] - figure_bbox[3] if 0 <= distance < min_distance: min_distance = distance best_caption = caption caption_text = extract_text_from_bbox(processed_img, best_caption["bbox"]) if best_caption else None figure_number = extract_number_from_caption(caption_text) or str(idx) image_entries.append({ "figure_number": figure_number, "figure_bbox": figure_bbox, "caption": caption_text, "caption_bbox": best_caption["bbox"] if best_caption else None, "confidence": figure["confidence"] }) # Match tables with their captions (caption usually ABOVE table) for idx, table in enumerate(tables, start=1): table_bbox = table["bbox"] best_caption = None min_distance = float('inf') for caption in table_captions: cap_bbox = caption["bbox"] distance = table_bbox[1] - cap_bbox[3] if 0 <= distance < min_distance: min_distance = distance best_caption = caption caption_text = extract_text_from_bbox(processed_img, best_caption["bbox"]) if best_caption else None table_number = extract_number_from_caption(caption_text) or str(idx) table_entries.append({ "table_number": table_number, "bbox": table_bbox, "caption": caption_text, "caption_bbox": best_caption["bbox"] if best_caption else None, "confidence": table["confidence"] }) # Create annotated image annotated_img = create_annotated_image( processed_img, image_entries, table_entries ) # Convert annotated image to base64 buffered = io.BytesIO() annotated_img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() results.append({ "page_number": page.number + 1, "figures": image_entries, "tables": table_entries, "image_dimensions": {"width": processed_img.width, "height": processed_img.height}, "annotated_image": f"data:image/png;base64,{img_str}" }) logger.info(f"Page {page_num + 1} processed: {len(image_entries)} figures, {len(table_entries)} tables") except Exception as e: logger.error(f"Error processing page {page_num + 1}: {e}", exc_info=True) raise doc.close() logger.info(f"Document processing completed. Total pages: {len(results)}") return results except Exception as e: logger.error(f"Error in process_document: {e}", exc_info=True) raise def create_annotated_image(img: Image.Image, figures: list, tables: list) -> Image.Image: """Create an annotated image with bounding boxes.""" # Create a copy to draw on annotated = img.copy() draw = ImageDraw.Draw(annotated) # Try to load a font try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20) small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14) except: font = ImageFont.load_default() small_font = ImageFont.load_default() # Draw tables (green boxes) for table in tables: bbox = table["bbox"] caption_bbox = table.get("caption_bbox") table_num = table.get("table_number", "?") conf = table.get("confidence", 0) # Draw table content box draw.rectangle(bbox, outline="green", width=3) draw.text( (bbox[0] + 5, bbox[1] + 5), f"Table {table_num} ({conf:.2f})", fill="green", font=font ) # Draw caption box if caption_bbox: draw.rectangle(caption_bbox, outline="blue", width=2) draw.text( (caption_bbox[0], caption_bbox[1] - 20), "Caption", fill="blue", font=small_font ) # Draw figures (red boxes) for figure in figures: bbox = figure["figure_bbox"] caption_bbox = figure.get("caption_bbox") fig_num = figure.get("figure_number", "?") conf = figure.get("confidence", 0) # Draw figure content box draw.rectangle(bbox, outline="red", width=3) draw.text( (bbox[0] + 5, bbox[1] + 5), f"Figure {fig_num} ({conf:.2f})", fill="red", font=font ) # Draw caption box if caption_bbox: draw.rectangle(caption_bbox, outline="blue", width=2) draw.text( (caption_bbox[0], caption_bbox[1] - 20), "Caption", fill="blue", font=small_font ) return annotated # --- API Endpoints --- @app.get("/", response_class=HTMLResponse) async def read_root(): """Serve the frontend UI""" html_content = """
Extract tables, figures, and captions from PDFs and images with precision
Processing document...
/analyze
Analyze a document by uploading a file
curl -X POST "http://your-api-url/analyze" \\
-F "file=@document.pdf" \\
-F "resolution=1280"
/analyze-url
Analyze a document from a URL
curl -X POST "http://your-api-url/analyze-url" \\
-H "Content-Type: application/json" \\
-d '{"url": "https://example.com/doc.pdf", "resolution": 1280}'