ai_manga_translator / visualization.py
jzhang533's picture
minor
4700e59
"""
Visualization utilities for drawing text detection boxes on images
"""
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from typing import List, Dict, Tuple
import os
import math
def generate_random_color() -> Tuple[int, int, int]:
"""
Generate a random color for bounding boxes
Returns:
RGB color tuple
"""
return (
np.random.randint(0, 200),
np.random.randint(0, 200),
np.random.randint(0, 255)
)
def draw_detection_boxes(
image: Image.Image,
detections: List[Dict],
box_width: int = 2,
font_size: int = 12,
show_text: bool = True,
merge_boxes: bool = True
) -> Image.Image:
"""
Draw text detection boxes with labels on image
Args:
image: PIL Image to draw on
detections: List of detection dicts with 'text', 'x1', 'y1', 'x2', 'y2'
box_width: Width of bounding box lines
font_size: Font size for text labels
show_text: Whether to show text labels
merge_boxes: Whether to merge close boxes (default: True)
Returns:
New image with boxes and labels drawn
"""
# Merge detections if requested
if merge_boxes:
detections = merge_detections(detections)
# Create a copy of the image
img_draw = image.copy().convert('RGBA')
# Create transparent overlay for semi-transparent boxes
overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
draw_overlay = ImageDraw.Draw(overlay)
draw = ImageDraw.Draw(img_draw)
# Try to load a better font that supports CJK (Chinese/Japanese/Korean)
# Prioritize local fonts folder for portability
font_paths = [
# Local fonts (project/fonts/) - Prioritize slim/light fonts
os.path.join(os.path.dirname(__file__), "fonts", "NotoSansCJK-Light.ttc"),
os.path.join(os.path.dirname(__file__), "fonts", "NotoSansCJK-Regular.ttc"),
os.path.join(os.path.dirname(__file__), "fonts", "STHeiti-Light.ttc"),
# macOS fonts
"/System/Library/Fonts/STHeiti Light.ttc",
"/System/Library/Fonts/PingFang.ttc",
"/System/Library/Fonts/Hiragino Sans GB.ttc",
# Linux fonts
"/usr/share/fonts/truetype/noto/NotoSansCJK-Light.ttc",
"/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc",
"/usr/share/fonts/truetype/wqy/wqy-microhei.ttc"
]
font = None
valid_font_path = None
for path in font_paths:
try:
font = ImageFont.truetype(path, font_size)
valid_font_path = path
break
except (IOError, OSError):
continue
if font is None:
# Fallback to default if no custom font loaded
font = ImageFont.load_default()
# Draw each detection
for i, detection in enumerate(detections, 1):
try:
text = detection['text']
x1, y1 = detection['x1'], detection['y1']
x2, y2 = detection['x2'], detection['y2']
# Calculate box dimensions
box_w = x2 - x1
box_h = y2 - y1
# Helper function to wrap text and calculate size
def get_text_layout(text, font, max_width):
lines = []
raw_lines = text.split('\n')
for raw_line in raw_lines:
current_line = ""
for char in raw_line:
test_line = current_line + char
bbox = draw.textbbox((0, 0), test_line, font=font)
if bbox[2] - bbox[0] < max_width:
current_line = test_line
else:
if current_line:
lines.append(current_line)
current_line = char
if current_line:
lines.append(current_line)
# Calculate total height
if not lines:
return [], 0, 0
# Get line height from font metrics
ascent, descent = font.getmetrics()
line_height = ascent + descent
total_height = len(lines) * line_height * 1.2 # 1.2 line spacing
max_line_w = 0
for line in lines:
bbox = draw.textbbox((0, 0), line, font=font)
max_line_w = max(max_line_w, bbox[2] - bbox[0])
return lines, total_height, max_line_w
# Use fixed font size as requested
font_size_to_use = 12
try:
if valid_font_path:
font_to_use = ImageFont.truetype(valid_font_path, font_size_to_use)
else:
font_to_use = ImageFont.load_default()
except:
font_to_use = ImageFont.load_default()
# Calculate max allowed dimensions (max 20% larger)
max_allowed_w = int(box_w * 1.2)
max_allowed_h = int(box_h * 1.2)
# Try layout with max allowed width to minimize height
# Use -8 for padding (4px left, 4px right)
lines, total_h, max_line_w = get_text_layout(text, font_to_use, max_allowed_w - 8)
# Determine new dimensions, capped at 20% expansion
# We ensure we don't shrink below original size
new_w = max(box_w, min(max_line_w + 8, max_allowed_w))
new_h = max(box_h, min(total_h + 4, max_allowed_h))
# Update box coordinates
x2 = x1 + new_w
y2 = y1 + new_h
box_w = new_w
box_h = new_h
# 1. Draw box with soft background (no border)
draw.rectangle(
[x1, y1, x2, y2],
fill=(255, 250, 240), # FloralWhite (soft background)
outline=None
)
# 4. Draw text left-aligned horizontally and centered vertically
# Get metrics again for drawing
ascent, descent = font_to_use.getmetrics()
line_height = (ascent + descent) * 1.2
start_y = y1 + (box_h - total_h) / 2
for j, line in enumerate(lines):
# Left align with small padding
line_x = x1 + 4
line_y = start_y + j * line_height
# Draw text with a bright red color
text_color = (150, 0, 0)
draw.text((line_x, line_y), line, font=font_to_use, fill=text_color)
except Exception as e:
print(f"Error drawing detection box: {str(e)}")
continue
except Exception as e:
print(f"Error drawing detection box: {str(e)}")
continue
# Composite the overlay onto the image
img_draw.paste(overlay, (0, 0), overlay)
# Convert back to RGB
return img_draw.convert('RGB')
def create_side_by_side_comparison(
original: Image.Image,
annotated: Image.Image,
spacing: int = 20
) -> Image.Image:
"""
Create side-by-side comparison of original and annotated images
Args:
original: Original image
annotated: Annotated image with boxes
spacing: Space between images in pixels
Returns:
Combined image showing both versions
"""
# Get dimensions
width1, height1 = original.size
width2, height2 = annotated.size
# Create new image
total_width = width1 + width2 + spacing
total_height = max(height1, height2)
combined = Image.new('RGB', (total_width, total_height), (255, 255, 255))
# Paste images
combined.paste(original, (0, 0))
combined.paste(annotated, (width1 + spacing, 0))
# Add labels
draw = ImageDraw.Draw(combined)
# Try to load a better font that supports CJK
font_paths = [
"/System/Library/Fonts/PingFang.ttc",
"/System/Library/Fonts/Hiragino Sans GB.ttc",
"/System/Library/Fonts/STHeiti Light.ttc",
"/System/Library/Fonts/Supplemental/Arial Unicode.ttf",
"/System/Library/Fonts/Supplemental/Arial.ttf",
"/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc",
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf"
]
font = None
for path in font_paths:
try:
font = ImageFont.truetype(path, 24)
break
except (IOError, OSError):
continue
if font is None:
font = ImageFont.load_default()
draw.text((10, 10), "Original", font=font, fill=(0, 0, 0))
draw.text((width1 + spacing + 10, 10), "Detected Text", font=font, fill=(0, 0, 0))
return combined
def get_detection_summary(detections: List[Dict]) -> str:
"""
Create a text summary of detection results
Args:
detections: List of detection dictionaries
Returns:
Formatted summary string
"""
if not detections:
return "No text detected in the image."
summary = f"Detected {len(detections)} text region(s):\n\n"
for i, det in enumerate(detections, 1):
if 'original_text' in det and det['original_text'] != det['text']:
summary += f"{i}. Original: \"{det['original_text']}\"\n"
summary += f" Translated: \"{det['text']}\"\n"
else:
summary += f"{i}. \"{det['text']}\"\n"
summary += f" Location: ({det['x1']}, {det['y1']}) → ({det['x2']}, {det['y2']})\n\n"
return summary
def merge_detections(detections: List[Dict], threshold: int = 30) -> List[Dict]:
"""
Merge close detection boxes into single boxes
Args:
detections: List of detection dicts
threshold: Distance threshold for merging
Returns:
List of merged detection dicts
"""
if not detections:
return []
# Helper to check if two boxes are close
def are_close(box1, box2, thresh):
# Expand box1 by thresh
b1_x1, b1_y1 = box1['x1'] - thresh, box1['y1'] - thresh
b1_x2, b1_y2 = box1['x2'] + thresh, box1['y2'] + thresh
# Check overlap with box2
return not (b1_x2 < box2['x1'] or b1_x1 > box2['x2'] or
b1_y2 < box2['y1'] or b1_y1 > box2['y2'])
# Build adjacency list
n = len(detections)
adj = [[] for _ in range(n)]
for i in range(n):
for j in range(i + 1, n):
if are_close(detections[i], detections[j], threshold):
adj[i].append(j)
adj[j].append(i)
# Find connected components
visited = [False] * n
merged_results = []
for i in range(n):
if not visited[i]:
# BFS to find component
component = []
stack = [i]
visited[i] = True
while stack:
curr = stack.pop()
component.append(detections[curr])
for neighbor in adj[curr]:
if not visited[neighbor]:
visited[neighbor] = True
stack.append(neighbor)
# Merge component
if not component:
continue
# Calculate merged bounds
min_x1 = min(d['x1'] for d in component)
min_y1 = min(d['y1'] for d in component)
max_x2 = max(d['x2'] for d in component)
max_y2 = max(d['y2'] for d in component)
# Sort texts: Right-to-Left (descending X), then Top-to-Bottom (ascending Y)
# This is standard for Manga reading order
component.sort(key=lambda d: (-d['x1'], d['y1']))
merged_text = "".join(d['text'] for d in component).replace(" ", "")
merged_results.append({
'text': merged_text,
'x1': min_x1,
'y1': min_y1,
'x2': max_x2,
'y2': max_y2
})
return merged_results