Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,79 +1,23 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import time
|
| 3 |
import spaces
|
| 4 |
from PIL import Image
|
| 5 |
-
from
|
| 6 |
-
from qwen_vl_utils import process_vision_info
|
| 7 |
-
import torch
|
| 8 |
-
import uuid
|
| 9 |
-
import os
|
| 10 |
-
import numpy as np
|
| 11 |
|
| 12 |
# Load model and processor
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
| 16 |
-
model_name,
|
| 17 |
-
torch_dtype="auto",
|
| 18 |
-
device_map="cuda"
|
| 19 |
-
)
|
| 20 |
-
processor = AutoProcessor.from_pretrained(model_name)
|
| 21 |
-
max_tokens = 2000
|
| 22 |
|
| 23 |
|
| 24 |
|
| 25 |
@spaces.GPU
|
| 26 |
def perform_ocr(image):
|
| 27 |
-
|
| 28 |
-
if inputArray == False:
|
| 29 |
-
return "Error Processing"
|
| 30 |
-
"""Process image and extract text using OCR model"""
|
| 31 |
-
image = Image.fromarray(image)
|
| 32 |
-
src = str(uuid.uuid4()) + ".png"
|
| 33 |
-
prompt = "Below is the image of one page of a document, as well as some raw textual content that was previously extracted for it. Just return the plain text representation of this document as if you were reading it naturally. Do not hallucinate."
|
| 34 |
-
image.save(src)
|
| 35 |
-
|
| 36 |
-
messages = [
|
| 37 |
-
{
|
| 38 |
-
"role": "user",
|
| 39 |
-
"content": [
|
| 40 |
-
{"type": "image", "image": f"file://{src}"},
|
| 41 |
-
{"type": "text", "text": prompt},
|
| 42 |
-
],
|
| 43 |
-
}
|
| 44 |
-
]
|
| 45 |
-
|
| 46 |
-
# Process inputs
|
| 47 |
-
text = processor.apply_chat_template(
|
| 48 |
-
messages, tokenize=False, add_generation_prompt=True
|
| 49 |
-
)
|
| 50 |
-
image_inputs, video_inputs = process_vision_info(messages)
|
| 51 |
-
inputs = processor(
|
| 52 |
-
text=[text],
|
| 53 |
-
images=image_inputs,
|
| 54 |
-
videos=video_inputs,
|
| 55 |
-
padding=True,
|
| 56 |
-
return_tensors="pt",
|
| 57 |
-
)
|
| 58 |
-
inputs = inputs.to("cuda")
|
| 59 |
-
|
| 60 |
-
# Generate text
|
| 61 |
-
generated_ids = model.generate(**inputs, max_new_tokens=max_tokens, use_cache=True)
|
| 62 |
-
generated_ids_trimmed = [
|
| 63 |
-
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 64 |
-
]
|
| 65 |
-
output_text = processor.batch_decode(
|
| 66 |
-
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 67 |
-
)[0]
|
| 68 |
-
|
| 69 |
-
# Cleanup
|
| 70 |
-
os.remove(src)
|
| 71 |
return output_text
|
| 72 |
|
| 73 |
# Create Gradio interface
|
| 74 |
-
with gr.Blocks(title="
|
| 75 |
-
gr.Markdown("#
|
| 76 |
-
gr.Markdown("Upload an image to extract
|
| 77 |
|
| 78 |
with gr.Row():
|
| 79 |
with gr.Column(scale=1):
|
|
@@ -101,9 +45,9 @@ with gr.Blocks(title="Qari Arabic OCR") as demo:
|
|
| 101 |
# Model details
|
| 102 |
with gr.Accordion("Model Information", open=False):
|
| 103 |
gr.Markdown("""
|
| 104 |
-
**Model:**
|
| 105 |
-
**Description:**
|
| 106 |
-
**Size:**
|
| 107 |
**Context window:** Supports up to 2000 output tokens
|
| 108 |
""")
|
| 109 |
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
import spaces
|
| 3 |
from PIL import Image
|
| 4 |
+
from .atlasocr_model import AtlasOCR
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
# Load model and processor
|
| 7 |
+
|
| 8 |
+
atlas_ocr=AtlasOCR()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
|
| 12 |
@spaces.GPU
|
| 13 |
def perform_ocr(image):
|
| 14 |
+
output_text = atlas_ocr(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
return output_text
|
| 16 |
|
| 17 |
# Create Gradio interface
|
| 18 |
+
with gr.Blocks(title="AtlasOCR") as demo:
|
| 19 |
+
gr.Markdown("# AtlasOCR")
|
| 20 |
+
gr.Markdown("Upload an image to extract Darija text in real-time. This model is specialized for Darija document OCR.")
|
| 21 |
|
| 22 |
with gr.Row():
|
| 23 |
with gr.Column(scale=1):
|
|
|
|
| 45 |
# Model details
|
| 46 |
with gr.Accordion("Model Information", open=False):
|
| 47 |
gr.Markdown("""
|
| 48 |
+
**Model:** AtlasOCR-v0
|
| 49 |
+
**Description:** Darija OCR model
|
| 50 |
+
**Size:** 3B parameters
|
| 51 |
**Context window:** Supports up to 2000 output tokens
|
| 52 |
""")
|
| 53 |
|