Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,6 +5,7 @@ import logging
|
|
| 5 |
import sys
|
| 6 |
import tempfile
|
| 7 |
import re
|
|
|
|
| 8 |
from pathlib import Path
|
| 9 |
from typing import Optional
|
| 10 |
|
|
@@ -13,7 +14,7 @@ import numpy as np
|
|
| 13 |
import requests
|
| 14 |
import torch
|
| 15 |
import torchvision
|
| 16 |
-
from PIL import Image
|
| 17 |
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
| 18 |
from fastapi.responses import HTMLResponse, JSONResponse
|
| 19 |
from paddleocr import PaddleOCR
|
|
@@ -82,6 +83,7 @@ async def startup_event():
|
|
| 82 |
logger.error(f"Failed to load DocLayout-YOLO model: {e}", exc_info=True)
|
| 83 |
raise RuntimeError("Could not load layout model") from e
|
| 84 |
|
|
|
|
| 85 |
# --- Pydantic Request Models ---
|
| 86 |
class URLRequest(BaseModel):
|
| 87 |
url: HttpUrl
|
|
@@ -274,11 +276,24 @@ def process_document(file_path: str, target_width: Optional[int] = None):
|
|
| 274 |
"confidence": table["confidence"]
|
| 275 |
})
|
| 276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
results.append({
|
| 278 |
"page_number": page.number + 1,
|
| 279 |
"figures": image_entries,
|
| 280 |
"tables": table_entries,
|
| 281 |
-
"image_dimensions": {"width": processed_img.width, "height": processed_img.height}
|
|
|
|
| 282 |
})
|
| 283 |
|
| 284 |
logger.info(f"Page {page_num + 1} processed: {len(image_entries)} figures, {len(table_entries)} tables")
|
|
@@ -295,6 +310,74 @@ def process_document(file_path: str, target_width: Optional[int] = None):
|
|
| 295 |
logger.error(f"Error in process_document: {e}", exc_info=True)
|
| 296 |
raise
|
| 297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
# --- API Endpoints ---
|
| 299 |
@app.get("/", response_class=HTMLResponse)
|
| 300 |
async def read_root():
|
|
@@ -390,9 +473,21 @@ async def read_root():
|
|
| 390 |
<!-- Results -->
|
| 391 |
<div id="results" class="hidden mt-8">
|
| 392 |
<h3 class="text-xl font-bold text-gray-300 mb-4">Analysis Results</h3>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
<div class="rounded-2xl bg-black/30 p-8 ring-1 ring-white/10 backdrop-blur-sm">
|
| 394 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
</div>
|
|
|
|
| 396 |
<button onclick="downloadJSON()" class="mt-4 px-6 py-3 rounded-lg bg-emerald-500/10 text-emerald-400 font-semibold hover:bg-emerald-500/20 transition-all ring-1 ring-emerald-500/30">
|
| 397 |
Download JSON
|
| 398 |
</button>
|
|
@@ -493,7 +588,20 @@ async def read_root():
|
|
| 493 |
}
|
| 494 |
|
| 495 |
analysisResults = responseData;
|
| 496 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
resultsDiv.classList.remove('hidden');
|
| 498 |
|
| 499 |
} catch (err) {
|
|
@@ -506,10 +614,75 @@ async def read_root():
|
|
| 506 |
}
|
| 507 |
}
|
| 508 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 509 |
function downloadJSON() {
|
| 510 |
if (!analysisResults) return;
|
| 511 |
|
| 512 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
const url = URL.createObjectURL(blob);
|
| 514 |
const a = document.createElement('a');
|
| 515 |
a.href = url;
|
|
|
|
| 5 |
import sys
|
| 6 |
import tempfile
|
| 7 |
import re
|
| 8 |
+
import base64
|
| 9 |
from pathlib import Path
|
| 10 |
from typing import Optional
|
| 11 |
|
|
|
|
| 14 |
import requests
|
| 15 |
import torch
|
| 16 |
import torchvision
|
| 17 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 18 |
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
| 19 |
from fastapi.responses import HTMLResponse, JSONResponse
|
| 20 |
from paddleocr import PaddleOCR
|
|
|
|
| 83 |
logger.error(f"Failed to load DocLayout-YOLO model: {e}", exc_info=True)
|
| 84 |
raise RuntimeError("Could not load layout model") from e
|
| 85 |
|
| 86 |
+
|
| 87 |
# --- Pydantic Request Models ---
|
| 88 |
class URLRequest(BaseModel):
|
| 89 |
url: HttpUrl
|
|
|
|
| 276 |
"confidence": table["confidence"]
|
| 277 |
})
|
| 278 |
|
| 279 |
+
# Create annotated image
|
| 280 |
+
annotated_img = create_annotated_image(
|
| 281 |
+
processed_img,
|
| 282 |
+
image_entries,
|
| 283 |
+
table_entries
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Convert annotated image to base64
|
| 287 |
+
buffered = io.BytesIO()
|
| 288 |
+
annotated_img.save(buffered, format="PNG")
|
| 289 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
| 290 |
+
|
| 291 |
results.append({
|
| 292 |
"page_number": page.number + 1,
|
| 293 |
"figures": image_entries,
|
| 294 |
"tables": table_entries,
|
| 295 |
+
"image_dimensions": {"width": processed_img.width, "height": processed_img.height},
|
| 296 |
+
"annotated_image": f"data:image/png;base64,{img_str}"
|
| 297 |
})
|
| 298 |
|
| 299 |
logger.info(f"Page {page_num + 1} processed: {len(image_entries)} figures, {len(table_entries)} tables")
|
|
|
|
| 310 |
logger.error(f"Error in process_document: {e}", exc_info=True)
|
| 311 |
raise
|
| 312 |
|
| 313 |
+
def create_annotated_image(img: Image.Image, figures: list, tables: list) -> Image.Image:
|
| 314 |
+
"""Create an annotated image with bounding boxes."""
|
| 315 |
+
# Create a copy to draw on
|
| 316 |
+
annotated = img.copy()
|
| 317 |
+
draw = ImageDraw.Draw(annotated)
|
| 318 |
+
|
| 319 |
+
# Try to load a font
|
| 320 |
+
try:
|
| 321 |
+
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20)
|
| 322 |
+
small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
|
| 323 |
+
except:
|
| 324 |
+
font = ImageFont.load_default()
|
| 325 |
+
small_font = ImageFont.load_default()
|
| 326 |
+
|
| 327 |
+
# Draw tables (green boxes)
|
| 328 |
+
for table in tables:
|
| 329 |
+
bbox = table["bbox"]
|
| 330 |
+
caption_bbox = table.get("caption_bbox")
|
| 331 |
+
table_num = table.get("table_number", "?")
|
| 332 |
+
conf = table.get("confidence", 0)
|
| 333 |
+
|
| 334 |
+
# Draw table content box
|
| 335 |
+
draw.rectangle(bbox, outline="green", width=3)
|
| 336 |
+
draw.text(
|
| 337 |
+
(bbox[0] + 5, bbox[1] + 5),
|
| 338 |
+
f"Table {table_num} ({conf:.2f})",
|
| 339 |
+
fill="green",
|
| 340 |
+
font=font
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
# Draw caption box
|
| 344 |
+
if caption_bbox:
|
| 345 |
+
draw.rectangle(caption_bbox, outline="blue", width=2)
|
| 346 |
+
draw.text(
|
| 347 |
+
(caption_bbox[0], caption_bbox[1] - 20),
|
| 348 |
+
"Caption",
|
| 349 |
+
fill="blue",
|
| 350 |
+
font=small_font
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
# Draw figures (red boxes)
|
| 354 |
+
for figure in figures:
|
| 355 |
+
bbox = figure["figure_bbox"]
|
| 356 |
+
caption_bbox = figure.get("caption_bbox")
|
| 357 |
+
fig_num = figure.get("figure_number", "?")
|
| 358 |
+
conf = figure.get("confidence", 0)
|
| 359 |
+
|
| 360 |
+
# Draw figure content box
|
| 361 |
+
draw.rectangle(bbox, outline="red", width=3)
|
| 362 |
+
draw.text(
|
| 363 |
+
(bbox[0] + 5, bbox[1] + 5),
|
| 364 |
+
f"Figure {fig_num} ({conf:.2f})",
|
| 365 |
+
fill="red",
|
| 366 |
+
font=font
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# Draw caption box
|
| 370 |
+
if caption_bbox:
|
| 371 |
+
draw.rectangle(caption_bbox, outline="blue", width=2)
|
| 372 |
+
draw.text(
|
| 373 |
+
(caption_bbox[0], caption_bbox[1] - 20),
|
| 374 |
+
"Caption",
|
| 375 |
+
fill="blue",
|
| 376 |
+
font=small_font
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
return annotated
|
| 380 |
+
|
| 381 |
# --- API Endpoints ---
|
| 382 |
@app.get("/", response_class=HTMLResponse)
|
| 383 |
async def read_root():
|
|
|
|
| 473 |
<!-- Results -->
|
| 474 |
<div id="results" class="hidden mt-8">
|
| 475 |
<h3 class="text-xl font-bold text-gray-300 mb-4">Analysis Results</h3>
|
| 476 |
+
|
| 477 |
+
<!-- Annotated Images -->
|
| 478 |
+
<div id="annotatedImages" class="mb-6 space-y-6"></div>
|
| 479 |
+
|
| 480 |
+
<!-- JSON Results -->
|
| 481 |
<div class="rounded-2xl bg-black/30 p-8 ring-1 ring-white/10 backdrop-blur-sm">
|
| 482 |
+
<div class="flex justify-between items-center mb-4">
|
| 483 |
+
<h4 class="text-lg font-semibold text-gray-300">JSON Output</h4>
|
| 484 |
+
<button onclick="toggleJSON()" class="px-4 py-2 rounded-lg bg-gray-500/10 text-gray-400 text-sm hover:bg-gray-500/20 transition-all">
|
| 485 |
+
<span id="toggleText">Show JSON</span>
|
| 486 |
+
</button>
|
| 487 |
+
</div>
|
| 488 |
+
<pre id="resultsContent" class="hidden text-sm text-gray-300 overflow-x-auto max-h-96"></pre>
|
| 489 |
</div>
|
| 490 |
+
|
| 491 |
<button onclick="downloadJSON()" class="mt-4 px-6 py-3 rounded-lg bg-emerald-500/10 text-emerald-400 font-semibold hover:bg-emerald-500/20 transition-all ring-1 ring-emerald-500/30">
|
| 492 |
Download JSON
|
| 493 |
</button>
|
|
|
|
| 588 |
}
|
| 589 |
|
| 590 |
analysisResults = responseData;
|
| 591 |
+
|
| 592 |
+
// Display annotated images
|
| 593 |
+
displayAnnotatedImages(responseData.results);
|
| 594 |
+
|
| 595 |
+
// Prepare JSON without base64 images for display
|
| 596 |
+
const jsonForDisplay = {
|
| 597 |
+
...responseData,
|
| 598 |
+
results: responseData.results.map(r => {
|
| 599 |
+
const {annotated_image, ...rest} = r;
|
| 600 |
+
return rest;
|
| 601 |
+
})
|
| 602 |
+
};
|
| 603 |
+
|
| 604 |
+
document.getElementById('resultsContent').textContent = JSON.stringify(jsonForDisplay, null, 2);
|
| 605 |
resultsDiv.classList.remove('hidden');
|
| 606 |
|
| 607 |
} catch (err) {
|
|
|
|
| 614 |
}
|
| 615 |
}
|
| 616 |
|
| 617 |
+
function displayAnnotatedImages(results) {
|
| 618 |
+
const container = document.getElementById('annotatedImages');
|
| 619 |
+
container.innerHTML = '';
|
| 620 |
+
|
| 621 |
+
results.forEach((page, idx) => {
|
| 622 |
+
if (page.annotated_image) {
|
| 623 |
+
const pageDiv = document.createElement('div');
|
| 624 |
+
pageDiv.className = 'rounded-2xl bg-black/30 p-6 ring-1 ring-white/10 backdrop-blur-sm';
|
| 625 |
+
|
| 626 |
+
const title = document.createElement('h4');
|
| 627 |
+
title.className = 'text-lg font-semibold text-gray-300 mb-4';
|
| 628 |
+
title.textContent = `Page ${page.page_number}`;
|
| 629 |
+
|
| 630 |
+
const stats = document.createElement('div');
|
| 631 |
+
stats.className = 'text-sm text-gray-400 mb-4 flex gap-6';
|
| 632 |
+
stats.innerHTML = `
|
| 633 |
+
<span class="flex items-center gap-2">
|
| 634 |
+
<span class="inline-block w-3 h-3 bg-red-500 rounded"></span>
|
| 635 |
+
${page.figures.length} Figure${page.figures.length !== 1 ? 's' : ''}
|
| 636 |
+
</span>
|
| 637 |
+
<span class="flex items-center gap-2">
|
| 638 |
+
<span class="inline-block w-3 h-3 bg-green-500 rounded"></span>
|
| 639 |
+
${page.tables.length} Table${page.tables.length !== 1 ? 's' : ''}
|
| 640 |
+
</span>
|
| 641 |
+
<span class="flex items-center gap-2">
|
| 642 |
+
<span class="inline-block w-3 h-3 bg-blue-500 rounded"></span>
|
| 643 |
+
Captions
|
| 644 |
+
</span>
|
| 645 |
+
`;
|
| 646 |
+
|
| 647 |
+
const img = document.createElement('img');
|
| 648 |
+
img.src = page.annotated_image;
|
| 649 |
+
img.className = 'w-full rounded-lg border border-white/10';
|
| 650 |
+
img.alt = `Annotated page ${page.page_number}`;
|
| 651 |
+
|
| 652 |
+
pageDiv.appendChild(title);
|
| 653 |
+
pageDiv.appendChild(stats);
|
| 654 |
+
pageDiv.appendChild(img);
|
| 655 |
+
container.appendChild(pageDiv);
|
| 656 |
+
}
|
| 657 |
+
});
|
| 658 |
+
}
|
| 659 |
+
|
| 660 |
+
function toggleJSON() {
|
| 661 |
+
const jsonContent = document.getElementById('resultsContent');
|
| 662 |
+
const toggleText = document.getElementById('toggleText');
|
| 663 |
+
|
| 664 |
+
if (jsonContent.classList.contains('hidden')) {
|
| 665 |
+
jsonContent.classList.remove('hidden');
|
| 666 |
+
toggleText.textContent = 'Hide JSON';
|
| 667 |
+
} else {
|
| 668 |
+
jsonContent.classList.add('hidden');
|
| 669 |
+
toggleText.textContent = 'Show JSON';
|
| 670 |
+
}
|
| 671 |
+
}
|
| 672 |
+
|
| 673 |
function downloadJSON() {
|
| 674 |
if (!analysisResults) return;
|
| 675 |
|
| 676 |
+
// Remove base64 images from download to reduce file size
|
| 677 |
+
const downloadData = {
|
| 678 |
+
...analysisResults,
|
| 679 |
+
results: analysisResults.results.map(r => {
|
| 680 |
+
const {annotated_image, ...rest} = r;
|
| 681 |
+
return rest;
|
| 682 |
+
})
|
| 683 |
+
};
|
| 684 |
+
|
| 685 |
+
const blob = new Blob([JSON.stringify(downloadData, null, 2)], { type: 'application/json' });
|
| 686 |
const url = URL.createObjectURL(blob);
|
| 687 |
const a = document.createElement('a');
|
| 688 |
a.href = url;
|