Spaces:
Sleeping
Sleeping
| """ | |
| Main application for RGB detection demo. | |
| Any new model should implement the following functions: | |
| - load_model(model_path, img_size=640) | |
| - inference(model, image) | |
| """ | |
| import os | |
| import glob | |
| # import spaces | |
| import gradio as gr | |
| from huggingface_hub import get_token | |
| from utils import ( | |
| check_image, | |
| load_image_from_url, | |
| load_badges, | |
| FlaggedCounter, | |
| ) | |
| from flagging import HuggingFaceDatasetSaver | |
| from blob_utils import decode_blob_data, is_blob_data | |
| from logging_config import get_logger | |
| import install_private_repos # noqa: F401 | |
| from seavision import load_model | |
| # Get loggers | |
| logger = get_logger(__name__) | |
| TITLE = """ | |
| <h1> 🌊 SEA.AI's Vision Demo ✨ </h1> | |
| <p align="center"> | |
| Ahoy! Explore our object detection technology! | |
| Upload a maritime scene image and click <code>Submit</code> | |
| to see the results. | |
| </p> | |
| """ | |
| FLAG_TXT = "Report Mis-detection" | |
| NOTICE = f""" | |
| 🚩 See something off? Your feedback makes a difference! Let us know by | |
| flagging any outcomes that don't seem right. Click the `{FLAG_TXT}` button | |
| to submit the image for review. | |
| """ | |
| css = """ | |
| h1 { | |
| text-align: center; | |
| display: block; | |
| } | |
| """ | |
| # Load model | |
| logger.info("Loading detection model...") | |
| model = load_model("ahoyv2n-MIX-b1.onnx") | |
| model.det_conf_thresh = 0.1 | |
| model.hor_conf_thresh = 0.1 | |
| # @spaces.GPU | |
| def inference(image): | |
| """Run inference on image and return annotated image.""" | |
| logger.debug("Running inference on image") | |
| results = model(image) | |
| logger.debug("Inference completed") | |
| return results.draw(image) | |
| def flag_img_input( | |
| image: gr.Image, | |
| name: str = "anonymous", | |
| email: str = "[email protected]", | |
| subscribe: bool = False, | |
| rating: int = 0, | |
| description: str = "", | |
| ): | |
| """Wrapper for flagging""" | |
| # Decode blob data if necessary | |
| if is_blob_data(image): | |
| image = decode_blob_data(image) | |
| metadata = { | |
| "name": name, | |
| "email": email, | |
| "subscribe": subscribe, | |
| "rating": rating, | |
| "description": description, | |
| } | |
| hf_writer.flag([image], metadata=metadata) | |
| logger.info("Image flagged successfully") | |
| # Flagging | |
| dataset_name = "SEA-AI/crowdsourced-sea-images-v2" | |
| hf_writer = HuggingFaceDatasetSaver(get_token(), dataset_name) | |
| flagged_counter = FlaggedCounter(dataset_name) | |
| theme = gr.themes.Default(primary_hue=gr.themes.colors.indigo) | |
| with gr.Blocks(theme=theme, css=css, title="SEA.AI Vision Demo") as demo: | |
| badges = gr.HTML(load_badges(flagged_counter.count())) | |
| title = gr.HTML(TITLE) | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_input = gr.Image( | |
| label="input", | |
| interactive=True, | |
| sources=["upload", "clipboard"], | |
| ) | |
| img_url = gr.Textbox( | |
| lines=1, | |
| placeholder="or enter URL to image here", | |
| label="input_url", | |
| show_label=False, | |
| ) | |
| with gr.Row(): | |
| clear = gr.ClearButton() | |
| submit = gr.Button("Submit", variant="primary") | |
| with gr.Column(): | |
| img_output = gr.Image(label="output", interactive=False) | |
| flag = gr.Button(FLAG_TXT, visible=False) | |
| notice = gr.Markdown(value=NOTICE, visible=False) | |
| examples = gr.Examples( | |
| examples=glob.glob("examples/*.jpg"), | |
| inputs=img_input, | |
| outputs=img_output, | |
| fn=inference, | |
| cache_examples=True, | |
| ) | |
| # add components to clear when clear button is clicked | |
| clear.add([img_input, img_url, img_output]) | |
| # event listeners | |
| img_url.change(load_image_from_url, [img_url], img_input) | |
| submit.click(check_image, [img_input], None, show_api=False).success( | |
| inference, | |
| [img_input], | |
| img_output, | |
| api_name="inference", | |
| ) | |
| # event listeners with decorators | |
| def _show_hide_flagging(_img_input, _img_output): | |
| visible = _img_output and _img_input["orig_name"] not in os.listdir("examples") | |
| return { | |
| flag: gr.Button(FLAG_TXT, interactive=True, visible=visible), | |
| notice: gr.Markdown(value=NOTICE, visible=visible), | |
| } | |
| # add hidden textbox for name and email (hacky but well...) | |
| name_ = gr.Textbox(label="name", visible=False, value="anonymous") | |
| email_ = gr.Textbox(label="email", visible=False, value="[email protected]") | |
| subscribe_ = gr.Checkbox(label="subscribe", value=False, visible=False) | |
| rating_ = gr.Slider(label="rating", value=0, minimum=0, maximum=5, step=1, visible=False) | |
| description_ = gr.Textbox(label="description", visible=False, value="") | |
| # This needs to be called prior to the first call to callback.flag() | |
| hf_writer.setup([img_input], "flagged") | |
| # Sequential logic when flag button is clicked | |
| flag.click(lambda: gr.Info("Thank you for contributing!"), show_api=False).then( | |
| lambda: {flag: gr.Button(FLAG_TXT, interactive=False)}, | |
| [], | |
| [flag], | |
| show_api=False, | |
| ).then( | |
| flag_img_input, | |
| [img_input, name_, email_, subscribe_, rating_, description_], | |
| [], | |
| preprocess=False, | |
| show_api=True, | |
| api_name="flag_misdetection", | |
| ).then( | |
| lambda: load_badges(flagged_counter.count()), | |
| [], | |
| badges, | |
| show_api=False, | |
| ) | |
| # called during initial load in browser | |
| demo.load(lambda: load_badges(flagged_counter.count()), [], badges, show_api=False) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |