Spaces:
Sleeping
Sleeping
| import os.path | |
| import gradio as gr | |
| import pandas as pd | |
| from constants import * | |
| # ------------ 下载链接 ------------ | |
| def get_download_link_model(task, dataset, example): | |
| _task_path = TASK_PATH_MAPPING[task] | |
| _dataset_path = DATASET_PATH_MAPPING[dataset] | |
| _example_path = EXAMPLE_PATH_MAPPING[example] | |
| return os.path.join("data", _task_path, _dataset_path, "weight", f"{_example_path}.zip") | |
| def get_download_link_json(task, dataset, example): | |
| _task_path = TASK_PATH_MAPPING[task] | |
| _dataset_path = DATASET_PATH_MAPPING[dataset] | |
| _example_path = EXAMPLE_PATH_MAPPING[example] | |
| if _task_path == "common": | |
| return os.path.join("data", _task_path, _dataset_path, "json", f"{_example_path}.jsonl") | |
| else: | |
| return os.path.join("data", _task_path, _dataset_path, "json", f"{_example_path}.json") | |
| # ------------ 数据读取 + 平均准确率 ------------ | |
| def get_data(task, dataset, example): | |
| _task_path = TASK_PATH_MAPPING[task] | |
| _dataset_path = DATASET_PATH_MAPPING[dataset] | |
| _example_path = EXAMPLE_PATH_MAPPING[example] | |
| csv_file = os.path.join("data", _task_path, _dataset_path, "csv", f"{_example_path}.csv") | |
| if not os.path.exists(csv_file): | |
| return None, None | |
| read_data = pd.read_csv(csv_file) | |
| data = pd.DataFrame(columns=COLUMN_NAMES) | |
| average_acc = None | |
| if _task_path == "coding": | |
| for _, row in read_data.iterrows(): | |
| data = pd.concat([data, pd.DataFrame([{ | |
| "Prompt": row["prompt"], | |
| "Pass@1": round(float(row["pass@1"]) * 100, 3), | |
| "Pass@5": round(float(row["pass@5"]) * 100, 3), | |
| "Pass@10": round(float(row["pass@10"]) * 100, 3), | |
| "Correctness": "N/A" | |
| }])], ignore_index=True) | |
| # 仅对 HumanEval 数据集计算三列平均 | |
| if "HumanEval" in dataset: | |
| p1_mean = round(read_data["pass@1"].mean() * 100, 3) | |
| p5_mean = round(read_data["pass@5"].mean() * 100, 3) | |
| p10_mean = round(read_data["pass@10"].mean() * 100, 3) | |
| average_acc = f"{p1_mean} / {p5_mean} / {p10_mean}" | |
| elif _task_path in ["common", "math"]: | |
| for _, row in read_data.iterrows(): | |
| data = pd.concat([data, pd.DataFrame([{ | |
| "Prompt": row["prompt"], | |
| "Pass@1": None, | |
| "Pass@5": None, | |
| "Pass@10": None, | |
| "Correctness": "✅" if row["correctness"] else "❌" | |
| }])], ignore_index=True) | |
| average_acc = round(read_data["correctness"].mean() * 100, 3) | |
| return data, average_acc | |
| # ------------ Gradio UI ------------ | |
| with gr.Blocks() as demo_board: | |
| gr.HTML(DND_HEADER) | |
| gr.Markdown(DND_INTRODUCTION) | |
| task = gr.Radio( | |
| label="Task", | |
| choices=TASK_LIST, | |
| value=TASK_LIST[0], | |
| interactive=True, | |
| ) | |
| dataset = gr.Radio( | |
| label="Dataset", | |
| choices=TASK_DATASET_LIST[task.value], | |
| value=TASK_DATASET_LIST[task.value][0], | |
| interactive=True | |
| ) | |
| example = gr.Radio( | |
| label="Example", | |
| choices=EXAMPLE_LIST, | |
| value=EXAMPLE_LIST[0], | |
| interactive=True, | |
| ) | |
| # 平均准确率(放在 Prompt 表格上方) | |
| average_acc_display = gr.Textbox( | |
| label="Average Accuracy (%)", | |
| value=lambda: str(get_data(task.value, dataset.value, example.value)[1]), | |
| interactive=False, | |
| visible=True, | |
| scale=0, | |
| max_lines=1, | |
| min_width=160 | |
| ) | |
| # Prompt 表格 | |
| board = gr.components.Dataframe( | |
| value=lambda: get_data(task.value, dataset.value, example.value)[0], | |
| column_widths=["60%", "10%", "10%", "10%", "10%"], | |
| headers=COLUMN_NAMES, | |
| type="pandas", | |
| datatype=DATA_TITLE_TYPE, | |
| interactive=False, | |
| visible=True, | |
| max_height=500, | |
| ) | |
| # 联动更新:task -> dataset | |
| task.change( | |
| lambda t: gr.Radio( | |
| label="Dataset", | |
| choices=TASK_DATASET_LIST[t], | |
| value=TASK_DATASET_LIST[t][0], | |
| interactive=True, | |
| ), | |
| inputs=[task], | |
| outputs=dataset | |
| ) | |
| # 联动更新:task / dataset / example -> 表格 + 平均准确率 | |
| for component in [task, dataset, example]: | |
| component.change( | |
| lambda t, d, e: (get_data(t, d, e)[0], str(get_data(t, d, e)[1])), | |
| inputs=[task, dataset, example], | |
| outputs=[board, average_acc_display] | |
| ) | |
| # 下载按钮 | |
| with gr.Row(): | |
| json_downloader = gr.DownloadButton("Download JSON", visible=True) | |
| model_downloader = gr.DownloadButton("Download Model", visible=True) | |
| json_downloader.click( | |
| fn=get_download_link_json, | |
| inputs=[task, dataset, example], | |
| outputs=json_downloader, | |
| ) | |
| model_downloader.click( | |
| fn=get_download_link_model, | |
| inputs=[task, dataset, example], | |
| outputs=model_downloader, | |
| ) | |
| # 引用文本 | |
| citation_button = gr.Textbox( | |
| value=CITATION_BUTTON_TEXT, | |
| label=CITATION_BUTTON_LABEL, | |
| elem_id="citation-button", | |
| lines=6, | |
| show_copy_button=True, | |
| ) | |
| # 启动 | |
| demo_board.launch() |