AtlasOCR-demo / app.py
imomayiz's picture
Update app.py
5e418cb verified
import gradio as gr
import torch
from PIL import Image
import os
from dotenv import load_dotenv
import spaces
load_dotenv()
# Disable torch compilation issues
torch._dynamo.config.disable = True
torch.backends.cudnn.allow_tf32 = True
IS_CUDA = torch.cuda.is_available()
IS_ZEROGPU = True if os.getenv("SPACES_ZERO_GPU", None) else False
if IS_ZEROGPU:
torch.compiler.set_stance("force_eager")
torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
MODEL_NAME ="atlasia/AtlasOCR"
MAX_TOKENS = 4096
@spaces.GPU()
@torch.inference_mode()
def predict(image: Image.Image) -> str:
try:
from unsloth import FastVisionModel
model, processor = FastVisionModel.from_pretrained(
MODEL_NAME,
device_map="auto",
load_in_4bit=True,
use_gradient_checkpointing="unsloth",
token=os.environ["HF_API_KEY"],
)
except Exception as e:
print(f"[Error] Failed to load model: {e}")
raise Exception(f"❌ Model failed to load: {e}")
if image is None:
gr.warning("Please upload an image.")
# Build prompt
messages = [
{
"role": "user",
"content": [{"type": "image"}, {"type": "text", "text": ""}],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = processor(
image,
text,
add_special_tokens=False,
return_tensors="pt",
).to(model.device)
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=MAX_TOKENS,
do_sample=False,
temperature=0.0001,
pad_token_id=processor.tokenizer.eos_token_id,
)
# Trim input ids from output
generated_ids = [
out[len(inp) :] for inp, out in zip(inputs["input_ids"], generated_ids)
]
text_out = processor.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return text_out[0].strip()
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Image", height=400),
outputs=gr.Textbox(
label="Extracted Text",
lines=20,
show_copy_button=True,
placeholder="Extracted text will appear here...",
),
title="AtlasOCR - Darija Document OCR",
description="Upload an image to extract Darija text.",
examples=[["i3.png"], ["i6.png"]],
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)