Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from threading import Thread | |
| import numpy as np | |
| from openrec.postprocess.unirec_postprocess import clean_special_tokens | |
| from openrec.preprocess import create_operators, transform | |
| from tools.engine.config import Config | |
| from tools.utils.ckpt import load_ckpt | |
| from tools.infer_rec import build_rec_process | |
| def set_device(device): | |
| if device == 'gpu' and torch.cuda.is_available(): | |
| device = torch.device(f'cuda:0') | |
| else: | |
| device = torch.device('cpu') | |
| return device | |
| cfg = Config('configs/rec/unirec/focalsvtr_ardecoder_unirec.yml') | |
| cfg = cfg.cfg | |
| global_config = cfg['Global'] | |
| from openrec.modeling.transformers_modeling.modeling_unirec import UniRecForConditionalGenerationNew | |
| from openrec.modeling.transformers_modeling.configuration_unirec import UniRecConfig | |
| from transformers import AutoTokenizer, TextIteratorStreamer | |
| tokenizer = AutoTokenizer.from_pretrained(global_config['vlm_ocr_config']) | |
| cfg_model = UniRecConfig.from_pretrained(global_config['vlm_ocr_config']) | |
| # cfg_model._attn_implementation = "flash_attention_2" | |
| cfg_model._attn_implementation = 'eager' | |
| model = UniRecForConditionalGenerationNew(config=cfg_model) | |
| load_ckpt(model, cfg) | |
| device = set_device(cfg['Global']['device']) | |
| model.eval() | |
| model.to(device=device) | |
| transforms, ratio_resize_flag = build_rec_process(cfg) | |
| ops = create_operators(transforms, global_config) | |
| # --- 2. Streaming generation function --- | |
| def stream_chat_with_image(input_image, history): | |
| if input_image is None: | |
| yield history + [('🖼️(empty)', 'Please upload an image first.')] | |
| return | |
| # Create TextIteratorStreamer | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=False | |
| ) | |
| data = {'image': input_image} | |
| batch = transform(data, ops[1:]) | |
| images = np.expand_dims(batch[0], axis=0) | |
| images = torch.from_numpy(images).to(device=device) | |
| inputs = { | |
| 'pixel_values': images, | |
| 'input_ids': None, | |
| 'attention_mask': None | |
| } | |
| generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024) | |
| # Running generation in background thread | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Stream output | |
| generated_text = '' | |
| history = history + [('🖼️(image)', '')] | |
| for new_text in streamer: | |
| new_text = clean_special_tokens(new_text) | |
| generated_text += new_text | |
| history[-1] = ('🖼️(image)', generated_text) | |
| yield history | |
| # --- 3. Gradio UI --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.HTML(""" | |
| <h1 style='text-align: center;'> | |
| <a href="https://github.com/Topdu/OpenOCR"> | |
| UniRec-0.1B: Unified Text and Formula Recognition with 0.1B Parameters | |
| </a> | |
| </h1> | |
| <p style='text-align: center;'> | |
| A ultralight unified text and formula recognition model | |
| (Created by <a href="https://fvl.fudan.edu.cn">FVL Lab</a>, | |
| <a href="https://github.com/Topdu/OpenOCR">OCR Team</a>) | |
| </p> | |
| <p style='text-align: center;'> | |
| <a href="https://github.com/Topdu/OpenOCR/blob/main/docs/unirec.md">[Local GPU Deployment]</a> | |
| for fast recognition experience | |
| </p>""" | |
| ) | |
| gr.Markdown('Upload an image, and the system will automatically recognize text and formulas.') | |
| with gr.Row(): | |
| with gr.Column(scale=1): # Left column: image + clear button | |
| image_input = gr.Image(label='Upload Image or Paste Screenshot', type='pil') | |
| clear = gr.ClearButton([image_input], value='Clear') | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot( | |
| label='Result (Use LaTeX renderer to display formulas)', | |
| show_copy_button=True, | |
| height='auto' | |
| ) | |
| clear.add([chatbot]) | |
| # Trigger after upload | |
| image_input.upload(stream_chat_with_image, [image_input, chatbot], chatbot) | |
| # --- 4. Launch app --- | |
| if __name__ == '__main__': | |
| demo.queue().launch(share=True) | |