bla commited on
Commit
ee69d03
·
verified ·
1 Parent(s): 81803f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -5
app.py CHANGED
@@ -5,6 +5,7 @@ import logging
5
  import sys
6
  import tempfile
7
  import re
 
8
  from pathlib import Path
9
  from typing import Optional
10
 
@@ -13,7 +14,7 @@ import numpy as np
13
  import requests
14
  import torch
15
  import torchvision
16
- from PIL import Image
17
  from fastapi import FastAPI, File, Form, HTTPException, UploadFile
18
  from fastapi.responses import HTMLResponse, JSONResponse
19
  from paddleocr import PaddleOCR
@@ -82,6 +83,7 @@ async def startup_event():
82
  logger.error(f"Failed to load DocLayout-YOLO model: {e}", exc_info=True)
83
  raise RuntimeError("Could not load layout model") from e
84
 
 
85
  # --- Pydantic Request Models ---
86
  class URLRequest(BaseModel):
87
  url: HttpUrl
@@ -274,11 +276,24 @@ def process_document(file_path: str, target_width: Optional[int] = None):
274
  "confidence": table["confidence"]
275
  })
276
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  results.append({
278
  "page_number": page.number + 1,
279
  "figures": image_entries,
280
  "tables": table_entries,
281
- "image_dimensions": {"width": processed_img.width, "height": processed_img.height}
 
282
  })
283
 
284
  logger.info(f"Page {page_num + 1} processed: {len(image_entries)} figures, {len(table_entries)} tables")
@@ -295,6 +310,74 @@ def process_document(file_path: str, target_width: Optional[int] = None):
295
  logger.error(f"Error in process_document: {e}", exc_info=True)
296
  raise
297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  # --- API Endpoints ---
299
  @app.get("/", response_class=HTMLResponse)
300
  async def read_root():
@@ -390,9 +473,21 @@ async def read_root():
390
  <!-- Results -->
391
  <div id="results" class="hidden mt-8">
392
  <h3 class="text-xl font-bold text-gray-300 mb-4">Analysis Results</h3>
 
 
 
 
 
393
  <div class="rounded-2xl bg-black/30 p-8 ring-1 ring-white/10 backdrop-blur-sm">
394
- <pre id="resultsContent" class="text-sm text-gray-300 overflow-x-auto"></pre>
 
 
 
 
 
 
395
  </div>
 
396
  <button onclick="downloadJSON()" class="mt-4 px-6 py-3 rounded-lg bg-emerald-500/10 text-emerald-400 font-semibold hover:bg-emerald-500/20 transition-all ring-1 ring-emerald-500/30">
397
  Download JSON
398
  </button>
@@ -493,7 +588,20 @@ async def read_root():
493
  }
494
 
495
  analysisResults = responseData;
496
- document.getElementById('resultsContent').textContent = JSON.stringify(analysisResults, null, 2);
 
 
 
 
 
 
 
 
 
 
 
 
 
497
  resultsDiv.classList.remove('hidden');
498
 
499
  } catch (err) {
@@ -506,10 +614,75 @@ async def read_root():
506
  }
507
  }
508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
  function downloadJSON() {
510
  if (!analysisResults) return;
511
 
512
- const blob = new Blob([JSON.stringify(analysisResults, null, 2)], { type: 'application/json' });
 
 
 
 
 
 
 
 
 
513
  const url = URL.createObjectURL(blob);
514
  const a = document.createElement('a');
515
  a.href = url;
 
5
  import sys
6
  import tempfile
7
  import re
8
+ import base64
9
  from pathlib import Path
10
  from typing import Optional
11
 
 
14
  import requests
15
  import torch
16
  import torchvision
17
+ from PIL import Image, ImageDraw, ImageFont
18
  from fastapi import FastAPI, File, Form, HTTPException, UploadFile
19
  from fastapi.responses import HTMLResponse, JSONResponse
20
  from paddleocr import PaddleOCR
 
83
  logger.error(f"Failed to load DocLayout-YOLO model: {e}", exc_info=True)
84
  raise RuntimeError("Could not load layout model") from e
85
 
86
+
87
  # --- Pydantic Request Models ---
88
  class URLRequest(BaseModel):
89
  url: HttpUrl
 
276
  "confidence": table["confidence"]
277
  })
278
 
279
+ # Create annotated image
280
+ annotated_img = create_annotated_image(
281
+ processed_img,
282
+ image_entries,
283
+ table_entries
284
+ )
285
+
286
+ # Convert annotated image to base64
287
+ buffered = io.BytesIO()
288
+ annotated_img.save(buffered, format="PNG")
289
+ img_str = base64.b64encode(buffered.getvalue()).decode()
290
+
291
  results.append({
292
  "page_number": page.number + 1,
293
  "figures": image_entries,
294
  "tables": table_entries,
295
+ "image_dimensions": {"width": processed_img.width, "height": processed_img.height},
296
+ "annotated_image": f"data:image/png;base64,{img_str}"
297
  })
298
 
299
  logger.info(f"Page {page_num + 1} processed: {len(image_entries)} figures, {len(table_entries)} tables")
 
310
  logger.error(f"Error in process_document: {e}", exc_info=True)
311
  raise
312
 
313
+ def create_annotated_image(img: Image.Image, figures: list, tables: list) -> Image.Image:
314
+ """Create an annotated image with bounding boxes."""
315
+ # Create a copy to draw on
316
+ annotated = img.copy()
317
+ draw = ImageDraw.Draw(annotated)
318
+
319
+ # Try to load a font
320
+ try:
321
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20)
322
+ small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
323
+ except:
324
+ font = ImageFont.load_default()
325
+ small_font = ImageFont.load_default()
326
+
327
+ # Draw tables (green boxes)
328
+ for table in tables:
329
+ bbox = table["bbox"]
330
+ caption_bbox = table.get("caption_bbox")
331
+ table_num = table.get("table_number", "?")
332
+ conf = table.get("confidence", 0)
333
+
334
+ # Draw table content box
335
+ draw.rectangle(bbox, outline="green", width=3)
336
+ draw.text(
337
+ (bbox[0] + 5, bbox[1] + 5),
338
+ f"Table {table_num} ({conf:.2f})",
339
+ fill="green",
340
+ font=font
341
+ )
342
+
343
+ # Draw caption box
344
+ if caption_bbox:
345
+ draw.rectangle(caption_bbox, outline="blue", width=2)
346
+ draw.text(
347
+ (caption_bbox[0], caption_bbox[1] - 20),
348
+ "Caption",
349
+ fill="blue",
350
+ font=small_font
351
+ )
352
+
353
+ # Draw figures (red boxes)
354
+ for figure in figures:
355
+ bbox = figure["figure_bbox"]
356
+ caption_bbox = figure.get("caption_bbox")
357
+ fig_num = figure.get("figure_number", "?")
358
+ conf = figure.get("confidence", 0)
359
+
360
+ # Draw figure content box
361
+ draw.rectangle(bbox, outline="red", width=3)
362
+ draw.text(
363
+ (bbox[0] + 5, bbox[1] + 5),
364
+ f"Figure {fig_num} ({conf:.2f})",
365
+ fill="red",
366
+ font=font
367
+ )
368
+
369
+ # Draw caption box
370
+ if caption_bbox:
371
+ draw.rectangle(caption_bbox, outline="blue", width=2)
372
+ draw.text(
373
+ (caption_bbox[0], caption_bbox[1] - 20),
374
+ "Caption",
375
+ fill="blue",
376
+ font=small_font
377
+ )
378
+
379
+ return annotated
380
+
381
  # --- API Endpoints ---
382
  @app.get("/", response_class=HTMLResponse)
383
  async def read_root():
 
473
  <!-- Results -->
474
  <div id="results" class="hidden mt-8">
475
  <h3 class="text-xl font-bold text-gray-300 mb-4">Analysis Results</h3>
476
+
477
+ <!-- Annotated Images -->
478
+ <div id="annotatedImages" class="mb-6 space-y-6"></div>
479
+
480
+ <!-- JSON Results -->
481
  <div class="rounded-2xl bg-black/30 p-8 ring-1 ring-white/10 backdrop-blur-sm">
482
+ <div class="flex justify-between items-center mb-4">
483
+ <h4 class="text-lg font-semibold text-gray-300">JSON Output</h4>
484
+ <button onclick="toggleJSON()" class="px-4 py-2 rounded-lg bg-gray-500/10 text-gray-400 text-sm hover:bg-gray-500/20 transition-all">
485
+ <span id="toggleText">Show JSON</span>
486
+ </button>
487
+ </div>
488
+ <pre id="resultsContent" class="hidden text-sm text-gray-300 overflow-x-auto max-h-96"></pre>
489
  </div>
490
+
491
  <button onclick="downloadJSON()" class="mt-4 px-6 py-3 rounded-lg bg-emerald-500/10 text-emerald-400 font-semibold hover:bg-emerald-500/20 transition-all ring-1 ring-emerald-500/30">
492
  Download JSON
493
  </button>
 
588
  }
589
 
590
  analysisResults = responseData;
591
+
592
+ // Display annotated images
593
+ displayAnnotatedImages(responseData.results);
594
+
595
+ // Prepare JSON without base64 images for display
596
+ const jsonForDisplay = {
597
+ ...responseData,
598
+ results: responseData.results.map(r => {
599
+ const {annotated_image, ...rest} = r;
600
+ return rest;
601
+ })
602
+ };
603
+
604
+ document.getElementById('resultsContent').textContent = JSON.stringify(jsonForDisplay, null, 2);
605
  resultsDiv.classList.remove('hidden');
606
 
607
  } catch (err) {
 
614
  }
615
  }
616
 
617
+ function displayAnnotatedImages(results) {
618
+ const container = document.getElementById('annotatedImages');
619
+ container.innerHTML = '';
620
+
621
+ results.forEach((page, idx) => {
622
+ if (page.annotated_image) {
623
+ const pageDiv = document.createElement('div');
624
+ pageDiv.className = 'rounded-2xl bg-black/30 p-6 ring-1 ring-white/10 backdrop-blur-sm';
625
+
626
+ const title = document.createElement('h4');
627
+ title.className = 'text-lg font-semibold text-gray-300 mb-4';
628
+ title.textContent = `Page ${page.page_number}`;
629
+
630
+ const stats = document.createElement('div');
631
+ stats.className = 'text-sm text-gray-400 mb-4 flex gap-6';
632
+ stats.innerHTML = `
633
+ <span class="flex items-center gap-2">
634
+ <span class="inline-block w-3 h-3 bg-red-500 rounded"></span>
635
+ ${page.figures.length} Figure${page.figures.length !== 1 ? 's' : ''}
636
+ </span>
637
+ <span class="flex items-center gap-2">
638
+ <span class="inline-block w-3 h-3 bg-green-500 rounded"></span>
639
+ ${page.tables.length} Table${page.tables.length !== 1 ? 's' : ''}
640
+ </span>
641
+ <span class="flex items-center gap-2">
642
+ <span class="inline-block w-3 h-3 bg-blue-500 rounded"></span>
643
+ Captions
644
+ </span>
645
+ `;
646
+
647
+ const img = document.createElement('img');
648
+ img.src = page.annotated_image;
649
+ img.className = 'w-full rounded-lg border border-white/10';
650
+ img.alt = `Annotated page ${page.page_number}`;
651
+
652
+ pageDiv.appendChild(title);
653
+ pageDiv.appendChild(stats);
654
+ pageDiv.appendChild(img);
655
+ container.appendChild(pageDiv);
656
+ }
657
+ });
658
+ }
659
+
660
+ function toggleJSON() {
661
+ const jsonContent = document.getElementById('resultsContent');
662
+ const toggleText = document.getElementById('toggleText');
663
+
664
+ if (jsonContent.classList.contains('hidden')) {
665
+ jsonContent.classList.remove('hidden');
666
+ toggleText.textContent = 'Hide JSON';
667
+ } else {
668
+ jsonContent.classList.add('hidden');
669
+ toggleText.textContent = 'Show JSON';
670
+ }
671
+ }
672
+
673
  function downloadJSON() {
674
  if (!analysisResults) return;
675
 
676
+ // Remove base64 images from download to reduce file size
677
+ const downloadData = {
678
+ ...analysisResults,
679
+ results: analysisResults.results.map(r => {
680
+ const {annotated_image, ...rest} = r;
681
+ return rest;
682
+ })
683
+ };
684
+
685
+ const blob = new Blob([JSON.stringify(downloadData, null, 2)], { type: 'application/json' });
686
  const url = URL.createObjectURL(blob);
687
  const a = document.createElement('a');
688
  a.href = url;