Spaces:
Runtime error
Runtime error
| # import gradio as gr | |
| # import easyocr | |
| # reader = easyocr.Reader(["en"]) | |
| # def ocr_image(image): | |
| # results = reader.readtext(image) | |
| # return "\n".join([text for _, text, _ in results]) | |
| # demo = gr.Interface(fn=ocr_image, inputs="image", outputs="text") | |
| # if __name__ == "__main__": | |
| # demo.launch() | |
| import os | |
| from PIL import Image | |
| import torch | |
| import gradio as gr | |
| # Make sure you have the local model folder (e.g., "DotsOCR") with all files from the repo | |
| LOCAL_MODEL_PATH = "./DotsOCR" | |
| # Import the model and processor code locally | |
| import sys | |
| sys.path.append(LOCAL_MODEL_PATH) | |
| from modeling_dots_ocr import DotsOCRForVision2Text | |
| from configuration_dots import DotsOCRConfig | |
| from transformers import PreTrainedTokenizerFast | |
| # Load tokenizer | |
| tokenizer = PreTrainedTokenizerFast.from_pretrained(LOCAL_MODEL_PATH) | |
| # Load model configuration | |
| config = DotsOCRConfig.from_pretrained(LOCAL_MODEL_PATH) | |
| # Load model | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = DotsOCRForVision2Text.from_pretrained(LOCAL_MODEL_PATH, config=config) | |
| model.to(device) | |
| model.eval() | |
| # Load only the image processor | |
| from transformers import AutoFeatureExtractor | |
| image_processor = AutoFeatureExtractor.from_pretrained(LOCAL_MODEL_PATH) | |
| def parse_document(image: Image.Image): | |
| # Preprocess the image | |
| inputs = image_processor(images=image, return_tensors="pt").to(device) | |
| # Forward pass | |
| with torch.no_grad(): | |
| output_ids = model.generate(**inputs, max_new_tokens=1024) | |
| # Decode output | |
| text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] | |
| return text | |
| # Gradio demo | |
| demo = gr.Interface( | |
| fn=parse_document, | |
| inputs=gr.Image(type="pil"), | |
| outputs="text", | |
| title="dots.ocr Document Parser", | |
| description="Parse text from images using dots.ocr" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |