dots-ocr / app.py
redhairedshanks1's picture
Update app.py
37de944 verified
# import gradio as gr
# import easyocr
# reader = easyocr.Reader(["en"])
# def ocr_image(image):
# results = reader.readtext(image)
# return "\n".join([text for _, text, _ in results])
# demo = gr.Interface(fn=ocr_image, inputs="image", outputs="text")
# if __name__ == "__main__":
# demo.launch()
import os
from PIL import Image
import torch
import gradio as gr
# Make sure you have the local model folder (e.g., "DotsOCR") with all files from the repo
LOCAL_MODEL_PATH = "./DotsOCR"
# Import the model and processor code locally
import sys
sys.path.append(LOCAL_MODEL_PATH)
from modeling_dots_ocr import DotsOCRForVision2Text
from configuration_dots import DotsOCRConfig
from transformers import PreTrainedTokenizerFast
# Load tokenizer
tokenizer = PreTrainedTokenizerFast.from_pretrained(LOCAL_MODEL_PATH)
# Load model configuration
config = DotsOCRConfig.from_pretrained(LOCAL_MODEL_PATH)
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = DotsOCRForVision2Text.from_pretrained(LOCAL_MODEL_PATH, config=config)
model.to(device)
model.eval()
# Load only the image processor
from transformers import AutoFeatureExtractor
image_processor = AutoFeatureExtractor.from_pretrained(LOCAL_MODEL_PATH)
def parse_document(image: Image.Image):
# Preprocess the image
inputs = image_processor(images=image, return_tensors="pt").to(device)
# Forward pass
with torch.no_grad():
output_ids = model.generate(**inputs, max_new_tokens=1024)
# Decode output
text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
return text
# Gradio demo
demo = gr.Interface(
fn=parse_document,
inputs=gr.Image(type="pil"),
outputs="text",
title="dots.ocr Document Parser",
description="Parse text from images using dots.ocr"
)
if __name__ == "__main__":
demo.launch()