Spaces:
Runtime error
Runtime error
| """Utility function for gradio/external.py""" | |
| import base64 | |
| import json | |
| import math | |
| import operator | |
| import re | |
| import warnings | |
| from typing import Any, Dict, List, Tuple | |
| import requests | |
| import websockets | |
| import yaml | |
| from packaging import version | |
| from websockets.legacy.protocol import WebSocketCommonProtocol | |
| from gradio import components, exceptions | |
| ################## | |
| # Helper functions for processing tabular data | |
| ################## | |
| def get_tabular_examples(model_name: str) -> Dict[str, List[float]]: | |
| readme = requests.get(f"https://huggingface.co/{model_name}/resolve/main/README.md") | |
| if readme.status_code != 200: | |
| warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning) | |
| example_data = {} | |
| else: | |
| yaml_regex = re.search( | |
| "(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text | |
| ) | |
| if yaml_regex is None: | |
| example_data = {} | |
| else: | |
| example_yaml = next( | |
| yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]]) | |
| ) | |
| example_data = example_yaml.get("widget", {}).get("structuredData", {}) | |
| if not example_data: | |
| raise ValueError( | |
| f"No example data found in README.md of {model_name} - Cannot build gradio demo. " | |
| "See the README.md here: https://huggingface.co/scikit-learn/tabular-playground/blob/main/README.md " | |
| "for a reference on how to provide example data to your model." | |
| ) | |
| # replace nan with string NaN for inference API | |
| for data in example_data.values(): | |
| for i, val in enumerate(data): | |
| if isinstance(val, float) and math.isnan(val): | |
| data[i] = "NaN" | |
| return example_data | |
| def cols_to_rows( | |
| example_data: Dict[str, List[float]] | |
| ) -> Tuple[List[str], List[List[float]]]: | |
| headers = list(example_data.keys()) | |
| n_rows = max(len(example_data[header] or []) for header in headers) | |
| data = [] | |
| for row_index in range(n_rows): | |
| row_data = [] | |
| for header in headers: | |
| col = example_data[header] or [] | |
| if row_index >= len(col): | |
| row_data.append("NaN") | |
| else: | |
| row_data.append(col[row_index]) | |
| data.append(row_data) | |
| return headers, data | |
| def rows_to_cols(incoming_data: Dict) -> Dict[str, Dict[str, Dict[str, List[str]]]]: | |
| data_column_wise = {} | |
| for i, header in enumerate(incoming_data["headers"]): | |
| data_column_wise[header] = [str(row[i]) for row in incoming_data["data"]] | |
| return {"inputs": {"data": data_column_wise}} | |
| ################## | |
| # Helper functions for processing other kinds of data | |
| ################## | |
| def postprocess_label(scores: Dict) -> Dict: | |
| sorted_pred = sorted(scores.items(), key=operator.itemgetter(1), reverse=True) | |
| return { | |
| "label": sorted_pred[0][0], | |
| "confidences": [ | |
| {"label": pred[0], "confidence": pred[1]} for pred in sorted_pred | |
| ], | |
| } | |
| def encode_to_base64(r: requests.Response) -> str: | |
| # Handles the different ways HF API returns the prediction | |
| base64_repr = base64.b64encode(r.content).decode("utf-8") | |
| data_prefix = ";base64," | |
| # Case 1: base64 representation already includes data prefix | |
| if data_prefix in base64_repr: | |
| return base64_repr | |
| else: | |
| content_type = r.headers.get("content-type") | |
| # Case 2: the data prefix is a key in the response | |
| if content_type == "application/json": | |
| try: | |
| content_type = r.json()[0]["content-type"] | |
| base64_repr = r.json()[0]["blob"] | |
| except KeyError: | |
| raise ValueError( | |
| "Cannot determine content type returned" "by external API." | |
| ) | |
| # Case 3: the data prefix is included in the response headers | |
| else: | |
| pass | |
| new_base64 = "data:{};base64,".format(content_type) + base64_repr | |
| return new_base64 | |
| ################## | |
| # Helper functions for connecting to websockets | |
| ################## | |
| async def get_pred_from_ws( | |
| websocket: WebSocketCommonProtocol, data: str, hash_data: str | |
| ) -> Dict[str, Any]: | |
| completed = False | |
| resp = {} | |
| while not completed: | |
| msg = await websocket.recv() | |
| resp = json.loads(msg) | |
| if resp["msg"] == "queue_full": | |
| raise exceptions.Error("Queue is full! Please try again.") | |
| if resp["msg"] == "send_hash": | |
| await websocket.send(hash_data) | |
| elif resp["msg"] == "send_data": | |
| await websocket.send(data) | |
| completed = resp["msg"] == "process_completed" | |
| return resp["output"] | |
| def get_ws_fn(ws_url, headers): | |
| async def ws_fn(data, hash_data): | |
| async with websockets.connect( # type: ignore | |
| ws_url, open_timeout=10, extra_headers=headers | |
| ) as websocket: | |
| return await get_pred_from_ws(websocket, data, hash_data) | |
| return ws_fn | |
| def use_websocket(config, dependency): | |
| queue_enabled = config.get("enable_queue", False) | |
| queue_uses_websocket = version.parse( | |
| config.get("version", "2.0") | |
| ) >= version.Version("3.2") | |
| dependency_uses_queue = dependency.get("queue", False) is not False | |
| return queue_enabled and queue_uses_websocket and dependency_uses_queue | |
| ################## | |
| # Helper function for cleaning up an Interface loaded from HF Spaces | |
| ################## | |
| def streamline_spaces_interface(config: Dict) -> Dict: | |
| """Streamlines the interface config dictionary to remove unnecessary keys.""" | |
| config["inputs"] = [ | |
| components.get_component_instance(component) | |
| for component in config["input_components"] | |
| ] | |
| config["outputs"] = [ | |
| components.get_component_instance(component) | |
| for component in config["output_components"] | |
| ] | |
| parameters = { | |
| "article", | |
| "description", | |
| "flagging_options", | |
| "inputs", | |
| "outputs", | |
| "theme", | |
| "title", | |
| } | |
| config = {k: config[k] for k in parameters} | |
| return config | |