redhairedshanks1 commited on
Commit
37de944
·
verified ·
1 Parent(s): 3ec4321

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -55
app.py CHANGED
@@ -12,66 +12,59 @@
12
  # if __name__ == "__main__":
13
  # demo.launch()
14
 
15
- import gradio as gr
 
16
  import torch
17
- from transformers import AutoProcessor, AutoModelForCausalLM
18
-
19
- # ------------------------------
20
- # Image-only Processor Wrapper
21
- # ------------------------------
22
- class ImageOnlyProcessor:
23
- def __init__(self, processor):
24
- self.image_processor = processor.image_processor
25
- self.tokenizer = processor.tokenizer
26
- self.feature_extractor = getattr(processor, "feature_extractor", None)
27
-
28
- def __getattr__(self, name):
29
- # Pass through to tokenizer if not found here
30
- return getattr(self.tokenizer, name)
31
-
32
- # ------------------------------
33
- # Load Model + Processor
34
- # ------------------------------
35
- model_id = "rednote/dots-ocr"
36
-
37
- print("🔄 Loading processor...")
38
- _base_proc = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
39
- processor = ImageOnlyProcessor(_base_proc)
40
-
41
- print("🔄 Loading model (may take time)...")
42
- model = AutoModelForCausalLM.from_pretrained(
43
- model_id,
44
- trust_remote_code=True,
45
- torch_dtype=torch.float16,
46
- device_map="auto"
47
- )
48
 
49
- # ------------------------------
50
- # OCR Function
51
- # ------------------------------
52
- def run_ocr(image):
53
- try:
54
- # Preprocess
55
- inputs = processor.image_processor(images=image, return_tensors="pt").to(model.device)
56
-
57
- # Generate
58
- generated_ids = model.generate(**inputs, max_new_tokens=256)
59
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
60
-
61
- return generated_text.strip()
62
- except Exception as e:
63
- return f"❌ Error: {str(e)}"
64
-
65
- # ------------------------------
66
- # Gradio UI
67
- # ------------------------------
 
 
 
 
 
 
 
68
  demo = gr.Interface(
69
- fn=run_ocr,
70
  inputs=gr.Image(type="pil"),
71
  outputs="text",
72
- title="📖 DOTS-OCR (Image Only)",
73
- description="Upload an image to extract text using the rednote/dots-ocr model (video support disabled)."
74
  )
75
 
76
  if __name__ == "__main__":
77
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
12
  # if __name__ == "__main__":
13
  # demo.launch()
14
 
15
+ import os
16
+ from PIL import Image
17
  import torch
18
+ import gradio as gr
19
+
20
+ # Make sure you have the local model folder (e.g., "DotsOCR") with all files from the repo
21
+ LOCAL_MODEL_PATH = "./DotsOCR"
22
+
23
+ # Import the model and processor code locally
24
+ import sys
25
+ sys.path.append(LOCAL_MODEL_PATH)
26
+
27
+ from modeling_dots_ocr import DotsOCRForVision2Text
28
+ from configuration_dots import DotsOCRConfig
29
+ from transformers import PreTrainedTokenizerFast
30
+
31
+ # Load tokenizer
32
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(LOCAL_MODEL_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ # Load model configuration
35
+ config = DotsOCRConfig.from_pretrained(LOCAL_MODEL_PATH)
36
+
37
+ # Load model
38
+ device = "cuda" if torch.cuda.is_available() else "cpu"
39
+ model = DotsOCRForVision2Text.from_pretrained(LOCAL_MODEL_PATH, config=config)
40
+ model.to(device)
41
+ model.eval()
42
+
43
+ # Load only the image processor
44
+ from transformers import AutoFeatureExtractor
45
+ image_processor = AutoFeatureExtractor.from_pretrained(LOCAL_MODEL_PATH)
46
+
47
+ def parse_document(image: Image.Image):
48
+ # Preprocess the image
49
+ inputs = image_processor(images=image, return_tensors="pt").to(device)
50
+
51
+ # Forward pass
52
+ with torch.no_grad():
53
+ output_ids = model.generate(**inputs, max_new_tokens=1024)
54
+
55
+ # Decode output
56
+ text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
57
+ return text
58
+
59
+ # Gradio demo
60
  demo = gr.Interface(
61
+ fn=parse_document,
62
  inputs=gr.Image(type="pil"),
63
  outputs="text",
64
+ title="dots.ocr Document Parser",
65
+ description="Parse text from images using dots.ocr"
66
  )
67
 
68
  if __name__ == "__main__":
69
+ demo.launch()
70
+