Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, STOKEStreamer | |
| from threading import Thread | |
| import json | |
| import torch | |
| import os | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from matplotlib.colors import to_hex | |
| import itertools | |
| import transformers | |
| import time | |
| transformers.logging.set_verbosity_error() | |
| # Variable to define number of instances | |
| n_instances = 1 | |
| gpu_name = "CPU" | |
| for i in range(torch.cuda.device_count()): | |
| gpu_name = torch.cuda.get_device_properties(i).name | |
| # Reusing the original MLP class and other functions (unchanged) except those specific to Streamlit | |
| class MLP(torch.nn.Module): | |
| def __init__(self, input_dim, output_dim, hidden_dim=1024, layer_id=0, cuda=False): | |
| super(MLP, self).__init__() | |
| self.fc1 = torch.nn.Linear(input_dim, hidden_dim) | |
| self.fc3 = torch.nn.Linear(hidden_dim, output_dim) | |
| self.layer_id = layer_id | |
| if cuda: | |
| self.device = "cuda" | |
| else: | |
| self.device = "cpu" | |
| self.to(self.device) | |
| def forward(self, x): | |
| x = torch.flatten(x, start_dim=1) | |
| x = torch.relu(self.fc1(x)) | |
| x = self.fc3(x) | |
| return torch.argmax(x, dim=-1).cpu().detach(), torch.softmax(x, dim=-1).cpu().detach() | |
| def map_value_to_color(value, colormap_name='tab20c'): | |
| value = np.clip(value, 0.0, 1.0) | |
| colormap = plt.get_cmap(colormap_name) | |
| rgba_color = colormap(value) | |
| css_color = to_hex(rgba_color) | |
| return css_color | |
| # Caching functions for model and classifier | |
| model_cache = {} | |
| def get_multiple_model_and_tokenizer(name, n_instances): | |
| model_instances = [] | |
| for _ in range(n_instances): | |
| tok = AutoTokenizer.from_pretrained(name, token=os.getenv('HF_TOKEN'), pad_token_id=128001) | |
| model = AutoModelForCausalLM.from_pretrained(name, token=os.getenv('HF_TOKEN'), torch_dtype="bfloat16", pad_token_id=128001, device_map="auto") | |
| if torch.cuda.is_available(): | |
| model.cuda() | |
| model_instances.append((model, tok)) | |
| return model_instances | |
| def get_classifiers_for_model(att_size, emb_size, device, config_paths): | |
| config = { | |
| "classifier_token": json.load(open(os.path.join(config_paths["classifier_token"], "config.json"), "r")), | |
| "classifier_span": json.load(open(os.path.join(config_paths["classifier_span"], "config.json"), "r")) | |
| } | |
| layer_id = config["classifier_token"]["layer"] | |
| classifier_span = MLP(att_size, 2, hidden_dim=config["classifier_span"]["classifier_dim"]).to(device) | |
| classifier_span.load_state_dict(torch.load(os.path.join(config_paths["classifier_span"], "checkpoint.pt"), map_location=device, weights_only=True)) | |
| classifier_token = MLP(emb_size, len(config["classifier_token"]["label_map"]), layer_id=layer_id, hidden_dim=config["classifier_token"]["classifier_dim"]).to(device) | |
| classifier_token.load_state_dict(torch.load(os.path.join(config_paths["classifier_token"], "checkpoint.pt"), map_location=device, weights_only=True)) | |
| return classifier_span, classifier_token, config["classifier_token"]["label_map"] | |
| def find_datasets_and_model_ids(root_dir): | |
| datasets = {} | |
| for root, dirs, files in os.walk(root_dir): | |
| if 'config.json' in files and 'stoke_config.json' in files: | |
| config_path = os.path.join(root, 'config.json') | |
| stoke_config_path = os.path.join(root, 'stoke_config.json') | |
| with open(config_path, 'r') as f: | |
| config_data = json.load(f) | |
| model_id = config_data.get('model_id') | |
| if model_id: | |
| dataset_name = os.path.basename(os.path.dirname(config_path)) | |
| with open(stoke_config_path, 'r') as f: | |
| stoke_config_data = json.load(f) | |
| if model_id: | |
| dataset_name = os.path.basename(os.path.dirname(stoke_config_path)) | |
| datasets.setdefault(model_id, {})[dataset_name] = stoke_config_data | |
| return datasets | |
| def filter_spans(spans_and_values): | |
| if spans_and_values == []: | |
| return [], [] | |
| # Create a dictionary to store spans based on their second index values | |
| span_dict = {} | |
| spans, values = [x[0] for x in spans_and_values], [x[1] for x in spans_and_values] | |
| # Iterate through the spans and update the dictionary with the highest value | |
| for span, value in zip(spans, values): | |
| start, end = span | |
| if start > end or end - start > 15 or start == 0: | |
| continue | |
| current_value = span_dict.get(end, None) | |
| if current_value is None or current_value[1] < value: | |
| span_dict[end] = (span, value) | |
| if span_dict == {}: | |
| return [], [] | |
| # Extract the filtered spans and values | |
| filtered_spans, filtered_values = zip(*span_dict.values()) | |
| return list(filtered_spans), list(filtered_values) | |
| def remove_overlapping_spans(spans): | |
| # Sort the spans based on their end points | |
| sorted_spans = sorted(spans, key=lambda x: x[0][1]) | |
| non_overlapping_spans = [] | |
| last_end = float('-inf') | |
| # Iterate through the sorted spans | |
| for span in sorted_spans: | |
| start, end = span[0] | |
| value = span[1] | |
| # If the current span does not overlap with the previous one | |
| if start >= last_end: | |
| non_overlapping_spans.append(span) | |
| last_end = end | |
| else: | |
| # If it overlaps, choose the one with the highest value | |
| existing_span_index = -1 | |
| for i, existing_span in enumerate(non_overlapping_spans): | |
| if existing_span[0][1] <= start: | |
| existing_span_index = i | |
| break | |
| if existing_span_index != -1 and non_overlapping_spans[existing_span_index][1] < value: | |
| non_overlapping_spans[existing_span_index] = span | |
| return non_overlapping_spans | |
| def generate_html_no_overlap(tokenized_text, spans): | |
| current_index = 0 | |
| html_content = "" | |
| for (span_start, span_end), value in spans: | |
| # Add text before the span | |
| html_content += "".join(tokenized_text[current_index:span_start]) | |
| # Add the span with underlining | |
| html_content += "<b><u>" | |
| html_content += "".join(tokenized_text[span_start:span_end]) | |
| html_content += "</u></b> " | |
| current_index = span_end | |
| # Add any remaining text after the last span | |
| html_content += "".join(tokenized_text[current_index:]) | |
| return html_content | |
| css = """ | |
| <style> | |
| .prose { | |
| line-height: 200%; | |
| } | |
| .highlight { | |
| display: inline; | |
| } | |
| .highlight::after { | |
| background-color: var(data-color); | |
| } | |
| .spanhighlight { | |
| padding: 2px 5px; | |
| border-radius: 5px; | |
| } | |
| .tooltip { | |
| position: relative; | |
| display: inline-block; | |
| } | |
| .generated-content { | |
| overflow: scroll; | |
| height: 100%; | |
| } | |
| .tooltip::after { | |
| content: attr(data-tooltip-text); /* Set content from data-tooltip-text attribute */ | |
| display: none; | |
| position: absolute; | |
| background-color: #333; | |
| color: #fff; | |
| padding: 5px; | |
| border-radius: 5px; | |
| bottom: 100%; /* Position it above the element */ | |
| left: 50%; | |
| transform: translateX(-50%); | |
| width: auto; | |
| min-width: 120px; | |
| margin: 0 auto; | |
| text-align: center; | |
| } | |
| .tooltip:hover::after { | |
| display: block; /* Show the tooltip on hover */ | |
| } | |
| .small-text { | |
| padding: 2px 5px; | |
| background-color: white; | |
| border-radius: 5px; | |
| font-size: xx-small; | |
| margin-left: 0.5em; | |
| vertical-align: 0.2em; | |
| font-weight: bold; | |
| color: grey!important; | |
| } | |
| .square { | |
| width: 20px; /* Width of the square */ | |
| height: 20px; /* Height of the square */ | |
| border: 1px solid black; /* Black outline */ | |
| margin: auto; | |
| background-color: white; /* Optional: set the background color */ | |
| position: relative; | |
| z-index: 1; /* Higher stacking order for the square */ | |
| } | |
| .circle { | |
| width: 16px; /* Width of the square */ | |
| height: 16px; /* Height of the square */ | |
| border: 1px solid red; /* Black outline */ | |
| border-radius: 8px; | |
| margin: auto; | |
| background-color: white; /* Optional: set the background color */ | |
| position: relative; | |
| z-index: 1; /* Higher stacking order for the square */ | |
| display: block!important; | |
| } | |
| table { | |
| border: 0px!important; /* Black outline */ | |
| table-layout: fixed; | |
| width:100%; | |
| } | |
| th, td { | |
| font-weight: normal; | |
| width: 7em!important; | |
| text-align: center!important; | |
| border: 0px!important; | |
| } | |
| tr { | |
| border: 0px!important; | |
| } | |
| .dashed-cell { | |
| position: relative; | |
| width: 50px; /* Adjust width of the table cell */ | |
| } | |
| .dashed-cell::before { | |
| content: ""; | |
| position: absolute; | |
| top: 0; | |
| bottom: 0; | |
| left: 50%; /* Center the dashed line horizontally */ | |
| width: 0; /* No width, just a vertical line */ | |
| border-left: 1px dashed black; /* Dashed vertical line */ | |
| transform: translateX(-50%); /* Center the line exactly in the middle */ | |
| } | |
| .dashed-cell-horizontal::after { | |
| content: ""; | |
| position: absolute; | |
| left: 0; | |
| right: 0; | |
| top: 50%; /* Center the dashed horizontal line vertically */ | |
| height: 0; /* No height, just a horizontal line */ | |
| border-top: 1px dashed black; /* Dashed horizontal line */ | |
| transform: translateY(-50%); /* Center the line exactly in the middle */ | |
| } | |
| .arrowtip { | |
| width: 0; | |
| height: 0; | |
| border-left: 4px solid transparent; | |
| border-right: 4px solid transparent; | |
| border-bottom: 8px solid black; /* The triangle color */ | |
| bottom: 8px; /* The triangle color */ | |
| position: relative; | |
| } | |
| .span-cell::after { | |
| content: ''; | |
| position: absolute; | |
| top: 50%; | |
| left: -1px; | |
| width: 1px; | |
| height: calc(100% * 6.5); /* Adjust the height as needed to reach the yellow circle */ | |
| background-color: red; | |
| } | |
| </style>""" | |
| def generate_html_spanwise(token_strings, tokenwise_preds, spans, tokenizer, new_tags): | |
| # spanwise annotated text | |
| annotated = [] | |
| span_ends = -1 | |
| in_span = False | |
| out_of_span_tokens = [] | |
| for i in reversed(range(len(tokenwise_preds))): | |
| if in_span: | |
| if i >= span_ends: | |
| continue | |
| else: | |
| in_span = False | |
| predicted_class = "" | |
| style = "" | |
| span = None | |
| for s in spans: | |
| if s[1] == i+1: | |
| span = s | |
| if tokenwise_preds[i] != 0 and span is not None: | |
| predicted_class = f"highlight spanhighlight" | |
| style = f"background-color: {map_value_to_color((tokenwise_preds[i]-1)/(len(new_tags)-1))}" | |
| if tokenizer.convert_tokens_to_string([token_strings[i]]).startswith(" "): | |
| annotated.append("Ġ") | |
| span_opener = f"Ġ<span class='{predicted_class}' data-tooltip-text='{new_tags[tokenwise_preds[i]]}' style='{style}'>".replace(" ", "Ġ") | |
| span_end = f"<span class='small-text'>{new_tags[tokenwise_preds[i]]}</span></span>" | |
| annotated.extend(out_of_span_tokens) | |
| out_of_span_tokens = [] | |
| span_ends = span[0] | |
| in_span = True | |
| annotated.append(span_end) | |
| annotated.extend([token_strings[x] for x in reversed(range(span[0], span[1]))]) | |
| annotated.append(span_opener) | |
| else: | |
| out_of_span_tokens.append(token_strings[i]) | |
| annotated.extend(out_of_span_tokens) | |
| return [x for x in reversed(annotated)] | |
| def gen_json(input_text, max_new_tokens): | |
| streamer = STOKEStreamer(tok, classifier_token, classifier_span) | |
| new_tags = label_map | |
| inputs = tok([f" {input_text}"], return_tensors="pt").to(model.device) | |
| generation_kwargs = dict( | |
| inputs, streamer=streamer, max_new_tokens=max_new_tokens, | |
| repetition_penalty=1.2, do_sample=False | |
| ) | |
| def generate_async(): | |
| model.generate(**generation_kwargs) | |
| thread = Thread(target=generate_async) | |
| thread.start() | |
| # Display generated text as it becomes available | |
| output_text = "" | |
| text_tokenwise = "" | |
| text_spans = "" | |
| removed_spans = "" | |
| tags = [] | |
| spans = [] | |
| for new_text in streamer: | |
| if new_text[1] is not None and new_text[2] != ['']: | |
| text_tokenwise = "" | |
| output_text = "" | |
| tags.extend(new_text[1]) | |
| spans.extend(new_text[-1]) | |
| # Tokenwise Classification | |
| for tk, pred in zip(new_text[2],tags): | |
| if pred != 0: | |
| style = f"background-color: {map_value_to_color((pred-1)/(len(new_tags)-1))}" | |
| if tk.startswith(" "): | |
| text_tokenwise += " " | |
| text_tokenwise += f"<span class='tooltip highlight' data-tooltip-text='{new_tags[pred]}' style='{style}'>{tk}</span>" | |
| output_text += tk | |
| else: | |
| text_tokenwise += tk | |
| output_text += tk | |
| # Span Classification | |
| text_spans = "" | |
| if len(spans) > 0: | |
| filtered_spans = remove_overlapping_spans(spans) | |
| text_spans = generate_html_no_overlap(new_text[2], filtered_spans) | |
| if len(spans) - len(filtered_spans) > 0: | |
| removed_spans = f"{len(spans) - len(filtered_spans)} span(s) hidden due to overlap." | |
| else: | |
| for tk in new_text[2]: | |
| text_spans += f"{tk}" | |
| # Spanwise Classification | |
| annotated_tokens = generate_html_spanwise(new_text[2], tags, [x for x in filter_spans(spans)[0]], tok, new_tags) | |
| generated_text_spanwise = tok.convert_tokens_to_string(annotated_tokens).replace("<|endoftext|>", "").replace("<|begin_of_text|>", "") | |
| output = f"{css}<br>" | |
| output += generated_text_spanwise.replace("\n", " ").replace("$", "$") + "\n<br>" | |
| #output += "<h5>Show tokenwise classification</h5>\n" + text_tokenwise.replace("\n", " ").replace("$", "\\$").replace("<|endoftext|>", "").replace("<|begin_of_text|>", "") | |
| #output += "</details><details><summary>Show spans</summary>\n" + text_spans.replace("\n", " ").replace("$", "\\$") | |
| #if removed_spans != "": | |
| # output += f"<br><br><i>({removed_spans})</i>" | |
| list_of_spans = [{"name": tok.convert_tokens_to_string(new_text[2][x[0]:x[1]]).strip(), "type": new_tags[tags[x[1]-1]]} for x in filter_spans(spans)[0] if new_tags[tags[x[1]-1]] != "O"] | |
| out_dict = {"text": output_text.replace("<|endoftext|>", "").replace("<|begin_of_text|>", "".strip()), "entites": list_of_spans} | |
| yield out_dict | |
| return | |
| # Gradio app function to generate text using the assigned model instance | |
| def generate_text(input_text, max_new_tokens=2): | |
| if input_text == "": | |
| yield "Please enter some text first." | |
| return | |
| # Select the next model instance in a round-robin manner | |
| model, tok = next(model_round_robin) | |
| streamer = STOKEStreamer(tok, classifier_token, classifier_span) | |
| new_tags = label_map | |
| inputs = tok([f"{input_text[:200]}"], return_tensors="pt").to(model.device) | |
| generation_kwargs = dict( | |
| inputs, streamer=streamer, max_new_tokens=max_new_tokens, | |
| repetition_penalty=1.2, do_sample=False, temperature=None, top_p=None | |
| ) | |
| def generate_async(): | |
| model.generate(**generation_kwargs) | |
| thread = Thread(target=generate_async) | |
| thread.start() | |
| # Display generated text as it becomes available | |
| output_text = "" | |
| text_tokenwise = "" | |
| text_spans = "" | |
| removed_spans = "" | |
| tags = [] | |
| spans = [] | |
| for new_text in streamer: | |
| if new_text[1] is not None and new_text[2] != ['']: | |
| text_tokenwise = "" | |
| output_text = "" | |
| tags.extend(new_text[1]) | |
| spans.extend(new_text[-1]) | |
| # Tokenwise Classification | |
| for tk, pred in zip(new_text[2],tags): | |
| if pred != 0: | |
| style = f"background-color: {map_value_to_color((pred-1)/(len(new_tags)-1))}" | |
| if tk.startswith(" "): | |
| text_tokenwise += " " | |
| text_tokenwise += f"<span class='tooltip highlight' data-tooltip-text='{new_tags[pred]}' style='{style}'>{tk}</span>" | |
| output_text += tk | |
| else: | |
| text_tokenwise += tk | |
| output_text += tk | |
| # Span Classification | |
| text_spans = "" | |
| if len(spans) > 0: | |
| filtered_spans = remove_overlapping_spans(spans) | |
| text_spans = generate_html_no_overlap(new_text[2], filtered_spans) | |
| if len(spans) - len(filtered_spans) > 0: | |
| removed_spans = f"{len(spans) - len(filtered_spans)} span(s) hidden due to overlap." | |
| else: | |
| for tk in new_text[2]: | |
| text_spans += f"{tk}" | |
| # Spanwise Classification | |
| annotated_tokens = generate_html_spanwise(new_text[2], tags, [x for x in filter_spans(spans)[0]], tok, new_tags) | |
| generated_text_spanwise = tok.convert_tokens_to_string(annotated_tokens).replace("<|endoftext|>", "").replace("<|begin_of_text|>", "") | |
| output = f"{css}<div class=\"generated-content\"><br>" | |
| output += generated_text_spanwise.replace("\n", " ").replace("$", "$") + "\n<br>" | |
| list_of_spans = [{"name": tok.convert_tokens_to_string(new_text[2][x[0]:x[1]]).strip(), "type": new_tags[tags[x[1]-1]]} for x in filter_spans(spans)[0] if new_tags[tags[x[1]-1]] != "O"] | |
| out_dict = {"text": output_text.replace("<|endoftext|>", "").replace("<|begin_of_text|>", "").strip(), "entites": list_of_spans} | |
| output_tokenwise = f"""{css}<div class=\"generated-content\"> | |
| <table>""" | |
| output_tokenwise += """<tr><th style="width: 10em!important; background-color: rgba(210, 210, 210, 0.24);">Span detection + label propagation</th>""" | |
| for i, (tk, pred) in enumerate(zip(new_text[2][1:],tags[1:])): | |
| span = "" | |
| if i in [x[0][1]-2 for x in spans] and pred != 0: | |
| top_span = [x for x in spans if x[0][1]-2 == i][0] | |
| spanstring = ''.join(new_text[2][top_span[0][0]:top_span[0][1]]) | |
| color = map_value_to_color((pred-1)/(len(new_tags)-1)) + "88" | |
| span = f"<span class='highlight spanhighlight spantext' style='background-color: {color}; position: absolute; transform: translateX(-50%); white-space: nowrap; top: 1.4em;'>{spanstring}<span class='small-text'>{new_tags[pred]}</span></span>" | |
| output_tokenwise += f"<td class='span-cell-2' style='position:relative; background-color: rgba(210, 210, 210, 0.24);'>{span}</td>" | |
| else: | |
| output_tokenwise += f"<td style='position:relative; background-color: rgba(210, 210, 210, 0.24);'></td>" | |
| output_tokenwise += "</tr><tr><td></td>" | |
| output_tokenwise += """<tr><td>Span detection</td>""" | |
| for i, (tk, pred) in enumerate(zip(new_text[2][1:],tags[:])): | |
| span = "" | |
| if i in [x[0][1]-1 for x in spans]: | |
| top_span = [x for x in spans if x[0][1]-1 == i][0] | |
| spanstring = ''.join(new_text[2][top_span[0][0]:top_span[0][1]]) | |
| span = f"<span class='highlight spanhighlight spantext' style='border: 1px solid red; background-color: lightgrey; position: absolute; left: 0; transform: translateX(-100%); white-space: nowrap; top: 0.6em;'>{spanstring}</span>" | |
| output_tokenwise += f"<td class='span-cell' style='position:relative;'>{span}</td>" | |
| else: | |
| output_tokenwise += f"<td style='position:relative;'></td>" | |
| output_tokenwise += "</tr><tr><td></td>" | |
| output_tokenwise += """<tr><td style='width: 10em; background-color: rgba(210, 210, 210, 0.24);'>Tokenwise<br>entity typing</td>""" | |
| for tk, pred in zip(new_text[2][1:],tags[1:]): | |
| style = "background-color: lightgrey;" | |
| if pred != 0: | |
| style = f"background-color: {map_value_to_color((pred-1)/(len(new_tags)-1))};" | |
| output_tokenwise += f"<td style='background-color: rgba(210, 210, 210, 0.24);'><span class='highlight spanhighlight' style='{style} font-weight:normal; font-size: xx-small; border: 1px solid red; color: white;'>{new_tags[pred]}</span></td>" | |
| else: | |
| output_tokenwise += "<td style='background-color: rgba(210, 210, 210, 0.24);'></td>" | |
| #output_tokenwise += f"<th><span class='arrowtip'></span></th>" | |
| output_tokenwise += "<td></td></tr><tr style='line-height: 0px!important;'><td></td>" | |
| for tk, pred in zip(new_text[2][1:],tags[1:]): | |
| output_tokenwise += f"<td><span class='arrowtip'></span></td>" | |
| output_tokenwise += "</tr><tr><td></td>" | |
| for i, (tk, pred) in enumerate(zip(new_text[2][1:],tags[1:])): | |
| style = "border-color: lightgray;background-color: transparent;" | |
| if i in [x[0][1]-1 for x in spans]: | |
| style = "background-color: yellow;" | |
| output_tokenwise += f"<td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='{style}margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td>" | |
| output_tokenwise += "</tr><tr><td></td>" | |
| for tk, pred in zip(new_text[2][1:],tags[1:]): | |
| if pred != 0: | |
| style = f"background-color: {map_value_to_color((pred-1)/(len(new_tags)-1))}" | |
| output_tokenwise += f"<td class='dashed-cell'><div class='circle tooltip' data-tooltip-text='{new_tags[pred]}' style='{style}'></div></td>" | |
| else: | |
| output_tokenwise += f"<td class='dashed-cell'><div class='circle' style='border-color: lightgray;background-color: transparent;'></div></td>" | |
| output_tokenwise += "</tr><tr><td></td>" | |
| for i, (tk, pred) in enumerate(zip(new_text[2][1:],tags[1:])): | |
| style = "border-color: lightgray;background-color: transparent;" | |
| if i in [x[0][1]-1 for x in spans]: | |
| style = "background-color: yellow;" | |
| output_tokenwise += f"<td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='{style}margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td>" | |
| output_tokenwise += "</tr><tr style='height: 36px;'><td></td>" | |
| for tk, pred in zip(new_text[2][1:],tags[1:]): | |
| output_tokenwise += f"<td class='dashed-cell'></td>" | |
| output_tokenwise += "</tr><tr><td></td>" | |
| for i, (tk, pred) in enumerate(zip(new_text[2][1:],tags[1:])): | |
| style = "border-color: lightgray;background-color: transparent;" | |
| if i in [x[0][1]-1 for x in spans]: | |
| style = "background-color: yellow;" | |
| output_tokenwise += f"<td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='{style}margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td>" | |
| output_tokenwise += "</tr><tr><td></td>" | |
| for tk, pred in zip(new_text[2][1:],tags[1:]): | |
| output_tokenwise += f"<td><span class='highlight spanhighlight' style='background-color: lightgrey;'>{tk}</span></td>" | |
| output_tokenwise += "</tr>" | |
| #yield output + "</div>" | |
| yield output_tokenwise + "</table></div>" | |
| #time.sleep(0.5) | |
| return | |
| # Load datasets and models for the Gradio app | |
| datasets = find_datasets_and_model_ids("data/") | |
| available_models = list(datasets.keys()) | |
| available_datasets = {model: list(datasets[model].keys()) for model in available_models} | |
| available_configs = {model: {dataset: list(datasets[model][dataset].keys()) for dataset in available_datasets[model]} for model in available_models} | |
| def update_datasets(model_name): | |
| return available_datasets[model_name] | |
| def update_configs(model_name, dataset_name): | |
| return available_configs[model_name][dataset_name] | |
| # Load datasets and models for the Gradio app | |
| datasets = find_datasets_and_model_ids("data/") | |
| available_models = list(datasets.keys()) | |
| available_datasets = {model: list(datasets[model].keys()) for model in available_models} | |
| available_configs = {model: {dataset: list(datasets[model][dataset].keys()) for dataset in available_datasets[model]} for model in available_models} | |
| # Set the model ID and data configurations | |
| model_id = "meta-llama/Llama-3.2-1B" | |
| data_id = "STOKE_100" | |
| config_id = "default" | |
| # Load n_instances separate instances of the model and tokenizer | |
| model_instances = get_multiple_model_and_tokenizer(model_id, n_instances) | |
| # Set up the round-robin iterator to distribute the requests across model instances | |
| model_round_robin = itertools.cycle(model_instances) | |
| # Load model classifiers | |
| try: | |
| classifier_span, classifier_token, label_map = get_classifiers_for_model( | |
| model_instances[0][0].config.n_head * model_instances[0][0].config.n_layer, model_instances[0][0].config.n_embd, model_instances[0][0].device, | |
| datasets[model_id][data_id][config_id] | |
| ) | |
| except: | |
| classifier_span, classifier_token, label_map = get_classifiers_for_model( | |
| model_instances[0][0].config.num_attention_heads * model_instances[0][0].config.num_hidden_layers, model_instances[0][0].config.hidden_size, model_instances[0][0].device, | |
| datasets[model_id][data_id][config_id] | |
| ) | |
| initial_output = (css+"""<div class="generated-content"> | |
| <table><tr><th style="width: 10em!important; background-color: rgba(210, 210, 210, 0.24);">Span detection + label propagation</th><td style='position:relative; background-color: rgba(210, 210, 210, 0.24);'></td><td style='position:relative; background-color: rgba(210, 210, 210, 0.24);'></td><td style='position:relative; background-color: rgba(210, 210, 210, 0.24);'></td><td style='position:relative; background-color: rgba(210, 210, 210, 0.24);'></td><td class='span-cell-2' style='position:relative; background-color: rgba(210, 210, 210, 0.24);'><span class='highlight spanhighlight spantext' style='background-color: #9ecae188; position: absolute; transform: translateX(-50%); white-space: nowrap; top: 1.4em;'>The New York Film Festival<span class='small-text'>EVENT</span></span></td><td style='position:relative; background-color: rgba(210, 210, 210, 0.24);'></td><td style='position:relative; background-color: rgba(210, 210, 210, 0.24);'></td><td style='position:relative; background-color: rgba(210, 210, 210, 0.24);'></td></tr><tr><td></td><tr><td>Span detection</td><td style='position:relative;'></td><td style='position:relative;'></td><td style='position:relative;'></td><td style='position:relative;'></td><td style='position:relative;'></td><td class='span-cell' style='position:relative;'><span class='highlight spanhighlight spantext' style='border: 1px solid red; background-color: lightgrey; position: absolute; left: 0; transform: translateX(-100%); white-space: nowrap; top: 0.6em;'>The New York Film Festival</span></td><td style='position:relative;'></td><td style='position:relative;'></td><td style='position:relative;'></td></tr><tr><td></td><tr><td style='width: 10em; background-color: rgba(210, 210, 210, 0.24);'>Tokenwise<br>entity typing</td><td style='background-color: rgba(210, 210, 210, 0.24);'></td><td style='background-color: rgba(210, 210, 210, 0.24);'><span class='highlight spanhighlight' style='background-color: #e6550d; font-weight:normal; font-size: xx-small; border: 1px solid red; color: white;'>GPE</span></td><td style='background-color: rgba(210, 210, 210, 0.24);'><span class='highlight spanhighlight' style='background-color: #756bb1; font-weight:normal; font-size: xx-small; border: 1px solid red; color: white;'>ORG</span></td><td style='background-color: rgba(210, 210, 210, 0.24);'><span class='highlight spanhighlight' style='background-color: #756bb1; font-weight:normal; font-size: xx-small; border: 1px solid red; color: white;'>ORG</span></td><td style='background-color: rgba(210, 210, 210, 0.24);'><span class='highlight spanhighlight' style='background-color: #9ecae1; font-weight:normal; font-size: xx-small; border: 1px solid red; color: white;'>EVENT</span></td><td style='background-color: rgba(210, 210, 210, 0.24);'></td><td style='background-color: rgba(210, 210, 210, 0.24);'></td><td style='background-color: rgba(210, 210, 210, 0.24);'></td><td></td></tr><tr style='line-height: 0px!important;'><td></td><td><span class='arrowtip'></span></td><td><span class='arrowtip'></span></td><td><span class='arrowtip'></span></td><td><span class='arrowtip'></span></td><td><span class='arrowtip'></span></td><td><span class='arrowtip'></span></td><td><span class='arrowtip'></span></td><td><span class='arrowtip'></span></td></tr><tr><td></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='background-color: yellow;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td></tr><tr><td></td><td class='dashed-cell'><div class='circle' style='border-color: lightgray;background-color: transparent;'></div></td><td class='dashed-cell'><div class='circle tooltip' data-tooltip-text='GPE' style='background-color: #e6550d'></div></td><td class='dashed-cell'><div class='circle tooltip' data-tooltip-text='ORG' style='background-color: #756bb1'></div></td><td class='dashed-cell'><div class='circle tooltip' data-tooltip-text='ORG' style='background-color: #756bb1'></div></td><td class='dashed-cell'><div class='circle tooltip' data-tooltip-text='EVENT' style='background-color: #9ecae1'></div></td><td class='dashed-cell'><div class='circle' style='border-color: lightgray;background-color: transparent;'></div></td><td class='dashed-cell'><div class='circle' style='border-color: lightgray;background-color: transparent;'></div></td><td class='dashed-cell'><div class='circle' style='border-color: lightgray;background-color: transparent;'></div></td></tr><tr><td></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='background-color: yellow;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td></tr><tr style='height: 36px;'><td></td><td class='dashed-cell'></td><td class='dashed-cell'></td><td class='dashed-cell'></td><td class='dashed-cell'></td><td class='dashed-cell'></td><td class='dashed-cell'></td><td class='dashed-cell'></td><td class='dashed-cell'></td></tr><tr><td></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='background-color: yellow;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td></tr><tr><td></td><td><span class='highlight spanhighlight' style='background-color: lightgrey;'>The</span></td><td><span class='highlight spanhighlight' style='background-color: lightgrey;'> New</span></td><td><span class='highlight spanhighlight' style='background-color: lightgrey;'> York</span></td><td><span class='highlight spanhighlight' style='background-color: lightgrey;'> Film</span></td><td><span class='highlight spanhighlight' style='background-color: lightgrey;'> Festival</span></td><td><span class='highlight spanhighlight' style='background-color: lightgrey;'> is</span></td><td><span class='highlight spanhighlight' style='background-color: lightgrey;'> an</span></td><td><span class='highlight spanhighlight' style='background-color: lightgrey;'> annual</span></td></tr></table></div>""", {'text': 'Miami is a city in the U.S. state of Florida, and it\'s also known as "The Magic City." It was founded by Henry Flagler on October 28th, 1896.', 'entites': [{'name': 'Miami', 'type': 'GPE'}, {'name': 'U.S.', 'type': 'GPE'}, {'name': 'Florida', 'type': 'GPE'}, {'name': 'The Magic City', 'type': 'WORK_OF_ART'}, {'name': 'Henry Flagler', 'type': 'PERSON'}, {'name': 'October 28th, 1896', 'type': 'DATE'}]}) | |
| with gr.Blocks(css="footer{display:none !important} .gradio-container {padding: 0!important; height:400px;}", fill_width=True, fill_height=True) as demo: | |
| with gr.Tab("EMBER Demo"): | |
| with gr.Row(): | |
| output_text = gr.HTML(label="Generated Text", value=initial_output[0]) | |
| with gr.Group(): | |
| with gr.Row(): | |
| input_text = gr.Textbox(label="Try with your own text!", value="The New York Film Festival is an", max_length=40, submit_btn=True) | |
| # New HTML output for model info | |
| model_info_html = gr.HTML( | |
| label="Model Info", | |
| value=f'<div style="font-weight: lighter; text-align: center; font-size: x-small;">{model_id} running on {gpu_name}</div>' | |
| ) | |
| input_text.submit( | |
| fn=generate_text, | |
| inputs=[input_text], | |
| outputs=[output_text], | |
| concurrency_limit=n_instances, | |
| concurrency_id="queue" | |
| ) | |
| # Function to refresh the model info HTML | |
| def refresh_model_info(): | |
| return f'<div style="overflow: visible; font-weight: lighter; text-align: center; font-size: x-small;">{model_id} running on {gpu_name}</div>' | |
| # Update the model info HTML on button click | |
| input_text.submit( | |
| fn=refresh_model_info, | |
| inputs=[], | |
| outputs=[model_info_html], | |
| queue=False | |
| ) | |
| demo.queue() | |
| demo.launch() | |