File size: 4,187 Bytes
5de2f8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
import torch
from threading import Thread

import numpy as np
from openrec.postprocess.unirec_postprocess import clean_special_tokens
from openrec.preprocess import create_operators, transform
from tools.engine.config import Config
from tools.utils.ckpt import load_ckpt
from tools.infer_rec import build_rec_process


def set_device(device):
    if device == 'gpu' and torch.cuda.is_available():
        device = torch.device(f'cuda:0')
    else:
        device = torch.device('cpu')
    return device


cfg = Config('configs/rec/unirec/focalsvtr_ardecoder_unirec.yml')
cfg = cfg.cfg
global_config = cfg['Global']

from openrec.modeling.transformers_modeling.modeling_unirec import UniRecForConditionalGenerationNew
from openrec.modeling.transformers_modeling.configuration_unirec import UniRecConfig
from transformers import AutoTokenizer, TextIteratorStreamer

tokenizer = AutoTokenizer.from_pretrained(global_config['vlm_ocr_config'])
cfg_model = UniRecConfig.from_pretrained(global_config['vlm_ocr_config'])
# cfg_model._attn_implementation = "flash_attention_2"
cfg_model._attn_implementation = 'eager'

model = UniRecForConditionalGenerationNew(config=cfg_model)
load_ckpt(model, cfg)
device = set_device(cfg['Global']['device'])
model.eval()
model.to(device=device)

transforms, ratio_resize_flag = build_rec_process(cfg)
ops = create_operators(transforms, global_config)


# --- 2. Streaming generation function ---
def stream_chat_with_image(input_image, history):
    if input_image is None:
        yield history + [('🖼️(empty)', 'Please upload an image first.')]
        return

    # Create TextIteratorStreamer
    streamer = TextIteratorStreamer(
        tokenizer,
        skip_prompt=True,
        skip_special_tokens=False
    )

    data = {'image': input_image}
    batch = transform(data, ops[1:])
    images = np.expand_dims(batch[0], axis=0)
    images = torch.from_numpy(images).to(device=device)
    inputs = {
        'pixel_values': images,
        'input_ids': None,
        'attention_mask': None
    }
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
    # Running generation in background thread
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    # Stream output
    generated_text = ''
    history = history + [('🖼️(image)', '')]
    for new_text in streamer:
        new_text = clean_special_tokens(new_text)
        generated_text += new_text
        history[-1] = ('🖼️(image)', generated_text)
        yield history


# --- 3. Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.HTML("""
            <h1 style='text-align: center;'>
                <a href="https://github.com/Topdu/OpenOCR">
                    UniRec-0.1B: Unified Text and Formula Recognition with 0.1B Parameters
                </a>
            </h1>
            <p style='text-align: center;'>
               A ultralight unified text and formula recognition model 
                (Created by <a href="https://fvl.fudan.edu.cn">FVL Lab</a>, 
                <a href="https://github.com/Topdu/OpenOCR">OCR Team</a>)
            </p>
            <p style='text-align: center;'>
                <a href="https://github.com/Topdu/OpenOCR/blob/main/docs/unirec.md">[Local GPU Deployment]</a>
                for fast recognition experience
            </p>"""
            )
    gr.Markdown('Upload an image, and the system will automatically recognize text and formulas.')
    with gr.Row():
        with gr.Column(scale=1):  # Left column: image + clear button
            image_input = gr.Image(label='Upload Image or Paste Screenshot', type='pil')
            clear = gr.ClearButton([image_input], value='Clear')
        with gr.Column(scale=2):
            chatbot = gr.Chatbot(
                label='Result (Use LaTeX renderer to display formulas)',
                show_copy_button=True,
                height='auto'
            )
    clear.add([chatbot])

    # Trigger after upload
    image_input.upload(stream_chat_with_image, [image_input, chatbot], chatbot)

# --- 4. Launch app ---
if __name__ == '__main__':
    demo.queue().launch(share=True)