Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import os | |
| from dotenv import load_dotenv | |
| import spaces | |
| load_dotenv() | |
| # Disable torch compilation issues | |
| torch._dynamo.config.disable = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| IS_CUDA = torch.cuda.is_available() | |
| IS_ZEROGPU = True if os.getenv("SPACES_ZERO_GPU", None) else False | |
| if IS_ZEROGPU: | |
| torch.compiler.set_stance("force_eager") | |
| torch.set_float32_matmul_precision("high") | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| MODEL_NAME ="atlasia/AtlasOCR" | |
| MAX_TOKENS = 4096 | |
| def predict(image: Image.Image) -> str: | |
| try: | |
| from unsloth import FastVisionModel | |
| model, processor = FastVisionModel.from_pretrained( | |
| MODEL_NAME, | |
| device_map="auto", | |
| load_in_4bit=True, | |
| use_gradient_checkpointing="unsloth", | |
| token=os.environ["HF_API_KEY"], | |
| ) | |
| except Exception as e: | |
| print(f"[Error] Failed to load model: {e}") | |
| raise Exception(f"β Model failed to load: {e}") | |
| if image is None: | |
| gr.warning("Please upload an image.") | |
| # Build prompt | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [{"type": "image"}, {"type": "text", "text": ""}], | |
| } | |
| ] | |
| text = processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = processor( | |
| image, | |
| text, | |
| add_special_tokens=False, | |
| return_tensors="pt", | |
| ).to(model.device) | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=MAX_TOKENS, | |
| do_sample=False, | |
| temperature=0.0001, | |
| pad_token_id=processor.tokenizer.eos_token_id, | |
| ) | |
| # Trim input ids from output | |
| generated_ids = [ | |
| out[len(inp) :] for inp, out in zip(inputs["input_ids"], generated_ids) | |
| ] | |
| text_out = processor.batch_decode( | |
| generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| ) | |
| return text_out[0].strip() | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload Image", height=400), | |
| outputs=gr.Textbox( | |
| label="Extracted Text", | |
| lines=20, | |
| show_copy_button=True, | |
| placeholder="Extracted text will appear here...", | |
| ), | |
| title="AtlasOCR - Darija Document OCR", | |
| description="Upload an image to extract Darija text.", | |
| examples=[["i3.png"], ["i6.png"]], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |