AtlasOCR-demo / app_legacy.py
imomayiz's picture
Rename app.py to app_legacy.py
1cf31f5 verified
import gradio as gr
import torch
from PIL import Image
import logging
from typing import Optional, Union
import os
import spaces
from dotenv import load_dotenv
load_dotenv()
# Disable torch compilation to avoid dynamo issues
torch._dynamo.config.disable = True
torch.backends.cudnn.allow_tf32 = True
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class AtlasOCR:
def __init__(self, model_name: str = "atlasia/AtlasOCR", max_tokens: int = 2000):
"""Initialize the AtlasOCR model with proper error handling."""
try:
from unsloth import FastVisionModel
logger.info(f"Loading model: {model_name}")
# Disable compilation for the model
with torch._dynamo.config.patch(disable=True):
self.model, self.processor = FastVisionModel.from_pretrained(
model_name,
device_map="auto",
load_in_4bit=True,
use_gradient_checkpointing="unsloth",
token=os.environ["HF_API_KEY"]
)
# Ensure model is not compiled
if hasattr(self.model, '_dynamo_compile'):
self.model._dynamo_compile = False
self.max_tokens = max_tokens
self.prompt = ""
self.device = next(self.model.parameters()).device
logger.info(f"Model loaded successfully on device: {self.device}")
except ImportError:
logger.error("unsloth not found. Please install it: pip install unsloth")
raise
except Exception as e:
logger.error(f"Error loading model: {e}")
raise
def prepare_inputs(self, image: Image.Image) -> dict:
"""Prepare inputs for the model with proper error handling."""
try:
messages = [
{
"role": "user",
"content": [
{
"type": "image",
},
{"type": "text", "text": self.prompt},
],
}
]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = self.processor(
image,
text,
add_special_tokens=False,
return_tensors="pt",
)
return inputs
except Exception as e:
logger.error(f"Error preparing inputs: {e}")
raise
def predict(self, image: Image.Image) -> str:
"""Predict text from image with comprehensive error handling."""
try:
if image is None:
return "Please upload an image."
# Convert numpy array to PIL Image if needed
if hasattr(image, 'shape'): # numpy array
image = Image.fromarray(image)
inputs = self.prepare_inputs(image)
# Move inputs to the same device as model with explicit device handling
device = self.device
logger.info(f"Moving inputs to device: {device}")
# Manually move each tensor to device
for key in inputs:
if hasattr(inputs[key], 'to'):
inputs[key] = inputs[key].to(device)
# Ensure attention_mask is float32 and on correct device
if 'attention_mask' in inputs:
inputs['attention_mask'] = inputs['attention_mask'].to(dtype=torch.float32, device=device)
logger.info(f"Generating text with max_tokens={self.max_tokens}")
# Disable compilation during generation
with torch.no_grad(), torch._dynamo.config.patch(disable=True):
generated_ids = self.model.generate(
**inputs,
max_new_tokens=self.max_tokens,
use_cache=True,
do_sample=False,
temperature=0.1,
pad_token_id=self.processor.tokenizer.eos_token_id
)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
result = output_text[0].strip()
logger.info(f"Generated text: {result[:100]}...")
return result
except Exception as e:
logger.error(f"Error during prediction: {e}")
return f"Error processing image: {str(e)}"
def __call__(self, image: Union[Image.Image, str]) -> str:
"""Callable interface for the model."""
if isinstance(image, str):
return "Please upload an image file."
return self.predict(image)
# Global model instance
atlas_ocr = None
def load_model():
"""Load the model globally to avoid reloading."""
global atlas_ocr
if atlas_ocr is None:
try:
atlas_ocr = AtlasOCR()
except Exception as e:
logger.error(f"Failed to load model: {e}")
return False
return True
@spaces.GPU
def perform_ocr(image):
"""Main OCR function with proper error handling."""
try:
if not load_model():
return "Error: Failed to load model. Please check the logs."
if image is None:
return "Please upload an image to extract text."
result = atlas_ocr(image)
return result
except Exception as e:
logger.error(f"Error in perform_ocr: {e}")
return f"An error occurred: {str(e)}"
def process_with_status(image):
"""Process image and return result with status - moved outside to avoid pickling issues."""
if image is None:
return "Please upload an image.", "No image provided"
try:
result = perform_ocr(image)
return result, "Processing completed successfully"
except Exception as e:
return f"Error: {str(e)}", f"Error occurred: {str(e)}"
def create_interface():
"""Create the Gradio interface with proper configuration."""
with gr.Blocks(
title="AtlasOCR - Darija Document OCR",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1200px !important;
}
"""
) as demo:
gr.Markdown("""
# AtlasOCR - Darija Document OCR
Upload an image to extract Darija text in real-time. This model is specialized for Darija document OCR.
""")
with gr.Row():
with gr.Column(scale=1):
# Input image
image_input = gr.Image(
type="pil",
label="Upload Image",
height=400
)
# Submit button
submit_btn = gr.Button(
"Extract Text",
variant="primary",
size="lg"
)
# Clear button
clear_btn = gr.Button("Clear", variant="secondary")
with gr.Column(scale=1):
# Output text
output = gr.Textbox(
label="Extracted Text",
lines=20,
show_copy_button=True,
placeholder="Extracted text will appear here..."
)
# Status indicator
status = gr.Textbox(
label="Status",
value="Ready to process images",
interactive=False
)
# Model details
with gr.Accordion("Model Information", open=False):
gr.Markdown("""
**Model:** AtlasOCR-v0
**Description:** Specialized Darija OCR model for Arabic dialect text extraction
**Size:** 3B parameters
**Context window:** Supports up to 2000 output tokens
**Optimization:** 4-bit quantization for efficient inference
""")
gr.Examples(
examples=[
["i3.png"],
["i6.png"]
],
inputs=image_input,
outputs=[output, status], # <-- required
fn=process_with_status, # <-- required
label="Example Images",
examples_per_page=4,
cache_examples=True
)
# Set up processing flow
submit_btn.click(
fn=process_with_status,
inputs=image_input,
outputs=[output, status]
)
image_input.change(
fn=process_with_status,
inputs=image_input,
outputs=[output, status]
)
clear_btn.click(
fn=lambda: (None, "", "Ready to process images"),
outputs=[image_input, output, status]
)
return demo
# Create and launch the interface
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=True
)