duyongkun
update app
5de2f8f
import gradio as gr
import torch
from threading import Thread
import numpy as np
from openrec.postprocess.unirec_postprocess import clean_special_tokens
from openrec.preprocess import create_operators, transform
from tools.engine.config import Config
from tools.utils.ckpt import load_ckpt
from tools.infer_rec import build_rec_process
def set_device(device):
if device == 'gpu' and torch.cuda.is_available():
device = torch.device(f'cuda:0')
else:
device = torch.device('cpu')
return device
cfg = Config('configs/rec/unirec/focalsvtr_ardecoder_unirec.yml')
cfg = cfg.cfg
global_config = cfg['Global']
from openrec.modeling.transformers_modeling.modeling_unirec import UniRecForConditionalGenerationNew
from openrec.modeling.transformers_modeling.configuration_unirec import UniRecConfig
from transformers import AutoTokenizer, TextIteratorStreamer
tokenizer = AutoTokenizer.from_pretrained(global_config['vlm_ocr_config'])
cfg_model = UniRecConfig.from_pretrained(global_config['vlm_ocr_config'])
# cfg_model._attn_implementation = "flash_attention_2"
cfg_model._attn_implementation = 'eager'
model = UniRecForConditionalGenerationNew(config=cfg_model)
load_ckpt(model, cfg)
device = set_device(cfg['Global']['device'])
model.eval()
model.to(device=device)
transforms, ratio_resize_flag = build_rec_process(cfg)
ops = create_operators(transforms, global_config)
# --- 2. Streaming generation function ---
def stream_chat_with_image(input_image, history):
if input_image is None:
yield history + [('🖼️(empty)', 'Please upload an image first.')]
return
# Create TextIteratorStreamer
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=False
)
data = {'image': input_image}
batch = transform(data, ops[1:])
images = np.expand_dims(batch[0], axis=0)
images = torch.from_numpy(images).to(device=device)
inputs = {
'pixel_values': images,
'input_ids': None,
'attention_mask': None
}
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
# Running generation in background thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Stream output
generated_text = ''
history = history + [('🖼️(image)', '')]
for new_text in streamer:
new_text = clean_special_tokens(new_text)
generated_text += new_text
history[-1] = ('🖼️(image)', generated_text)
yield history
# --- 3. Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.HTML("""
<h1 style='text-align: center;'>
<a href="https://github.com/Topdu/OpenOCR">
UniRec-0.1B: Unified Text and Formula Recognition with 0.1B Parameters
</a>
</h1>
<p style='text-align: center;'>
A ultralight unified text and formula recognition model
(Created by <a href="https://fvl.fudan.edu.cn">FVL Lab</a>,
<a href="https://github.com/Topdu/OpenOCR">OCR Team</a>)
</p>
<p style='text-align: center;'>
<a href="https://github.com/Topdu/OpenOCR/blob/main/docs/unirec.md">[Local GPU Deployment]</a>
for fast recognition experience
</p>"""
)
gr.Markdown('Upload an image, and the system will automatically recognize text and formulas.')
with gr.Row():
with gr.Column(scale=1): # Left column: image + clear button
image_input = gr.Image(label='Upload Image or Paste Screenshot', type='pil')
clear = gr.ClearButton([image_input], value='Clear')
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label='Result (Use LaTeX renderer to display formulas)',
show_copy_button=True,
height='auto'
)
clear.add([chatbot])
# Trigger after upload
image_input.upload(stream_chat_with_image, [image_input, chatbot], chatbot)
# --- 4. Launch app ---
if __name__ == '__main__':
demo.queue().launch(share=True)