YOND / app_function.py
hansen97's picture
ZeroGPU optimize
47547b1
import numpy as np
import rawpy
from PIL import Image
import torch
import yaml
import gradio as gr
import tempfile
import spaces
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import time
from torch.optim import Adam, lr_scheduler
from data_process import *
from utils import *
from archs import *
import sys
# 将 dist 目录添加到 Python 搜索路径
sys.path.append("./dist")
# from dist.isp_algos import *
from isp_algos import VST, inverse_VST, ddim, BiasLUT, SimpleNLF
from bm3d import bm3d
class YOND_Backend:
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#self.device = torch.device('cpu') # 强制使用CPU,避免CUDA相关问题
# 初始化处理参数
self.p = {
"ratio": 1.0,
"ispgain": 1.0,
"h": 2160,
"w": 3840,
"bl": 64.0,
"wp": 1023.0,
"gain": 0.0,
"sigma": 0.0,
"wb": [2.0, 1.0, 2.0],
"ccm": np.eye(3),
"scale": 959.0, # 1023-64
"ransac": False,
"ddim_mode": False,
}
# 状态变量
self.raw_data = None
self.denoised_data = None
self.denoised_npy = None
self.denoised_rgb = None
self.mask_data = None
self.yond = None
self.bias_lut = None
# 新增:释放模型资源的方法
def unload_model(self):
if self.yond is not None:
del self.yond
self.yond = None
self.bias_lut = None
torch.cuda.empty_cache() # 清理GPU缓存
print("Model has unloaded, please reload config")
gr.Success("Model has unloaded, please reload config")
# 新增:清理缓存的方法
def clear_cache(self):
# 清理处理过程中的临时缓存、中间变量等
self.raw_data = None
self.denoised_data = None
self.denoised_npy = None
self.denoised_rgb = None
self.mask_data = None
gc.collect()
print("Images has clear, please reload images")
gr.Success("Images has clear, please reload images")
def update_param(self, param, value):
"""更新处理参数"""
try:
if param in ['h', 'w']:
self.p[param] = int(value)
else:
self.p[param] = float(value)
# 自动更新相关参数
if param in ['wp', 'bl']:
self.p['scale'] = self.p['wp'] - self.p['bl']
except (ValueError, TypeError) as e:
gr.Error(f"参数更新失败: {str(e)}")
raise ValueError(f"无效的参数值: {value}") from e
def load_config(self, config_path):
"""加载配置文件"""
try:
self.yond = YOND_anytest(config_path, self.device)
gr.Success(f"配置加载成功: {config_path}", duration=2)
gr.Success(f"当前设备: {self.device}", duration=2)
except Exception as e:
gr.Error(f"配置加载失败: {str(e)}")
raise RuntimeError(f"配置加载失败: {str(e)}")
args = self.yond.args
if 'pipeline' in args:
self.p.update(args['pipeline'])
else:
self.p.update({'epoch':10, 'sigma_t':0.8, 'eta_t':0.85})
model_path = f"{self.yond.fast_ckpt}/{self.yond.yond_name}_last_model.pth"
self.load_model(model_path)
# return model_path
def load_model(self, model_path):
"""加载预训练模型"""
try:
# 加载模型权重
self.yond.load_model(model_path)
self.bias_lut = BiasLUT(lut_path='checkpoints/bias_lut_2d.npy')
if self.bias_lut is None:
gr.Error(f"BiasLUT加载失败: {os.path.exists('checkpoints/bias_lut_2d.npy')}")
gr.Success(f"模型加载成功: {model_path}", duration=2)
except Exception as e:
gr.Error(f"模型加载失败: {str(e)}")
raise RuntimeError(f"模型加载失败: {str(e)}") from e
def process_image(self, file_path, h, w, bl, wp, ratio, ispgain):
"""处理原始图像文件"""
try:
gr.Warning("正在可视化图像")
# 更新处理参数
self.update_param('h', h)
self.update_param('w', w)
self.update_param('bl', bl)
self.update_param('wp', wp)
self.update_param('ratio', ratio)
self.update_param('ispgain', ispgain)
# 重新初始化
self.raw_data = None
self.denoised_data = None
self.mask_data = None
self.p.update({'wb':[2,1,2], 'ccm':np.eye(3)})
if file_path.lower().endswith(('.arw','.dng','.nef','.cr2')):
with rawpy.imread(str(file_path)) as raw:
self.raw_data = raw.raw_image_visible.astype(np.uint16)
wb, ccm = self._extract_color_params(raw)
h, w = self.raw_data.shape
bl, wp = raw.black_level_per_channel[0], raw.white_level
scale = wp - bl
self.p.update({'wb':wb,'ccm':ccm,'h':h,'w':w,'bl':bl,'wp':wp,'scale':scale})
elif file_path.lower().endswith(('.raw', '.npy')):
try:
self.raw_data = np.fromfile(file_path, dtype=np.uint16)
self.raw_data = self.raw_data.reshape(
self.p['h'], self.p['w']
)
except Exception as e:
gr.Warning(f"默认参数读取失败: {e}, 尝试使用魔↑术↓技↑巧↓")
info = rawread(file_path)
self.raw_data = info['raw']
self.p.update({
'h': info['h'], 'w': info['w'],
'bl': info['bl'], 'wp': info['wp'],
'scale': info['wp'] - info['bl']
})
gr.Success('基于 魔↑术↓技↑巧↓,参数已更新...', duration=2)
# MATLAB格式处理
elif file_path.lower().endswith('.mat'):
with h5py.File(file_path, 'r') as f:
self.raw_data = np.array(f['x']).astype(np.float32) * 959 + 64
# 尝试读取元数据
meta_path = file_path.replace('NOISY', 'METADATA')
if os.path.exists(meta_path):
self.meta = read_metadata(scipy.io.loadmat(meta_path))#scipy.io.loadmat(meta_path)
self.p.update({
'h': self.raw_data.shape[0], 'w': self.raw_data.shape[1],
'bl': 64, 'wp': 1023, 'scale': 959
})
else:
gr.Error("不支持的格式")
raise ValueError("不支持的格式")
# 生成预览图
self.raw_data = self.raw_data.astype(np.float32)
if self.p['clip']: self.raw_data = self.raw_data.clip(self.p['bl'],self.p['wp'])
preview = self._generate_preview()
return preview, self.p['h'], self.p['w'], self.p['bl'], self.p['wp']
except Exception as e:
gr.Error(f"图像可视化失败: {str(e)}")
raise RuntimeError(f"图像处理失败: {str(e)}") from e
def update_image(self, bl, wp, ratio, ispgain):
"""更新图像文件"""
try:
log("更新图像参数...")
gr.Success("更新图像参数...", duration=2)
# 更新处理参数
if ispgain != self.p['ispgain'] and (bl != self.p['bl'] and wp != self.p['wp'] and ratio != self.p['ratio']):
update_image_flag = True
self.update_param('bl', bl)
self.update_param('wp', wp)
self.update_param('ratio', ratio)
self.update_param('ispgain', ispgain)
# 重新初始化
self.denoised_data = None
self.mask_data = None
if self.raw_data is not None:
gr.Success("图像可视化中...", duration=2)
preview = self._generate_preview()
return preview
else:
gr.Error("请先加载图像")
raise RuntimeError("请先加载图像")
except Exception as e:
gr.Error(f"图像更新失败: {str(e)}")
raise RuntimeError(f"图像更新失败: {str(e)}") from e
@spaces.GPU
def denoise(self, raw_vst, patch_size, nsr):
################# 准备去噪 #################
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.yond.net = self.yond.net.to(self.device)
raw_vst = torch.from_numpy(raw_vst).float().to(self.device).permute(2,0,1)[None,]
if 'guided' in self.yond.arch:
t = torch.tensor(nsr*self.p['sigsnr'], dtype=raw_vst.dtype, device=self.device).view(-1,1,1,1)
# Denoise & pad
target_size = patch_size # 设置目标块大小,可以根据需要调整 GPU:1024
overlap_ratio = 1/8 # 设置重叠率,可以根据需要调整
# 使用改进的 big_image_split 函数
raw_inp, metadata = big_image_split(raw_vst, target_size, overlap_ratio)
raw_dn = torch.zeros_like(raw_inp[:,:4])
with torch.no_grad():
if self.p['ddim_mode']:
for i in range(raw_inp.shape[0]): # 处理所有切块
print(f'Patch: {i+1}/{len(raw_dn)}')
raw_dn[i] = ddim(raw_inp[i][None,].clip(None, 2), self.yond.net, t, epoch=self.p['epoch'],
sigma_t=self.p['sigma_t'], eta=self.p['eta_t'], sigma_corr=1.00)
else:
for i in range(raw_inp.shape[0]): # 处理所有切块
input_tensor = raw_inp[i][None,].clip(None, 2)
raw_dn[i] = self.yond.net(input_tensor, t).clamp(0,None)
# 使用改进的 big_image_merge 函数
raw_dn = big_image_merge(raw_dn, metadata, blend_mode='avg')
################# VST逆变换 #################
raw_dn = raw_dn[0].permute(1,2,0).detach().cpu().numpy()
return raw_dn
def estimate_noise(self, double_est, ransac, patch_size):
"""执行噪声估计"""
if not self.yond:
gr.Error("请先加载模型")
raise RuntimeError("请先加载模型")
try:
gr.Warning("正在估计噪声...")
log('开始估计噪声')
self.p['ransac'] = ransac
# 预处理数据
processed = (self.raw_data - self.p['bl']) / self.p['scale']
lr_raw = bayer2rggb(processed) * self.p['ratio']
# 粗估计
reg, self.mask_data = SimpleNLF(
rggb2bayer(lr_raw),
k=19,
eps=1e-3,
setting={'mode': 'self', 'thr_mode':'score2', 'ransac': self.p['ransac']}
)
self.p['gain'] = reg[0] * self.p['scale']
self.p['sigma'] = np.sqrt(max(reg[1], 0)) * self.p['scale']
if double_est:
log(" 使用精估计")
if self.denoised_npy is None:
log(" 之前没去噪,先去噪再估计")
lr_raw_np = lr_raw * self.p['scale']
######## EM-VST矫正VST噪图期望偏差 ########
bias_base = np.maximum(lr_raw_np, 0)
bias = self.bias_lut.get_lut(bias_base, K=self.p['gain'], sigGs=self.p['sigma'])
raw_vst = VST(lr_raw_np, self.p['sigma'], gain=self.p['gain'])
raw_vst = raw_vst - bias
################# VST变换 #################
lower = VST(0, self.p['sigma'], gain=self.p['gain'])
upper = VST(self.p['scale'], self.p['sigma'], gain=self.p['gain'])
nsr = 1 / (upper - lower)
raw_vst = (raw_vst - lower) / (upper - lower)
################# 去噪 #################
raw_dn = self.denoise(raw_vst, patch_size, nsr)
################# VST逆变换 #################
raw_dn = raw_dn * (upper - lower) + lower
self.denoised_data = inverse_VST(raw_dn, self.p['sigma'], gain=self.p['gain']) / self.p['scale']
self.denoised_npy = rggb2bayer(self.denoised_data)
reg, self.mask_data = SimpleNLF(rggb2bayer(lr_raw), self.denoised_npy, k=13,
setting={'mode':'collab', 'thr_mode':'score3', 'ransac': self.p['ransac']})
self.p['gain'] = reg[0] * self.p['scale']
self.p['sigma'] = np.sqrt(max(reg[1], 0)) * self.p['scale']
# 生成可视化结果
mask_img = self._visualize_mask()
log(f"噪声估计完成: gain={self.p['gain']:.2f}, sigma={self.p['sigma']:.2f}")
gr.Success(f"噪声估计完成: gain={self.p['gain']:.2f}, sigma={self.p['sigma']:.2f}", duration=2)
return mask_img, float(f"{self.p['gain']:.2f}"), float(f"{self.p['sigma']:.2f}")
except Exception as e:
gr.Error(f"噪声估计失败: {str(e)}")
raise RuntimeError(f"噪声估计失败: {str(e)}") from e
def enhance_image(self, gain, sigma, sigsnr, ddim_mode, patch_size):
"""执行图像增强"""
if not self.yond:
log('请先加载模型')
raise RuntimeError("请先加载模型")
try:
gr.Warning("正在增强图像...")
log('正在增强图像...')
# 更新处理参数
self.p['ddim_mode'] = ddim_mode
self.update_param('gain', gain)
self.update_param('sigma', sigma)
self.update_param('sigsnr', sigsnr)
# 数据预处理
processed = ((self.raw_data - self.p['bl']) / self.p['scale'])
lr_raw = bayer2rggb(processed) * self.p['ratio']
lr_raw_np = lr_raw * self.p['scale']
bias_base = np.maximum(lr_raw_np, 0)
bias = self.bias_lut.get_lut(bias_base, K=self.p['gain'], sigGs=self.p['sigma'])
raw_vst = VST(lr_raw_np, self.p['sigma'], gain=self.p['gain'])
raw_vst = raw_vst - bias
################# VST变换 #################
lower = VST(0, self.p['sigma'], gain=self.p['gain'])
upper = VST(self.p['scale'], self.p['sigma'], gain=self.p['gain'])
nsr = 1 / (upper - lower)
raw_vst = (raw_vst - lower) / (upper - lower)
################# 准备去噪 #################
raw_dn = self.denoise(raw_vst, patch_size, nsr)
################# VST逆变换 #################
raw_dn = raw_dn * (upper - lower) + lower
self.denoised_data = inverse_VST(raw_dn, self.p['sigma'], gain=self.p['gain']) / self.p['scale']
self.denoised_npy = rggb2bayer(self.denoised_data)
# 保存结果
result = self._generate_result()
log("图像增强完成,请查看结果")
gr.Success("图像增强完成,请查看结果")
return result
except Exception as e:
gr.Error(f"图像增强失败: {str(e)}")
raise RuntimeError(f"增强失败: {str(e)}") from e
# 私有工具方法 ------------------------------------------------------------
def _extract_color_params(self, raw):
"""从RAW文件中提取颜色参数"""
wb = np.array(raw.camera_whitebalance) / raw.camera_whitebalance[1]
ccm = raw.color_matrix[:3, :3].astype(np.float32)
return wb, ccm if ccm[0,0] != 0 else np.eye(3)
def _generate_preview(self):
"""生成预览图像"""
processed = (self.raw_data - self.p['bl']) / self.p['scale']
rgb = FastISP(bayer2rggb(processed)*self.p['ratio']*self.p['ispgain'],
self.p['wb'], self.p['ccm'])
rgb = (rgb.clip(0, 1) * 255).astype(np.uint8)
preview_img = Image.fromarray(rgb)
return preview_img
def _visualize_mask(self):
"""可视化噪声掩模"""
from matplotlib import pyplot as plt
# 检查是否为单通道mask
if self.mask_data.ndim != 2:
gr.Error("Input mask must be a 2D array")
raise ValueError("Input mask must be a 2D array")
# 创建viridis颜色映射的查找表
cmap = plt.cm.viridis
x = np.linspace(0, 1, 256)
lut = (cmap(x)[:, :3] * 255).astype(np.uint8)
# 将mask值缩放到0-255范围并转换为整数索引
mask_indices = (np.clip(self.mask_data, 0, 1) * 255).astype(np.uint8)
# 使用高级索引进行向量化映射
rgb_img = lut[mask_indices]
# 缩放并转换为PIL图像
rgb_img = cv2.resize(rgb_img, (self.p['w'], self.p['h']), interpolation=cv2.INTER_LINEAR)
mask_img = Image.fromarray(rgb_img)
return mask_img
def _generate_result(self):
"""保存最终结果"""
rgb = FastISP(self.denoised_data*self.p['ispgain'],
self.p['wb'],
self.p['ccm'])
self.denoised_rgb = Image.fromarray((rgb.clip(0, 1) * 255).astype(np.uint8))
return self.denoised_rgb
def save_result_npy(self):
"""保存结果到 NPY 文件"""
if self.denoised_npy is None:
gr.Error("请先进行图像增强")
raise RuntimeError("请先进行图像增强")
with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as tmp_file:
tmp_file_path = tmp_file.name
np.save(tmp_file_path, self.denoised_npy.astype(np.float32))
return tmp_file_path
def save_result_png(self):
"""保存结果到 PNG 文件"""
if self.denoised_npy is None:
gr.Error("请先进行图像增强")
raise RuntimeError("请先进行图像增强")
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file_png:
tmp_file_path_png = tmp_file_png.name
cv2.imwrite(tmp_file_path_png, np.array(self.denoised_rgb)[:,:,::-1])
return tmp_file_path_png
class YONDParser():
def __init__(self, yaml_path="runfiles/Gaussian/gru32n_paper_noclip.yml"):
self.runfile = yaml_path
self.mode = 'eval'
self.debug = False
self.nofig = False
self.nohost = False
self.gpu = 0
class YOND_anytest():
def __init__(self, yaml_path, device):
# 初始化
self.device = device
self.parser = YONDParser(yaml_path)
self.initialization()
def initialization(self):
with open(self.parser.runfile, 'r', encoding="utf-8") as f:
self.args = yaml.load(f.read(), Loader=yaml.FullLoader)
self.mode = self.args['mode'] if self.parser.mode is None else self.parser.mode
if self.parser.debug is True:
self.args['num_workers'] = 0
warnings.warn('You are using debug mode, only main worker(cpu) is used!!!')
if 'clip' not in self.args['dst']:
self.args['dst']['clip'] = False
self.save_plot = False if self.parser.nofig else True
self.args['dst']['mode'] = self.mode
self.hostname, self.hostpath, self.multi_gpu = get_host_with_dir()
self.yond_dir = self.args['checkpoint']
if not self.parser.nohost:
for key in self.args:
if 'dst' in key:
self.args[key]['root_dir'] = f"{self.hostpath}/{self.args[key]['root_dir']}"
self.dst = self.args['dst']
self.arch = self.args['arch']
self.pipe = self.args['pipeline']
if self.pipe['bias_corr'] == 'none':
self.pipe['bias_corr'] = None
self.yond_name = self.args['model_name']
self.method_name = self.args['method_name']
self.fast_ckpt = self.args['fast_ckpt']
self.sample_dir = os.path.join(self.args['result_dir'] ,f"{self.method_name}")
os.makedirs(self.sample_dir, exist_ok=True)
os.makedirs('./logs', exist_ok=True)
#os.makedirs('./metrics', exist_ok=True)
def load_model(self, model_path):
# 模型加载
self.net = globals()[self.arch['name']](self.arch)
model = torch.load(model_path, map_location='cpu')
self.net = load_weights(self.net, model, by_name=False)