YOND / app.py
hansen97's picture
ZeroGPU optimize
47547b1
import sys
import os
import glob
import time
from datetime import datetime
import gradio as gr
from app_function import YOND_Backend
yond = YOND_Backend()
# ------------------- 会话管理与资源释放 -------------------
# 存储会话最后活跃时间(key: 会话ID, value: 最后活跃时间)
active_sessions = {}
# 超时时间(秒),这里设为5分钟(300秒),可自行调整
INACTIVE_TIMEOUT = 120
def create_session():
"""创建新会话并记录日志"""
other_sess = [sess for sess in active_sessions]
if len(other_sess) > 0:
gr.Warning(f"检测到{len(other_sess)}个其他用户正在使用,尝试释放超时资源")
for sess in other_sess:
check_inactive(sess)
session_id = str(time.time_ns())
active_sessions[session_id] = datetime.now()
print(f"新会话创建: {session_id}")
return session_id
def update_heartbeat(session_id):
"""更新会话活跃时间"""
if session_id and session_id in active_sessions:
active_sessions[session_id] = datetime.now()
print(f"会话 {session_id} 心跳更新")
return "active"
return "invalid"
def check_inactive(session_id):
"""检查会话是否超时,超时则释放资源"""
if not session_id or session_id not in active_sessions:
return "invalid"
last_active = active_sessions[session_id]
if (datetime.now() - last_active).total_seconds() > INACTIVE_TIMEOUT:
print(f"会话 {session_id} 超时,释放资源")
gr.Warning(f"会话 {session_id} 超时,释放资源")
try:
yond.unload_model() # 释放模型资源
yond.clear_cache() # 清理缓存
except Exception as e:
print(f"释放资源时出错: {e}")
gr.Error(f"释放资源时出错: {e}")
finally:
if session_id in active_sessions:
del active_sessions[session_id]
return f"session {session_id} released"
return "active"
def close_session(session_id):
"""关闭会话并释放资源"""
if session_id and session_id in active_sessions:
print(f"会话 {session_id} 关闭,释放资源")
gr.Warning(f"会话 {session_id} 关闭,释放资源")
try:
yond.unload_model()
yond.clear_cache()
except Exception as e:
print(f"关闭会话时出错: {e}")
gr.Error(f"关闭会话时出错: {e}")
finally:
if session_id in active_sessions:
del active_sessions[session_id]
return "session closed"
return "invalid session"
# --------------------------------------------------------------
with gr.Blocks(title="YOND WebUI", css="""
#left_panel {
width: 400px !important;
min-width: 400px !important;
max-width: 400px !important;
}
.gradio-container {max-width: 1800px !important}
.log-panel {height: 200px !important; overflow-y: auto}
""") as app:
# ------------------- 会话ID和心跳组件 -------------------
# 生成唯一会话ID
session_id = gr.State(value=create_session)
# 隐藏组件:用于心跳通信
heartbeat_signal = gr.Textbox(visible=False)
session_status = gr.Textbox(visible=False)
# --------------------------------------------------------------
gr.Markdown("""
# 🌌 YOND ([You Only Need a Denoiser](https://arxiv.org/abs/2506.03645)) | Practical Blind Raw Image Denoising
### YOND WebUI Simple Tutorial: (See the [YOND WebUI Introduction](https://vmcl-isp.site/t/topic/201) for a complete usage guide):
1. **[Load Config]** → 2. **[Upload Raw Image]** (or **[Load Example]**) → 3. **[Load Image]** (modified metadata and **[Update Image]**) → 4. **[Noise Estimation]** → 5. **[Denoising]** → ...(Optional Operations)... → **[Release GPU]**
""")
with gr.Row():
with gr.Row():
yaml_files = glob.glob("runfiles/*/*.yml")
config_selector = gr.Dropdown(
label="预设配置",
choices=yaml_files,
value="runfiles/Gaussian/gru32n_ft.yml",
scale=2,
container=False
)
load_config_btn = gr.Button("Load Config", variant="primary", scale=1)
ckpt_files = glob.glob("images/*.*")
example_selector = gr.Dropdown(
label="预设图片",
choices=ckpt_files,
value="images/LRID_outdoor_x5_004_iso6400.dng",
scale=2,
container=False
)
load_example_btn = gr.Button("Load Example", variant="primary", scale=1)
unload_btn = gr.Button("Release GPU", variant="secondary", scale=1)
with gr.Row():
# 左侧控制面板
with gr.Column(scale=1, elem_id="left_panel"):
raw_upload = gr.File(label="Uploaded Raw Image", file_types=[".npy", ".NPY", ".ARW", ".DNG", ".NEF", ".CR2", ".RAW", ".MAT",".arw", ".dng", ".nef", ".cr2", ".raw", ".mat"], type="filepath")
with gr.Accordion("Raw Metadata", open=True):
with gr.Row():
h = gr.Number(label="Height", value=2160, precision=0, scale=1)
w = gr.Number(label="Width", value=3840, precision=0, scale=1)
bl = gr.Number(label="Black Level", value=64.0, precision=1, scale=1)
wp = gr.Number(label="White Point", value=1023.0, precision=1, scale=1)
ratio = gr.Number(label="DGain (x Ratio)", value=1.0, precision=1, scale=1)
ispgain = gr.Number(label="ISPGain (Visual Only)", value=1.0, precision=1, scale=1)
with gr.Row():
image_btn = gr.Button("Load Image", variant="primary")
image_update_btn = gr.Button("Update Image", variant="secondary")
with gr.Accordion("Noise Estimation & Denoising", open=True):
with gr.Row():
use_ransac = gr.Checkbox(label="RANSAC", value=False, scale=1)
double_est = gr.Checkbox(label="Refined Estimation", value=False, scale=1)
gain = gr.Slider(value=0, step=0.1, label="System Gain (K)")
sigma = gr.Slider(value=0, step=0.1, label="Read Noise Level (σ)")
est_btn = gr.Button("Noise Estimation", variant="primary")
use_ddim = gr.Checkbox(label="DDIM Mode (Please use large model, e.g., gru64n)", value=False, scale=1)
with gr.Row():
patch_size = gr.Number(label="Patch Size", value=1024, precision=0)
sigsnr = gr.Number(label="SigSNR", precision=2, value=1.03)
enh_btn = gr.Button("Denoising", variant="primary")
# 右侧显示区域
with gr.Column(scale=2):
with gr.Tabs():
with gr.Tab("Input Image", id="input_tab"):
input_img = gr.Image(label="Noisy Image", type="pil")
with gr.Tab("Output Image", id="output_tab"):
output_img = gr.Image(label="Denoised Image", type="pil")
with gr.Tab("Threshold Mask", id="analysis_tab"):
mask_img = gr.Image(label="mask", type="pil")
with gr.Accordion("Download Manager", open=True):
with gr.Row():
with gr.Column(scale=1):
save_npy_btn = gr.Button("Save as NPY Files", variant="primary")
npy_file = gr.File(label="Denoised NPY Download", visible=True)
with gr.Column(scale=1):
save_png_btn = gr.Button("Save as PNG Files", variant="primary")
png_file = gr.File(label="Denoised PNG Download", visible=True)
# 加载配置
load_config_btn.click(
fn=yond.load_config,
inputs=[config_selector],
# outputs=[model_selector],
)
# 加载预设图片
load_example_btn.click(
fn=yond.process_image,
inputs=[example_selector, h, w, bl, wp, ratio, ispgain],
outputs=[input_img, h, w, bl, wp]
)
# 滑动条绑定
def update_sliders(gain_val, sigma_val):
"""动态调整滑动条范围"""
gain_min = round(0.1 * int(gain_val), 2)
gain_max = round(int(gain_val) * 2.0, 2)
sigma_min = 0
sigma_max = max(2.0 * int(gain_val), int(sigma_val) * 2.0)
return [
gr.update(minimum=gain_min, maximum=gain_max),
gr.update(minimum=sigma_min, maximum=sigma_max)
]
# 加载图片
image_btn.click(
fn=yond.process_image,
inputs=[raw_upload, h, w, bl, wp, ratio, ispgain],
outputs=[input_img, h, w, bl, wp]
)
# 更新图片
image_update_btn.click(
fn=yond.update_image,
inputs=[bl, wp, ratio, ispgain],
outputs=[input_img]
)
# 估计噪声
est_btn.click(
fn=yond.estimate_noise,
inputs=[double_est, use_ransac, patch_size],
outputs=[mask_img, gain, sigma]
).then(
fn=update_sliders,
inputs=[gain, sigma],
outputs=[gain, sigma]
)
# 计算增强
enh_btn.click(
fn=yond.enhance_image,
inputs=[gain, sigma, sigsnr, use_ddim, patch_size],
outputs=[output_img]
)
save_npy_btn.click(
fn=yond.save_result_npy,
outputs=[npy_file]
)
save_png_btn.click(
fn=yond.save_result_png,
outputs=[png_file]
)
# ------------------- Gradio 5.38.0 定时器实现 -------------------
# 创建定时器(使用value参数设置间隔秒数)
heartbeat_timer = gr.Timer(
value=30, # 每30秒触发一次心跳
active=True # 初始激活状态
)
check_timer = gr.Timer(
value=60, # 每60秒检查一次超时
active=True
)
# 绑定定时器事件(使用tick方法)
heartbeat_timer.tick(
fn=update_heartbeat,
inputs=[session_id],
outputs=session_status
)
check_timer.tick(
fn=check_inactive,
inputs=[session_id],
outputs=session_status
)
unload_btn.click(
fn=close_session,
inputs=[session_id],
outputs=session_status
)
# 启动应用时设置关闭回调
if __name__ == "__main__":
app.launch(
server_name="0.0.0.0",
# server_port=7860,
# share=True,
)