Spaces:
Runtime error
Runtime error
| """ Handy utility functions.""" | |
| from __future__ import annotations | |
| import asyncio | |
| import copy | |
| import inspect | |
| import json | |
| import json.decoder | |
| import os | |
| import pkgutil | |
| import random | |
| import re | |
| import sys | |
| import time | |
| import typing | |
| import warnings | |
| from contextlib import contextmanager | |
| from distutils.version import StrictVersion | |
| from enum import Enum | |
| from io import BytesIO | |
| from numbers import Number | |
| from pathlib import Path | |
| from typing import ( | |
| TYPE_CHECKING, | |
| Any, | |
| Callable, | |
| Dict, | |
| Generator, | |
| List, | |
| NewType, | |
| Tuple, | |
| Type, | |
| TypeVar, | |
| Union, | |
| ) | |
| import aiohttp | |
| import fsspec.asyn | |
| import httpx | |
| import matplotlib.pyplot as plt | |
| import requests | |
| from pydantic import BaseModel, Json, parse_obj_as | |
| import gradio | |
| from gradio.context import Context | |
| from gradio.strings import en | |
| if TYPE_CHECKING: # Only import for type checking (is False at runtime). | |
| from gradio.blocks import BlockContext | |
| from gradio.components import Component | |
| analytics_url = "https://api.gradio.app/" | |
| PKG_VERSION_URL = "https://api.gradio.app/pkg-version" | |
| JSON_PATH = os.path.join(os.path.dirname(gradio.__file__), "launches.json") | |
| T = TypeVar("T") | |
| def version_check(): | |
| try: | |
| version_data = pkgutil.get_data(__name__, "version.txt") | |
| if not version_data: | |
| raise FileNotFoundError | |
| current_pkg_version = version_data.decode("ascii").strip() | |
| latest_pkg_version = requests.get(url=PKG_VERSION_URL, timeout=3).json()[ | |
| "version" | |
| ] | |
| if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version): | |
| print( | |
| "IMPORTANT: You are using gradio version {}, " | |
| "however version {} " | |
| "is available, please upgrade.".format( | |
| current_pkg_version, latest_pkg_version | |
| ) | |
| ) | |
| print("--------") | |
| except json.decoder.JSONDecodeError: | |
| warnings.warn("unable to parse version details from package URL.") | |
| except KeyError: | |
| warnings.warn("package URL does not contain version info.") | |
| except: | |
| pass | |
| def get_local_ip_address() -> str: | |
| """Gets the public IP address or returns the string "No internet connection" if unable to obtain it.""" | |
| try: | |
| ip_address = requests.get( | |
| "https://checkip.amazonaws.com/", timeout=3 | |
| ).text.strip() | |
| except (requests.ConnectionError, requests.exceptions.ReadTimeout): | |
| ip_address = "No internet connection" | |
| return ip_address | |
| def initiated_analytics(data: Dict[str, Any]) -> None: | |
| try: | |
| requests.post( | |
| analytics_url + "gradio-initiated-analytics/", data=data, timeout=3 | |
| ) | |
| except (requests.ConnectionError, requests.exceptions.ReadTimeout): | |
| pass # do not push analytics if no network | |
| def launch_analytics(data: Dict[str, Any]) -> None: | |
| try: | |
| requests.post( | |
| analytics_url + "gradio-launched-analytics/", data=data, timeout=3 | |
| ) | |
| except (requests.ConnectionError, requests.exceptions.ReadTimeout): | |
| pass # do not push analytics if no network | |
| def integration_analytics(data: Dict[str, Any]) -> None: | |
| try: | |
| requests.post( | |
| analytics_url + "gradio-integration-analytics/", data=data, timeout=3 | |
| ) | |
| except (requests.ConnectionError, requests.exceptions.ReadTimeout): | |
| pass # do not push analytics if no network | |
| def error_analytics(ip_address: str, message: str) -> None: | |
| """ | |
| Send error analytics if there is network | |
| :param ip_address: IP address where error occurred | |
| :param message: Details about error | |
| """ | |
| data = {"ip_address": ip_address, "error": message} | |
| try: | |
| requests.post(analytics_url + "gradio-error-analytics/", data=data, timeout=3) | |
| except (requests.ConnectionError, requests.exceptions.ReadTimeout): | |
| pass # do not push analytics if no network | |
| async def log_feature_analytics(ip_address: str, feature: str) -> None: | |
| data = {"ip_address": ip_address, "feature": feature} | |
| async with aiohttp.ClientSession() as session: | |
| try: | |
| async with session.post( | |
| analytics_url + "gradio-feature-analytics/", data=data | |
| ): | |
| pass | |
| except (aiohttp.ClientError): | |
| pass # do not push analytics if no network | |
| def colab_check() -> bool: | |
| """ | |
| Check if interface is launching from Google Colab | |
| :return is_colab (bool): True or False | |
| """ | |
| is_colab = False | |
| try: # Check if running interactively using ipython. | |
| from IPython import get_ipython | |
| from_ipynb = get_ipython() | |
| if "google.colab" in str(from_ipynb): | |
| is_colab = True | |
| except (ImportError, NameError): | |
| pass | |
| return is_colab | |
| def ipython_check() -> bool: | |
| """ | |
| Check if interface is launching from iPython (not colab) | |
| :return is_ipython (bool): True or False | |
| """ | |
| is_ipython = False | |
| try: # Check if running interactively using ipython. | |
| from IPython import get_ipython | |
| if get_ipython() is not None: | |
| is_ipython = True | |
| except (ImportError, NameError): | |
| pass | |
| return is_ipython | |
| def readme_to_html(article: str) -> str: | |
| try: | |
| response = requests.get(article, timeout=3) | |
| if response.status_code == requests.codes.ok: # pylint: disable=no-member | |
| article = response.text | |
| except requests.exceptions.RequestException: | |
| pass | |
| return article | |
| def show_tip(interface: gradio.Blocks) -> None: | |
| if interface.show_tips and random.random() < 1.5: | |
| tip: str = random.choice(en["TIPS"]) | |
| print(f"Tip: {tip}") | |
| def launch_counter() -> None: | |
| try: | |
| if not os.path.exists(JSON_PATH): | |
| launches = {"launches": 1} | |
| with open(JSON_PATH, "w+") as j: | |
| json.dump(launches, j) | |
| else: | |
| with open(JSON_PATH) as j: | |
| launches = json.load(j) | |
| launches["launches"] += 1 | |
| if launches["launches"] in [25, 50, 150, 500, 1000]: | |
| print(en["BETA_INVITE"]) | |
| with open(JSON_PATH, "w") as j: | |
| j.write(json.dumps(launches)) | |
| except: | |
| pass | |
| def get_default_args(func: Callable) -> List[Any]: | |
| signature = inspect.signature(func) | |
| return [ | |
| v.default if v.default is not inspect.Parameter.empty else None | |
| for v in signature.parameters.values() | |
| ] | |
| def assert_configs_are_equivalent_besides_ids( | |
| config1: Dict, config2: Dict, root_keys: Tuple = ("mode", "theme") | |
| ): | |
| """Allows you to test if two different Blocks configs produce the same demo. | |
| Parameters: | |
| config1 (dict): nested dict with config from the first Blocks instance | |
| config2 (dict): nested dict with config from the second Blocks instance | |
| root_keys (Tuple): an interable consisting of which keys to test for equivalence at | |
| the root level of the config. By default, only "mode" and "theme" are tested, | |
| so keys like "version" are ignored. | |
| """ | |
| config1 = copy.deepcopy(config1) | |
| config2 = copy.deepcopy(config2) | |
| for key in root_keys: | |
| assert config1[key] == config2[key], f"Configs have different: {key}" | |
| assert len(config1["components"]) == len( | |
| config2["components"] | |
| ), "# of components are different" | |
| def assert_same_components(config1_id, config2_id): | |
| c1 = list(filter(lambda c: c["id"] == config1_id, config1["components"]))[0] | |
| c2 = list(filter(lambda c: c["id"] == config2_id, config2["components"]))[0] | |
| c1 = copy.deepcopy(c1) | |
| c1.pop("id") | |
| c2 = copy.deepcopy(c2) | |
| c2.pop("id") | |
| assert c1 == c2, f"{c1} does not match {c2}" | |
| def same_children_recursive(children1, chidren2): | |
| for child1, child2 in zip(children1, chidren2): | |
| assert_same_components(child1["id"], child2["id"]) | |
| if "children" in child1 or "children" in child2: | |
| same_children_recursive(child1["children"], child2["children"]) | |
| children1 = config1["layout"]["children"] | |
| children2 = config2["layout"]["children"] | |
| same_children_recursive(children1, children2) | |
| for d1, d2 in zip(config1["dependencies"], config2["dependencies"]): | |
| for t1, t2 in zip(d1.pop("targets"), d2.pop("targets")): | |
| assert_same_components(t1, t2) | |
| for i1, i2 in zip(d1.pop("inputs"), d2.pop("inputs")): | |
| assert_same_components(i1, i2) | |
| for o1, o2 in zip(d1.pop("outputs"), d2.pop("outputs")): | |
| assert_same_components(o1, o2) | |
| assert d1 == d2, f"{d1} does not match {d2}" | |
| return True | |
| def format_ner_list(input_string: str, ner_groups: List[Dict[str, str | int]]): | |
| if len(ner_groups) == 0: | |
| return [(input_string, None)] | |
| output = [] | |
| end = 0 | |
| prev_end = 0 | |
| for group in ner_groups: | |
| entity, start, end = group["entity_group"], group["start"], group["end"] | |
| output.append((input_string[prev_end:start], None)) | |
| output.append((input_string[start:end], entity)) | |
| prev_end = end | |
| output.append((input_string[end:], None)) | |
| return output | |
| def delete_none(_dict: T, skip_value: bool = False) -> T: | |
| """ | |
| Delete None values recursively from all of the dictionaries, tuples, lists, sets. | |
| Credit: https://stackoverflow.com/a/66127889/5209347 | |
| """ | |
| if isinstance(_dict, dict): | |
| for key, value in list(_dict.items()): | |
| if skip_value and key == "value": | |
| continue | |
| if isinstance(value, (list, dict, tuple, set)): | |
| _dict[key] = delete_none(value) | |
| elif value is None or key is None: | |
| del _dict[key] | |
| elif isinstance(_dict, (list, set, tuple)): | |
| _dict = type(_dict)(delete_none(item) for item in _dict if item is not None) | |
| return _dict | |
| def resolve_singleton(_list: List[Any] | Any) -> Any: | |
| if len(_list) == 1: | |
| return _list[0] | |
| else: | |
| return _list | |
| def component_or_layout_class(cls_name: str) -> Type[Component] | Type[BlockContext]: | |
| """ | |
| Returns the component, template, or layout class with the given class name, or | |
| raises a ValueError if not found. | |
| Parameters: | |
| cls_name (str): lower-case string class name of a component | |
| Returns: | |
| cls: the component class | |
| """ | |
| import gradio.blocks | |
| import gradio.components | |
| import gradio.layouts | |
| import gradio.templates | |
| components = [ | |
| (name, cls) | |
| for name, cls in gradio.components.__dict__.items() | |
| if isinstance(cls, type) | |
| ] | |
| templates = [ | |
| (name, cls) | |
| for name, cls in gradio.templates.__dict__.items() | |
| if isinstance(cls, type) | |
| ] | |
| layouts = [ | |
| (name, cls) | |
| for name, cls in gradio.layouts.__dict__.items() | |
| if isinstance(cls, type) | |
| ] | |
| for name, cls in components + templates + layouts: | |
| if name.lower() == cls_name.replace("_", "") and ( | |
| issubclass(cls, gradio.components.Component) | |
| or issubclass(cls, gradio.blocks.BlockContext) | |
| ): | |
| return cls | |
| raise ValueError(f"No such component or layout: {cls_name}") | |
| def synchronize_async(func: Callable, *args, **kwargs) -> Any: | |
| """ | |
| Runs async functions in sync scopes. | |
| Can be used in any scope. See run_coro_in_background for more details. | |
| Example: | |
| if inspect.iscoroutinefunction(block_fn.fn): | |
| predictions = utils.synchronize_async(block_fn.fn, *processed_input) | |
| Args: | |
| func: | |
| *args: | |
| **kwargs: | |
| """ | |
| return fsspec.asyn.sync(fsspec.asyn.get_loop(), func, *args, **kwargs) | |
| def run_coro_in_background(func: Callable, *args, **kwargs): | |
| """ | |
| Runs coroutines in background. | |
| Warning, be careful to not use this function in other than FastAPI scope, because the event_loop has not started yet. | |
| You can use it in any scope reached by FastAPI app. | |
| correct scope examples: endpoints in routes, Blocks.process_api | |
| incorrect scope examples: Blocks.launch | |
| Use startup_events in routes.py if you need to run a coro in background in Blocks.launch(). | |
| Example: | |
| utils.run_coro_in_background(fn, *args, **kwargs) | |
| Args: | |
| func: | |
| *args: | |
| **kwargs: | |
| Returns: | |
| """ | |
| event_loop = asyncio.get_event_loop() | |
| return event_loop.create_task(func(*args, **kwargs)) | |
| def async_iteration(iterator): | |
| try: | |
| return next(iterator) | |
| except StopIteration: | |
| # raise a ValueError here because co-routines can't raise StopIteration themselves | |
| raise StopAsyncIteration() | |
| class AsyncRequest: | |
| """ | |
| The AsyncRequest class is a low-level API that allow you to create asynchronous HTTP requests without a context manager. | |
| Compared to making calls by using httpx directly, AsyncRequest offers more flexibility and control over: | |
| (1) Includes response validation functionality both using validation models and functions. | |
| (2) Since we're still using httpx.Request class by wrapping it, we have all it's functionalities. | |
| (3) Exceptions are handled silently during the request call, which gives us the ability to inspect each one | |
| individually in the case of multiple asynchronous request calls and some of them failing. | |
| (4) Provides HTTP request types with AsyncRequest.Method Enum class for ease of usage | |
| AsyncRequest also offers some util functions such as has_exception, is_valid and status to inspect get detailed | |
| information about executed request call. | |
| The basic usage of AsyncRequest is as follows: create a AsyncRequest object with inputs(method, url etc.). Then use it | |
| with the "await" statement, and then you can use util functions to do some post request checks depending on your use-case. | |
| Finally, call the get_validated_data function to get the response data. | |
| You can see example usages in test_utils.py. | |
| """ | |
| ResponseJson = NewType("ResponseJson", Json) | |
| client = httpx.AsyncClient() | |
| class Method(str, Enum): | |
| """ | |
| Method is an enumeration class that contains possible types of HTTP request methods. | |
| """ | |
| ANY = "*" | |
| CONNECT = "CONNECT" | |
| HEAD = "HEAD" | |
| GET = "GET" | |
| DELETE = "DELETE" | |
| OPTIONS = "OPTIONS" | |
| PATCH = "PATCH" | |
| POST = "POST" | |
| PUT = "PUT" | |
| TRACE = "TRACE" | |
| def __init__( | |
| self, | |
| method: Method, | |
| url: str, | |
| *, | |
| validation_model: Type[BaseModel] | None = None, | |
| validation_function: Union[Callable, None] = None, | |
| exception_type: Type[Exception] = Exception, | |
| raise_for_status: bool = False, | |
| **kwargs, | |
| ): | |
| """ | |
| Initialize the Request instance. | |
| Args: | |
| method(Request.Method) : method of the request | |
| url(str): url of the request | |
| * | |
| validation_model(Type[BaseModel]): a pydantic validation class type to use in validation of the response | |
| validation_function(Callable): a callable instance to use in validation of the response | |
| exception_class(Type[Exception]): a exception type to throw with its type | |
| raise_for_status(bool): a flag that determines to raise httpx.Request.raise_for_status() exceptions. | |
| """ | |
| self._exception: Union[Exception, None] = None | |
| self._status = None | |
| self._raise_for_status = raise_for_status | |
| self._validation_model = validation_model | |
| self._validation_function = validation_function | |
| self._exception_type = exception_type | |
| self._validated_data = None | |
| # Create request | |
| self._request = self._create_request(method, url, **kwargs) | |
| def __await__(self) -> Generator[None, Any, "AsyncRequest"]: | |
| """ | |
| Wrap Request's __await__ magic function to create request calls which are executed in one line. | |
| """ | |
| return self.__run().__await__() | |
| async def __run(self) -> AsyncRequest: | |
| """ | |
| Manage the request call lifecycle. | |
| Execute the request by sending it through the client, then check its status. | |
| Then parse the request into Json format. And then validate it using the provided validation methods. | |
| If a problem occurs in this sequential process, | |
| an exception will be raised within the corresponding method, and allowed to be examined. | |
| Manage the request call lifecycle. | |
| Returns: | |
| Request | |
| """ | |
| try: | |
| # Send the request and get the response. | |
| self._response: httpx.Response = await AsyncRequest.client.send( | |
| self._request | |
| ) | |
| # Raise for _status | |
| self._status = self._response.status_code | |
| if self._raise_for_status: | |
| self._response.raise_for_status() | |
| # Parse client response data to JSON | |
| self._json_response_data = self._response.json() | |
| # Validate response data | |
| self._validated_data = self._validate_response_data( | |
| self._json_response_data | |
| ) | |
| except Exception as exception: | |
| # If there is an exception, store it to do further inspections. | |
| self._exception = self._exception_type(exception) | |
| return self | |
| def _create_request(method: Method, url: str, **kwargs) -> httpx.Request: | |
| """ | |
| Create a request. This is a httpx request wrapper function. | |
| Args: | |
| method(Request.Method): request method type | |
| url(str): target url of the request | |
| **kwargs | |
| Returns: | |
| Request | |
| """ | |
| request = httpx.Request(method, url, **kwargs) | |
| return request | |
| def _validate_response_data( | |
| self, response: ResponseJson | |
| ) -> Union[BaseModel, ResponseJson | None]: | |
| """ | |
| Validate response using given validation methods. If there is a validation method and response is not valid, | |
| validation functions will raise an exception for them. | |
| Args: | |
| response(ResponseJson): response object | |
| Returns: | |
| ResponseJson: Validated Json object. | |
| """ | |
| # We use raw response as a default value if there is no validation method or response is not valid. | |
| validated_response = response | |
| try: | |
| # If a validation model is provided, validate response using the validation model. | |
| if self._validation_model: | |
| validated_response = self._validate_response_by_model(response) | |
| # Then, If a validation function is provided, validate response using the validation function. | |
| if self._validation_function: | |
| validated_response = self._validate_response_by_validation_function( | |
| response | |
| ) | |
| except Exception as exception: | |
| # If one of the validation methods does not confirm, raised exception will be silently handled. | |
| # We assign this exception to classes instance to do further inspections via is_valid function. | |
| self._exception = exception | |
| return validated_response | |
| def _validate_response_by_model(self, response: ResponseJson) -> BaseModel: | |
| """ | |
| Validate response json using the validation model. | |
| Args: | |
| response(ResponseJson): response object | |
| Returns: | |
| ResponseJson: Validated Json object. | |
| """ | |
| validated_data = BaseModel() | |
| if self._validation_model: | |
| validated_data = parse_obj_as(self._validation_model, response) | |
| return validated_data | |
| def _validate_response_by_validation_function( | |
| self, response: ResponseJson | |
| ) -> ResponseJson | None: | |
| """ | |
| Validate response json using the validation function. | |
| Args: | |
| response(ResponseJson): response object | |
| Returns: | |
| ResponseJson: Validated Json object. | |
| """ | |
| validated_data = None | |
| if self._validation_function: | |
| validated_data = self._validation_function(response) | |
| return validated_data | |
| def is_valid(self, raise_exceptions: bool = False) -> bool: | |
| """ | |
| Check response object's validity+. Raise exceptions if raise_exceptions flag is True. | |
| Args: | |
| raise_exceptions(bool) : a flag to raise exceptions in this check | |
| Returns: | |
| bool: validity of the data | |
| """ | |
| if self.has_exception and self._exception: | |
| if raise_exceptions: | |
| raise self._exception | |
| return False | |
| else: | |
| # If there is no exception, that means there is no validation error. | |
| return True | |
| def get_validated_data(self): | |
| return self._validated_data | |
| def json(self): | |
| return self._json_response_data | |
| def exception(self): | |
| return self._exception | |
| def has_exception(self): | |
| return self.exception is not None | |
| def raise_exceptions(self): | |
| if self.has_exception and self._exception: | |
| raise self._exception | |
| def status(self): | |
| return self._status | |
| def set_directory(path: Path | str): | |
| """Context manager that sets the working directory to the given path.""" | |
| origin = Path().absolute() | |
| try: | |
| os.chdir(path) | |
| yield | |
| finally: | |
| os.chdir(origin) | |
| def strip_invalid_filename_characters(filename: str, max_bytes: int = 200) -> str: | |
| """Strips invalid characters from a filename and ensures that the file_length is less than `max_bytes` bytes.""" | |
| filename = "".join([char for char in filename if char.isalnum() or char in "._- "]) | |
| filename_len = len(filename.encode()) | |
| if filename_len > max_bytes: | |
| while filename_len > max_bytes: | |
| if len(filename) == 0: | |
| break | |
| filename = filename[:-1] | |
| filename_len = len(filename.encode()) | |
| return filename | |
| def sanitize_value_for_csv(value: str | Number) -> str | Number: | |
| """ | |
| Sanitizes a value that is being written to a CSV file to prevent CSV injection attacks. | |
| Reference: https://owasp.org/www-community/attacks/CSV_Injection | |
| """ | |
| if isinstance(value, Number): | |
| return value | |
| unsafe_prefixes = ["=", "+", "-", "@", "\t", "\n"] | |
| unsafe_sequences = [",=", ",+", ",-", ",@", ",\t", ",\n"] | |
| if any(value.startswith(prefix) for prefix in unsafe_prefixes) or any( | |
| sequence in value for sequence in unsafe_sequences | |
| ): | |
| value = "'" + value | |
| return value | |
| def sanitize_list_for_csv(values: List[Any]) -> List[Any]: | |
| """ | |
| Sanitizes a list of values (or a list of list of values) that is being written to a | |
| CSV file to prevent CSV injection attacks. | |
| """ | |
| sanitized_values = [] | |
| for value in values: | |
| if isinstance(value, list): | |
| sanitized_value = [sanitize_value_for_csv(v) for v in value] | |
| sanitized_values.append(sanitized_value) | |
| else: | |
| sanitized_value = sanitize_value_for_csv(value) | |
| sanitized_values.append(sanitized_value) | |
| return sanitized_values | |
| def append_unique_suffix(name: str, list_of_names: List[str]): | |
| """Appends a numerical suffix to `name` so that it does not appear in `list_of_names`.""" | |
| set_of_names: set[str] = set(list_of_names) # for O(1) lookup | |
| if name not in set_of_names: | |
| return name | |
| else: | |
| suffix_counter = 1 | |
| new_name = name + f"_{suffix_counter}" | |
| while new_name in set_of_names: | |
| suffix_counter += 1 | |
| new_name = name + f"_{suffix_counter}" | |
| return new_name | |
| def validate_url(possible_url: str) -> bool: | |
| headers = {"User-Agent": "gradio (https://gradio.app/; [email protected])"} | |
| try: | |
| return requests.get(possible_url, headers=headers).ok | |
| except Exception: | |
| return False | |
| def is_update(val): | |
| return isinstance(val, dict) and "update" in val.get("__type__", "") | |
| def get_continuous_fn(fn: Callable, every: float) -> Callable: | |
| def continuous_fn(*args): | |
| while True: | |
| output = fn(*args) | |
| yield output | |
| time.sleep(every) | |
| return continuous_fn | |
| async def cancel_tasks(task_ids: set[str]): | |
| if sys.version_info < (3, 8): | |
| return None | |
| matching_tasks = [ | |
| task for task in asyncio.all_tasks() if task.get_name() in task_ids | |
| ] | |
| for task in matching_tasks: | |
| task.cancel() | |
| await asyncio.gather(*matching_tasks, return_exceptions=True) | |
| def set_task_name(task, session_hash: str, fn_index: int, batch: bool): | |
| if sys.version_info >= (3, 8) and not ( | |
| batch | |
| ): # You shouldn't be able to cancel a task if it's part of a batch | |
| task.set_name(f"{session_hash}_{fn_index}") | |
| def get_cancel_function( | |
| dependencies: List[Dict[str, Any]] | |
| ) -> Tuple[Callable, List[int]]: | |
| fn_to_comp = {} | |
| for dep in dependencies: | |
| if Context.root_block: | |
| fn_index = next( | |
| i for i, d in enumerate(Context.root_block.dependencies) if d == dep | |
| ) | |
| fn_to_comp[fn_index] = [ | |
| Context.root_block.blocks[o] for o in dep["outputs"] | |
| ] | |
| async def cancel(session_hash: str) -> None: | |
| task_ids = set([f"{session_hash}_{fn}" for fn in fn_to_comp]) | |
| await cancel_tasks(task_ids) | |
| return ( | |
| cancel, | |
| list(fn_to_comp.keys()), | |
| ) | |
| def check_function_inputs_match(fn: Callable, inputs: List, inputs_as_dict: bool): | |
| """ | |
| Checks if the input component set matches the function | |
| Returns: None if valid, a string error message if mismatch | |
| """ | |
| def is_special_typed_parameter(name): | |
| from gradio.routes import Request | |
| """Checks if parameter has a type hint designating it as a gr.Request""" | |
| return parameter_types.get(name, "") == Request | |
| signature = inspect.signature(fn) | |
| parameter_types = typing.get_type_hints(fn) if inspect.isfunction(fn) else {} | |
| min_args = 0 | |
| max_args = 0 | |
| infinity = -1 | |
| for name, param in signature.parameters.items(): | |
| has_default = param.default != param.empty | |
| if param.kind in [param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD]: | |
| if not (is_special_typed_parameter(name)): | |
| if not has_default: | |
| min_args += 1 | |
| max_args += 1 | |
| elif param.kind == param.VAR_POSITIONAL: | |
| max_args = infinity | |
| elif param.kind == param.KEYWORD_ONLY: | |
| if not has_default: | |
| return f"Keyword-only args must have default values for function {fn}" | |
| arg_count = 1 if inputs_as_dict else len(inputs) | |
| if min_args == max_args and max_args != arg_count: | |
| warnings.warn( | |
| f"Expected {max_args} arguments for function {fn}, received {arg_count}." | |
| ) | |
| if arg_count < min_args: | |
| warnings.warn( | |
| f"Expected at least {min_args} arguments for function {fn}, received {arg_count}." | |
| ) | |
| if max_args != infinity and arg_count > max_args: | |
| warnings.warn( | |
| f"Expected maximum {max_args} arguments for function {fn}, received {arg_count}." | |
| ) | |
| class TupleNoPrint(tuple): | |
| # To remove printing function return in notebook | |
| def __repr__(self): | |
| return "" | |
| def __str__(self): | |
| return "" | |
| def tex2svg(formula, *args): | |
| FONTSIZE = 20 | |
| DPI = 300 | |
| plt.rc("mathtext", fontset="cm") | |
| fig = plt.figure(figsize=(0.01, 0.01)) | |
| fig.text(0, 0, r"${}$".format(formula), fontsize=FONTSIZE) | |
| output = BytesIO() | |
| fig.savefig( | |
| output, | |
| dpi=DPI, | |
| transparent=True, | |
| format="svg", | |
| bbox_inches="tight", | |
| pad_inches=0.0, | |
| ) | |
| plt.close(fig) | |
| output.seek(0) | |
| xml_code = output.read().decode("utf-8") | |
| svg_start = xml_code.index("<svg ") | |
| svg_code = xml_code[svg_start:] | |
| svg_code = re.sub(r"<metadata>.*<\/metadata>", "", svg_code, flags=re.DOTALL) | |
| copy_code = f"<span style='font-size: 0px'>{formula}</span>" | |
| return f"{copy_code}{svg_code}" | |