Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor | |
| from qwen_vl_utils import process_vision_info | |
| from PIL import Image | |
| import io | |
| import base64 | |
| from datasets import load_dataset | |
| max_token_budget = 512 | |
| min_pixels = 1 * 28 * 28 | |
| max_pixels = max_token_budget * 28 * 28 | |
| processor = AutoProcessor.from_pretrained( | |
| "Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels | |
| ) | |
| ds = load_dataset("gigant/tib-bench")["train"] | |
| def segments(example): | |
| # create a text with the <image> tokens from the timestamps of the extracted keyframes and transcript | |
| text = "" | |
| segment_i = 0 | |
| for i, timestamp in enumerate(example['keyframes']['timestamp']): | |
| text += f"<image>" #f"<image {i}>" | |
| start, end = timestamp[0], timestamp[1] | |
| while segment_i < len(example["transcript_segments"]["seek"]) and end > example["transcript_segments"]["seek"][segment_i] * 0.01: | |
| text += example["transcript_segments"]["text"][segment_i] | |
| segment_i += 1 | |
| if segment_i < len(example["transcript_segments"]): | |
| text += "".join(example["transcript_segments"]["text"][segment_i:]) | |
| return text | |
| def create_interleaved_html(text, slides, scale=0.4, max_width=600): | |
| """ | |
| Creates an HTML string with interleaved images and text segments. | |
| The images are converted to base64 and embedded directly in the HTML. | |
| """ | |
| html = [] | |
| segments = text.split("<image>") | |
| for j, segment in enumerate(segments): # Skip the first empty string bc of leading <image> | |
| # Add the image | |
| if j > 0: | |
| img = slides[j - 1] | |
| img_width = int(img.width * scale) | |
| img_height = int(img.height * scale) | |
| if img_width > max_width: | |
| ratio = max_width / img_width | |
| img_width = max_width | |
| img_height = int(img_height * ratio) | |
| # Convert image to base64 | |
| buffer = io.BytesIO() | |
| img.resize((img_width, img_height)).save(buffer, format="PNG") | |
| img_str = base64.b64encode(buffer.getvalue()).decode("utf-8") | |
| html.append(f'<img src="data:image/png;base64,{img_str}" style="max-width: {max_width}px; display: block; margin: 20px auto;">') | |
| # Add the text segment after the image | |
| html.append(f'<div style="white-space: pre-wrap;">{segment}</div>') | |
| return "".join(html) | |
| def doc_to_messages(text, slides): | |
| content = [] | |
| segments = text.split("<image>") | |
| for j, segment in enumerate(segments): | |
| if j > 0: | |
| content.append({"type": "image", "image": slides[j - 1]}) | |
| content.append({"type": "text", "text": segment}) | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": content, | |
| } | |
| ] | |
| # Preparation for inference | |
| text = processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| inputs = processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| return inputs | |
| # Global variables to keep track of current document | |
| current_doc_index = 0 | |
| annotations = [] | |
| choices = [f"{i} | {ds['title'][i]}" for i in range(len(ds))] | |
| def load_document(index): | |
| """Load a specific document from the dataset""" | |
| if 0 <= index < len(ds): | |
| doc = ds[index] | |
| segments_doc = segments(doc) | |
| return ( | |
| doc["title"], | |
| doc["abstract"], | |
| create_interleaved_html(segments_doc, doc["slides"], scale=0.7), | |
| doc_to_messages(segments_doc, doc["slides"]).input_ids.shape[1], | |
| choices[index], | |
| ) | |
| return ("", "", "", "", "") | |
| def get_next_document(): | |
| """Get the next document in the dataset""" | |
| global current_doc_index | |
| return choices[(current_doc_index + 1) % len(ds)] | |
| def get_prev_document(): | |
| """Get the previous document in the dataset""" | |
| global current_doc_index | |
| return choices[(current_doc_index - 1) % len(ds)] | |
| def get_selected_document(arg): | |
| """Get the selected document from the dataset""" | |
| global current_doc_index | |
| index = int(arg.split(" | ")[0]) | |
| current_doc_index = index | |
| return load_document(current_doc_index) | |
| theme = gr.themes.Ocean() | |
| with gr.Blocks(theme=theme) as demo: | |
| gr.Markdown("# Slide Presentation Visualization Tool") | |
| pres_selection_dd = gr.Dropdown(label="Presentation", value=choices[0], choices=choices) | |
| with gr.Row(): | |
| with gr.Column(): | |
| body = gr.HTML(max_height=400) | |
| with gr.Column(): | |
| title = gr.Textbox(label="Title", interactive=False, max_lines=1) | |
| abstract = gr.Textbox(label="Abstract", interactive=False, max_lines=8) | |
| token_count = gr.Textbox(label=f"Token Count (Qwen2-VL with under {max_token_budget} tokens per image)", interactive=False, max_lines=1) | |
| # Load first document | |
| title_val, abstract_val, body_val, token_count_val, choices_val = load_document(current_doc_index) | |
| title.value = title_val | |
| abstract.value = abstract_val | |
| body.value = body_val | |
| token_count.value = str(token_count_val) | |
| pres_selection_dd.value = choices_val | |
| pres_selection_dd.change( | |
| fn=get_selected_document, | |
| inputs=pres_selection_dd, | |
| outputs=[title, abstract, body, token_count, pres_selection_dd], | |
| ) | |
| with gr.Row(): | |
| prev_button = gr.Button("Previous Document") | |
| prev_button.click(fn=get_prev_document, inputs=[], outputs=[pres_selection_dd]) | |
| next_button = gr.Button("Next Document") | |
| next_button.click(fn=get_next_document, inputs=[], outputs=[pres_selection_dd]) | |
| demo.launch() |