import numpy as np import rawpy from PIL import Image import torch import yaml import gradio as gr import tempfile import spaces # os.environ['CUDA_VISIBLE_DEVICES'] = '0' import time from torch.optim import Adam, lr_scheduler from data_process import * from utils import * from archs import * import sys # 将 dist 目录添加到 Python 搜索路径 sys.path.append("./dist") # from dist.isp_algos import * from isp_algos import VST, inverse_VST, ddim, BiasLUT, SimpleNLF from bm3d import bm3d class YOND_Backend: def __init__(self): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #self.device = torch.device('cpu') # 强制使用CPU,避免CUDA相关问题 # 初始化处理参数 self.p = { "ratio": 1.0, "ispgain": 1.0, "h": 2160, "w": 3840, "bl": 64.0, "wp": 1023.0, "gain": 0.0, "sigma": 0.0, "wb": [2.0, 1.0, 2.0], "ccm": np.eye(3), "scale": 959.0, # 1023-64 "ransac": False, "ddim_mode": False, } # 状态变量 self.raw_data = None self.denoised_data = None self.denoised_npy = None self.denoised_rgb = None self.mask_data = None self.yond = None self.bias_lut = None # 新增:释放模型资源的方法 def unload_model(self): if self.yond is not None: del self.yond self.yond = None self.bias_lut = None torch.cuda.empty_cache() # 清理GPU缓存 print("Model has unloaded, please reload config") gr.Success("Model has unloaded, please reload config") # 新增:清理缓存的方法 def clear_cache(self): # 清理处理过程中的临时缓存、中间变量等 self.raw_data = None self.denoised_data = None self.denoised_npy = None self.denoised_rgb = None self.mask_data = None gc.collect() print("Images has clear, please reload images") gr.Success("Images has clear, please reload images") def update_param(self, param, value): """更新处理参数""" try: if param in ['h', 'w']: self.p[param] = int(value) else: self.p[param] = float(value) # 自动更新相关参数 if param in ['wp', 'bl']: self.p['scale'] = self.p['wp'] - self.p['bl'] except (ValueError, TypeError) as e: gr.Error(f"参数更新失败: {str(e)}") raise ValueError(f"无效的参数值: {value}") from e def load_config(self, config_path): """加载配置文件""" try: self.yond = YOND_anytest(config_path, self.device) gr.Success(f"配置加载成功: {config_path}", duration=2) gr.Success(f"当前设备: {self.device}", duration=2) except Exception as e: gr.Error(f"配置加载失败: {str(e)}") raise RuntimeError(f"配置加载失败: {str(e)}") args = self.yond.args if 'pipeline' in args: self.p.update(args['pipeline']) else: self.p.update({'epoch':10, 'sigma_t':0.8, 'eta_t':0.85}) model_path = f"{self.yond.fast_ckpt}/{self.yond.yond_name}_last_model.pth" self.load_model(model_path) # return model_path def load_model(self, model_path): """加载预训练模型""" try: # 加载模型权重 self.yond.load_model(model_path) self.bias_lut = BiasLUT(lut_path='checkpoints/bias_lut_2d.npy') if self.bias_lut is None: gr.Error(f"BiasLUT加载失败: {os.path.exists('checkpoints/bias_lut_2d.npy')}") gr.Success(f"模型加载成功: {model_path}", duration=2) except Exception as e: gr.Error(f"模型加载失败: {str(e)}") raise RuntimeError(f"模型加载失败: {str(e)}") from e def process_image(self, file_path, h, w, bl, wp, ratio, ispgain): """处理原始图像文件""" try: gr.Warning("正在可视化图像") # 更新处理参数 self.update_param('h', h) self.update_param('w', w) self.update_param('bl', bl) self.update_param('wp', wp) self.update_param('ratio', ratio) self.update_param('ispgain', ispgain) # 重新初始化 self.raw_data = None self.denoised_data = None self.mask_data = None self.p.update({'wb':[2,1,2], 'ccm':np.eye(3)}) if file_path.lower().endswith(('.arw','.dng','.nef','.cr2')): with rawpy.imread(str(file_path)) as raw: self.raw_data = raw.raw_image_visible.astype(np.uint16) wb, ccm = self._extract_color_params(raw) h, w = self.raw_data.shape bl, wp = raw.black_level_per_channel[0], raw.white_level scale = wp - bl self.p.update({'wb':wb,'ccm':ccm,'h':h,'w':w,'bl':bl,'wp':wp,'scale':scale}) elif file_path.lower().endswith(('.raw', '.npy')): try: self.raw_data = np.fromfile(file_path, dtype=np.uint16) self.raw_data = self.raw_data.reshape( self.p['h'], self.p['w'] ) except Exception as e: gr.Warning(f"默认参数读取失败: {e}, 尝试使用魔↑术↓技↑巧↓") info = rawread(file_path) self.raw_data = info['raw'] self.p.update({ 'h': info['h'], 'w': info['w'], 'bl': info['bl'], 'wp': info['wp'], 'scale': info['wp'] - info['bl'] }) gr.Success('基于 魔↑术↓技↑巧↓,参数已更新...', duration=2) # MATLAB格式处理 elif file_path.lower().endswith('.mat'): with h5py.File(file_path, 'r') as f: self.raw_data = np.array(f['x']).astype(np.float32) * 959 + 64 # 尝试读取元数据 meta_path = file_path.replace('NOISY', 'METADATA') if os.path.exists(meta_path): self.meta = read_metadata(scipy.io.loadmat(meta_path))#scipy.io.loadmat(meta_path) self.p.update({ 'h': self.raw_data.shape[0], 'w': self.raw_data.shape[1], 'bl': 64, 'wp': 1023, 'scale': 959 }) else: gr.Error("不支持的格式") raise ValueError("不支持的格式") # 生成预览图 self.raw_data = self.raw_data.astype(np.float32) if self.p['clip']: self.raw_data = self.raw_data.clip(self.p['bl'],self.p['wp']) preview = self._generate_preview() return preview, self.p['h'], self.p['w'], self.p['bl'], self.p['wp'] except Exception as e: gr.Error(f"图像可视化失败: {str(e)}") raise RuntimeError(f"图像处理失败: {str(e)}") from e def update_image(self, bl, wp, ratio, ispgain): """更新图像文件""" try: log("更新图像参数...") gr.Success("更新图像参数...", duration=2) # 更新处理参数 if ispgain != self.p['ispgain'] and (bl != self.p['bl'] and wp != self.p['wp'] and ratio != self.p['ratio']): update_image_flag = True self.update_param('bl', bl) self.update_param('wp', wp) self.update_param('ratio', ratio) self.update_param('ispgain', ispgain) # 重新初始化 self.denoised_data = None self.mask_data = None if self.raw_data is not None: gr.Success("图像可视化中...", duration=2) preview = self._generate_preview() return preview else: gr.Error("请先加载图像") raise RuntimeError("请先加载图像") except Exception as e: gr.Error(f"图像更新失败: {str(e)}") raise RuntimeError(f"图像更新失败: {str(e)}") from e @spaces.GPU def denoise(self, raw_vst, patch_size, nsr): ################# 准备去噪 ################# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.yond.net = self.yond.net.to(self.device) raw_vst = torch.from_numpy(raw_vst).float().to(self.device).permute(2,0,1)[None,] if 'guided' in self.yond.arch: t = torch.tensor(nsr*self.p['sigsnr'], dtype=raw_vst.dtype, device=self.device).view(-1,1,1,1) # Denoise & pad target_size = patch_size # 设置目标块大小,可以根据需要调整 GPU:1024 overlap_ratio = 1/8 # 设置重叠率,可以根据需要调整 # 使用改进的 big_image_split 函数 raw_inp, metadata = big_image_split(raw_vst, target_size, overlap_ratio) raw_dn = torch.zeros_like(raw_inp[:,:4]) with torch.no_grad(): if self.p['ddim_mode']: for i in range(raw_inp.shape[0]): # 处理所有切块 print(f'Patch: {i+1}/{len(raw_dn)}') raw_dn[i] = ddim(raw_inp[i][None,].clip(None, 2), self.yond.net, t, epoch=self.p['epoch'], sigma_t=self.p['sigma_t'], eta=self.p['eta_t'], sigma_corr=1.00) else: for i in range(raw_inp.shape[0]): # 处理所有切块 input_tensor = raw_inp[i][None,].clip(None, 2) raw_dn[i] = self.yond.net(input_tensor, t).clamp(0,None) # 使用改进的 big_image_merge 函数 raw_dn = big_image_merge(raw_dn, metadata, blend_mode='avg') ################# VST逆变换 ################# raw_dn = raw_dn[0].permute(1,2,0).detach().cpu().numpy() return raw_dn def estimate_noise(self, double_est, ransac, patch_size): """执行噪声估计""" if not self.yond: gr.Error("请先加载模型") raise RuntimeError("请先加载模型") try: gr.Warning("正在估计噪声...") log('开始估计噪声') self.p['ransac'] = ransac # 预处理数据 processed = (self.raw_data - self.p['bl']) / self.p['scale'] lr_raw = bayer2rggb(processed) * self.p['ratio'] # 粗估计 reg, self.mask_data = SimpleNLF( rggb2bayer(lr_raw), k=19, eps=1e-3, setting={'mode': 'self', 'thr_mode':'score2', 'ransac': self.p['ransac']} ) self.p['gain'] = reg[0] * self.p['scale'] self.p['sigma'] = np.sqrt(max(reg[1], 0)) * self.p['scale'] if double_est: log(" 使用精估计") if self.denoised_npy is None: log(" 之前没去噪,先去噪再估计") lr_raw_np = lr_raw * self.p['scale'] ######## EM-VST矫正VST噪图期望偏差 ######## bias_base = np.maximum(lr_raw_np, 0) bias = self.bias_lut.get_lut(bias_base, K=self.p['gain'], sigGs=self.p['sigma']) raw_vst = VST(lr_raw_np, self.p['sigma'], gain=self.p['gain']) raw_vst = raw_vst - bias ################# VST变换 ################# lower = VST(0, self.p['sigma'], gain=self.p['gain']) upper = VST(self.p['scale'], self.p['sigma'], gain=self.p['gain']) nsr = 1 / (upper - lower) raw_vst = (raw_vst - lower) / (upper - lower) ################# 去噪 ################# raw_dn = self.denoise(raw_vst, patch_size, nsr) ################# VST逆变换 ################# raw_dn = raw_dn * (upper - lower) + lower self.denoised_data = inverse_VST(raw_dn, self.p['sigma'], gain=self.p['gain']) / self.p['scale'] self.denoised_npy = rggb2bayer(self.denoised_data) reg, self.mask_data = SimpleNLF(rggb2bayer(lr_raw), self.denoised_npy, k=13, setting={'mode':'collab', 'thr_mode':'score3', 'ransac': self.p['ransac']}) self.p['gain'] = reg[0] * self.p['scale'] self.p['sigma'] = np.sqrt(max(reg[1], 0)) * self.p['scale'] # 生成可视化结果 mask_img = self._visualize_mask() log(f"噪声估计完成: gain={self.p['gain']:.2f}, sigma={self.p['sigma']:.2f}") gr.Success(f"噪声估计完成: gain={self.p['gain']:.2f}, sigma={self.p['sigma']:.2f}", duration=2) return mask_img, float(f"{self.p['gain']:.2f}"), float(f"{self.p['sigma']:.2f}") except Exception as e: gr.Error(f"噪声估计失败: {str(e)}") raise RuntimeError(f"噪声估计失败: {str(e)}") from e def enhance_image(self, gain, sigma, sigsnr, ddim_mode, patch_size): """执行图像增强""" if not self.yond: log('请先加载模型') raise RuntimeError("请先加载模型") try: gr.Warning("正在增强图像...") log('正在增强图像...') # 更新处理参数 self.p['ddim_mode'] = ddim_mode self.update_param('gain', gain) self.update_param('sigma', sigma) self.update_param('sigsnr', sigsnr) # 数据预处理 processed = ((self.raw_data - self.p['bl']) / self.p['scale']) lr_raw = bayer2rggb(processed) * self.p['ratio'] lr_raw_np = lr_raw * self.p['scale'] bias_base = np.maximum(lr_raw_np, 0) bias = self.bias_lut.get_lut(bias_base, K=self.p['gain'], sigGs=self.p['sigma']) raw_vst = VST(lr_raw_np, self.p['sigma'], gain=self.p['gain']) raw_vst = raw_vst - bias ################# VST变换 ################# lower = VST(0, self.p['sigma'], gain=self.p['gain']) upper = VST(self.p['scale'], self.p['sigma'], gain=self.p['gain']) nsr = 1 / (upper - lower) raw_vst = (raw_vst - lower) / (upper - lower) ################# 准备去噪 ################# raw_dn = self.denoise(raw_vst, patch_size, nsr) ################# VST逆变换 ################# raw_dn = raw_dn * (upper - lower) + lower self.denoised_data = inverse_VST(raw_dn, self.p['sigma'], gain=self.p['gain']) / self.p['scale'] self.denoised_npy = rggb2bayer(self.denoised_data) # 保存结果 result = self._generate_result() log("图像增强完成,请查看结果") gr.Success("图像增强完成,请查看结果") return result except Exception as e: gr.Error(f"图像增强失败: {str(e)}") raise RuntimeError(f"增强失败: {str(e)}") from e # 私有工具方法 ------------------------------------------------------------ def _extract_color_params(self, raw): """从RAW文件中提取颜色参数""" wb = np.array(raw.camera_whitebalance) / raw.camera_whitebalance[1] ccm = raw.color_matrix[:3, :3].astype(np.float32) return wb, ccm if ccm[0,0] != 0 else np.eye(3) def _generate_preview(self): """生成预览图像""" processed = (self.raw_data - self.p['bl']) / self.p['scale'] rgb = FastISP(bayer2rggb(processed)*self.p['ratio']*self.p['ispgain'], self.p['wb'], self.p['ccm']) rgb = (rgb.clip(0, 1) * 255).astype(np.uint8) preview_img = Image.fromarray(rgb) return preview_img def _visualize_mask(self): """可视化噪声掩模""" from matplotlib import pyplot as plt # 检查是否为单通道mask if self.mask_data.ndim != 2: gr.Error("Input mask must be a 2D array") raise ValueError("Input mask must be a 2D array") # 创建viridis颜色映射的查找表 cmap = plt.cm.viridis x = np.linspace(0, 1, 256) lut = (cmap(x)[:, :3] * 255).astype(np.uint8) # 将mask值缩放到0-255范围并转换为整数索引 mask_indices = (np.clip(self.mask_data, 0, 1) * 255).astype(np.uint8) # 使用高级索引进行向量化映射 rgb_img = lut[mask_indices] # 缩放并转换为PIL图像 rgb_img = cv2.resize(rgb_img, (self.p['w'], self.p['h']), interpolation=cv2.INTER_LINEAR) mask_img = Image.fromarray(rgb_img) return mask_img def _generate_result(self): """保存最终结果""" rgb = FastISP(self.denoised_data*self.p['ispgain'], self.p['wb'], self.p['ccm']) self.denoised_rgb = Image.fromarray((rgb.clip(0, 1) * 255).astype(np.uint8)) return self.denoised_rgb def save_result_npy(self): """保存结果到 NPY 文件""" if self.denoised_npy is None: gr.Error("请先进行图像增强") raise RuntimeError("请先进行图像增强") with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as tmp_file: tmp_file_path = tmp_file.name np.save(tmp_file_path, self.denoised_npy.astype(np.float32)) return tmp_file_path def save_result_png(self): """保存结果到 PNG 文件""" if self.denoised_npy is None: gr.Error("请先进行图像增强") raise RuntimeError("请先进行图像增强") with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file_png: tmp_file_path_png = tmp_file_png.name cv2.imwrite(tmp_file_path_png, np.array(self.denoised_rgb)[:,:,::-1]) return tmp_file_path_png class YONDParser(): def __init__(self, yaml_path="runfiles/Gaussian/gru32n_paper_noclip.yml"): self.runfile = yaml_path self.mode = 'eval' self.debug = False self.nofig = False self.nohost = False self.gpu = 0 class YOND_anytest(): def __init__(self, yaml_path, device): # 初始化 self.device = device self.parser = YONDParser(yaml_path) self.initialization() def initialization(self): with open(self.parser.runfile, 'r', encoding="utf-8") as f: self.args = yaml.load(f.read(), Loader=yaml.FullLoader) self.mode = self.args['mode'] if self.parser.mode is None else self.parser.mode if self.parser.debug is True: self.args['num_workers'] = 0 warnings.warn('You are using debug mode, only main worker(cpu) is used!!!') if 'clip' not in self.args['dst']: self.args['dst']['clip'] = False self.save_plot = False if self.parser.nofig else True self.args['dst']['mode'] = self.mode self.hostname, self.hostpath, self.multi_gpu = get_host_with_dir() self.yond_dir = self.args['checkpoint'] if not self.parser.nohost: for key in self.args: if 'dst' in key: self.args[key]['root_dir'] = f"{self.hostpath}/{self.args[key]['root_dir']}" self.dst = self.args['dst'] self.arch = self.args['arch'] self.pipe = self.args['pipeline'] if self.pipe['bias_corr'] == 'none': self.pipe['bias_corr'] = None self.yond_name = self.args['model_name'] self.method_name = self.args['method_name'] self.fast_ckpt = self.args['fast_ckpt'] self.sample_dir = os.path.join(self.args['result_dir'] ,f"{self.method_name}") os.makedirs(self.sample_dir, exist_ok=True) os.makedirs('./logs', exist_ok=True) #os.makedirs('./metrics', exist_ok=True) def load_model(self, model_path): # 模型加载 self.net = globals()[self.arch['name']](self.arch) model = torch.load(model_path, map_location='cpu') self.net = load_weights(self.net, model, by_name=False)