Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import rawpy | |
| from PIL import Image | |
| import torch | |
| import yaml | |
| import gradio as gr | |
| import tempfile | |
| import spaces | |
| # os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |
| import time | |
| from torch.optim import Adam, lr_scheduler | |
| from data_process import * | |
| from utils import * | |
| from archs import * | |
| import sys | |
| # 将 dist 目录添加到 Python 搜索路径 | |
| sys.path.append("./dist") | |
| # from dist.isp_algos import * | |
| from isp_algos import VST, inverse_VST, ddim, BiasLUT, SimpleNLF | |
| from bm3d import bm3d | |
| class YOND_Backend: | |
| def __init__(self): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| #self.device = torch.device('cpu') # 强制使用CPU,避免CUDA相关问题 | |
| # 初始化处理参数 | |
| self.p = { | |
| "ratio": 1.0, | |
| "ispgain": 1.0, | |
| "h": 2160, | |
| "w": 3840, | |
| "bl": 64.0, | |
| "wp": 1023.0, | |
| "gain": 0.0, | |
| "sigma": 0.0, | |
| "wb": [2.0, 1.0, 2.0], | |
| "ccm": np.eye(3), | |
| "scale": 959.0, # 1023-64 | |
| "ransac": False, | |
| "ddim_mode": False, | |
| } | |
| # 状态变量 | |
| self.raw_data = None | |
| self.denoised_data = None | |
| self.denoised_npy = None | |
| self.denoised_rgb = None | |
| self.mask_data = None | |
| self.yond = None | |
| self.bias_lut = None | |
| # 新增:释放模型资源的方法 | |
| def unload_model(self): | |
| if self.yond is not None: | |
| del self.yond | |
| self.yond = None | |
| self.bias_lut = None | |
| torch.cuda.empty_cache() # 清理GPU缓存 | |
| print("Model has unloaded, please reload config") | |
| gr.Success("Model has unloaded, please reload config") | |
| # 新增:清理缓存的方法 | |
| def clear_cache(self): | |
| # 清理处理过程中的临时缓存、中间变量等 | |
| self.raw_data = None | |
| self.denoised_data = None | |
| self.denoised_npy = None | |
| self.denoised_rgb = None | |
| self.mask_data = None | |
| gc.collect() | |
| print("Images has clear, please reload images") | |
| gr.Success("Images has clear, please reload images") | |
| def update_param(self, param, value): | |
| """更新处理参数""" | |
| try: | |
| if param in ['h', 'w']: | |
| self.p[param] = int(value) | |
| else: | |
| self.p[param] = float(value) | |
| # 自动更新相关参数 | |
| if param in ['wp', 'bl']: | |
| self.p['scale'] = self.p['wp'] - self.p['bl'] | |
| except (ValueError, TypeError) as e: | |
| gr.Error(f"参数更新失败: {str(e)}") | |
| raise ValueError(f"无效的参数值: {value}") from e | |
| def load_config(self, config_path): | |
| """加载配置文件""" | |
| try: | |
| self.yond = YOND_anytest(config_path, self.device) | |
| gr.Success(f"配置加载成功: {config_path}", duration=2) | |
| gr.Success(f"当前设备: {self.device}", duration=2) | |
| except Exception as e: | |
| gr.Error(f"配置加载失败: {str(e)}") | |
| raise RuntimeError(f"配置加载失败: {str(e)}") | |
| args = self.yond.args | |
| if 'pipeline' in args: | |
| self.p.update(args['pipeline']) | |
| else: | |
| self.p.update({'epoch':10, 'sigma_t':0.8, 'eta_t':0.85}) | |
| model_path = f"{self.yond.fast_ckpt}/{self.yond.yond_name}_last_model.pth" | |
| self.load_model(model_path) | |
| # return model_path | |
| def load_model(self, model_path): | |
| """加载预训练模型""" | |
| try: | |
| # 加载模型权重 | |
| self.yond.load_model(model_path) | |
| self.bias_lut = BiasLUT(lut_path='checkpoints/bias_lut_2d.npy') | |
| if self.bias_lut is None: | |
| gr.Error(f"BiasLUT加载失败: {os.path.exists('checkpoints/bias_lut_2d.npy')}") | |
| gr.Success(f"模型加载成功: {model_path}", duration=2) | |
| except Exception as e: | |
| gr.Error(f"模型加载失败: {str(e)}") | |
| raise RuntimeError(f"模型加载失败: {str(e)}") from e | |
| def process_image(self, file_path, h, w, bl, wp, ratio, ispgain): | |
| """处理原始图像文件""" | |
| try: | |
| gr.Warning("正在可视化图像") | |
| # 更新处理参数 | |
| self.update_param('h', h) | |
| self.update_param('w', w) | |
| self.update_param('bl', bl) | |
| self.update_param('wp', wp) | |
| self.update_param('ratio', ratio) | |
| self.update_param('ispgain', ispgain) | |
| # 重新初始化 | |
| self.raw_data = None | |
| self.denoised_data = None | |
| self.mask_data = None | |
| self.p.update({'wb':[2,1,2], 'ccm':np.eye(3)}) | |
| if file_path.lower().endswith(('.arw','.dng','.nef','.cr2')): | |
| with rawpy.imread(str(file_path)) as raw: | |
| self.raw_data = raw.raw_image_visible.astype(np.uint16) | |
| wb, ccm = self._extract_color_params(raw) | |
| h, w = self.raw_data.shape | |
| bl, wp = raw.black_level_per_channel[0], raw.white_level | |
| scale = wp - bl | |
| self.p.update({'wb':wb,'ccm':ccm,'h':h,'w':w,'bl':bl,'wp':wp,'scale':scale}) | |
| elif file_path.lower().endswith(('.raw', '.npy')): | |
| try: | |
| self.raw_data = np.fromfile(file_path, dtype=np.uint16) | |
| self.raw_data = self.raw_data.reshape( | |
| self.p['h'], self.p['w'] | |
| ) | |
| except Exception as e: | |
| gr.Warning(f"默认参数读取失败: {e}, 尝试使用魔↑术↓技↑巧↓") | |
| info = rawread(file_path) | |
| self.raw_data = info['raw'] | |
| self.p.update({ | |
| 'h': info['h'], 'w': info['w'], | |
| 'bl': info['bl'], 'wp': info['wp'], | |
| 'scale': info['wp'] - info['bl'] | |
| }) | |
| gr.Success('基于 魔↑术↓技↑巧↓,参数已更新...', duration=2) | |
| # MATLAB格式处理 | |
| elif file_path.lower().endswith('.mat'): | |
| with h5py.File(file_path, 'r') as f: | |
| self.raw_data = np.array(f['x']).astype(np.float32) * 959 + 64 | |
| # 尝试读取元数据 | |
| meta_path = file_path.replace('NOISY', 'METADATA') | |
| if os.path.exists(meta_path): | |
| self.meta = read_metadata(scipy.io.loadmat(meta_path))#scipy.io.loadmat(meta_path) | |
| self.p.update({ | |
| 'h': self.raw_data.shape[0], 'w': self.raw_data.shape[1], | |
| 'bl': 64, 'wp': 1023, 'scale': 959 | |
| }) | |
| else: | |
| gr.Error("不支持的格式") | |
| raise ValueError("不支持的格式") | |
| # 生成预览图 | |
| self.raw_data = self.raw_data.astype(np.float32) | |
| if self.p['clip']: self.raw_data = self.raw_data.clip(self.p['bl'],self.p['wp']) | |
| preview = self._generate_preview() | |
| return preview, self.p['h'], self.p['w'], self.p['bl'], self.p['wp'] | |
| except Exception as e: | |
| gr.Error(f"图像可视化失败: {str(e)}") | |
| raise RuntimeError(f"图像处理失败: {str(e)}") from e | |
| def update_image(self, bl, wp, ratio, ispgain): | |
| """更新图像文件""" | |
| try: | |
| log("更新图像参数...") | |
| gr.Success("更新图像参数...", duration=2) | |
| # 更新处理参数 | |
| if ispgain != self.p['ispgain'] and (bl != self.p['bl'] and wp != self.p['wp'] and ratio != self.p['ratio']): | |
| update_image_flag = True | |
| self.update_param('bl', bl) | |
| self.update_param('wp', wp) | |
| self.update_param('ratio', ratio) | |
| self.update_param('ispgain', ispgain) | |
| # 重新初始化 | |
| self.denoised_data = None | |
| self.mask_data = None | |
| if self.raw_data is not None: | |
| gr.Success("图像可视化中...", duration=2) | |
| preview = self._generate_preview() | |
| return preview | |
| else: | |
| gr.Error("请先加载图像") | |
| raise RuntimeError("请先加载图像") | |
| except Exception as e: | |
| gr.Error(f"图像更新失败: {str(e)}") | |
| raise RuntimeError(f"图像更新失败: {str(e)}") from e | |
| def denoise(self, raw_vst, patch_size, nsr): | |
| ################# 准备去噪 ################# | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.yond.net = self.yond.net.to(self.device) | |
| raw_vst = torch.from_numpy(raw_vst).float().to(self.device).permute(2,0,1)[None,] | |
| if 'guided' in self.yond.arch: | |
| t = torch.tensor(nsr*self.p['sigsnr'], dtype=raw_vst.dtype, device=self.device).view(-1,1,1,1) | |
| # Denoise & pad | |
| target_size = patch_size # 设置目标块大小,可以根据需要调整 GPU:1024 | |
| overlap_ratio = 1/8 # 设置重叠率,可以根据需要调整 | |
| # 使用改进的 big_image_split 函数 | |
| raw_inp, metadata = big_image_split(raw_vst, target_size, overlap_ratio) | |
| raw_dn = torch.zeros_like(raw_inp[:,:4]) | |
| with torch.no_grad(): | |
| if self.p['ddim_mode']: | |
| for i in range(raw_inp.shape[0]): # 处理所有切块 | |
| print(f'Patch: {i+1}/{len(raw_dn)}') | |
| raw_dn[i] = ddim(raw_inp[i][None,].clip(None, 2), self.yond.net, t, epoch=self.p['epoch'], | |
| sigma_t=self.p['sigma_t'], eta=self.p['eta_t'], sigma_corr=1.00) | |
| else: | |
| for i in range(raw_inp.shape[0]): # 处理所有切块 | |
| input_tensor = raw_inp[i][None,].clip(None, 2) | |
| raw_dn[i] = self.yond.net(input_tensor, t).clamp(0,None) | |
| # 使用改进的 big_image_merge 函数 | |
| raw_dn = big_image_merge(raw_dn, metadata, blend_mode='avg') | |
| ################# VST逆变换 ################# | |
| raw_dn = raw_dn[0].permute(1,2,0).detach().cpu().numpy() | |
| return raw_dn | |
| def estimate_noise(self, double_est, ransac, patch_size): | |
| """执行噪声估计""" | |
| if not self.yond: | |
| gr.Error("请先加载模型") | |
| raise RuntimeError("请先加载模型") | |
| try: | |
| gr.Warning("正在估计噪声...") | |
| log('开始估计噪声') | |
| self.p['ransac'] = ransac | |
| # 预处理数据 | |
| processed = (self.raw_data - self.p['bl']) / self.p['scale'] | |
| lr_raw = bayer2rggb(processed) * self.p['ratio'] | |
| # 粗估计 | |
| reg, self.mask_data = SimpleNLF( | |
| rggb2bayer(lr_raw), | |
| k=19, | |
| eps=1e-3, | |
| setting={'mode': 'self', 'thr_mode':'score2', 'ransac': self.p['ransac']} | |
| ) | |
| self.p['gain'] = reg[0] * self.p['scale'] | |
| self.p['sigma'] = np.sqrt(max(reg[1], 0)) * self.p['scale'] | |
| if double_est: | |
| log(" 使用精估计") | |
| if self.denoised_npy is None: | |
| log(" 之前没去噪,先去噪再估计") | |
| lr_raw_np = lr_raw * self.p['scale'] | |
| ######## EM-VST矫正VST噪图期望偏差 ######## | |
| bias_base = np.maximum(lr_raw_np, 0) | |
| bias = self.bias_lut.get_lut(bias_base, K=self.p['gain'], sigGs=self.p['sigma']) | |
| raw_vst = VST(lr_raw_np, self.p['sigma'], gain=self.p['gain']) | |
| raw_vst = raw_vst - bias | |
| ################# VST变换 ################# | |
| lower = VST(0, self.p['sigma'], gain=self.p['gain']) | |
| upper = VST(self.p['scale'], self.p['sigma'], gain=self.p['gain']) | |
| nsr = 1 / (upper - lower) | |
| raw_vst = (raw_vst - lower) / (upper - lower) | |
| ################# 去噪 ################# | |
| raw_dn = self.denoise(raw_vst, patch_size, nsr) | |
| ################# VST逆变换 ################# | |
| raw_dn = raw_dn * (upper - lower) + lower | |
| self.denoised_data = inverse_VST(raw_dn, self.p['sigma'], gain=self.p['gain']) / self.p['scale'] | |
| self.denoised_npy = rggb2bayer(self.denoised_data) | |
| reg, self.mask_data = SimpleNLF(rggb2bayer(lr_raw), self.denoised_npy, k=13, | |
| setting={'mode':'collab', 'thr_mode':'score3', 'ransac': self.p['ransac']}) | |
| self.p['gain'] = reg[0] * self.p['scale'] | |
| self.p['sigma'] = np.sqrt(max(reg[1], 0)) * self.p['scale'] | |
| # 生成可视化结果 | |
| mask_img = self._visualize_mask() | |
| log(f"噪声估计完成: gain={self.p['gain']:.2f}, sigma={self.p['sigma']:.2f}") | |
| gr.Success(f"噪声估计完成: gain={self.p['gain']:.2f}, sigma={self.p['sigma']:.2f}", duration=2) | |
| return mask_img, float(f"{self.p['gain']:.2f}"), float(f"{self.p['sigma']:.2f}") | |
| except Exception as e: | |
| gr.Error(f"噪声估计失败: {str(e)}") | |
| raise RuntimeError(f"噪声估计失败: {str(e)}") from e | |
| def enhance_image(self, gain, sigma, sigsnr, ddim_mode, patch_size): | |
| """执行图像增强""" | |
| if not self.yond: | |
| log('请先加载模型') | |
| raise RuntimeError("请先加载模型") | |
| try: | |
| gr.Warning("正在增强图像...") | |
| log('正在增强图像...') | |
| # 更新处理参数 | |
| self.p['ddim_mode'] = ddim_mode | |
| self.update_param('gain', gain) | |
| self.update_param('sigma', sigma) | |
| self.update_param('sigsnr', sigsnr) | |
| # 数据预处理 | |
| processed = ((self.raw_data - self.p['bl']) / self.p['scale']) | |
| lr_raw = bayer2rggb(processed) * self.p['ratio'] | |
| lr_raw_np = lr_raw * self.p['scale'] | |
| bias_base = np.maximum(lr_raw_np, 0) | |
| bias = self.bias_lut.get_lut(bias_base, K=self.p['gain'], sigGs=self.p['sigma']) | |
| raw_vst = VST(lr_raw_np, self.p['sigma'], gain=self.p['gain']) | |
| raw_vst = raw_vst - bias | |
| ################# VST变换 ################# | |
| lower = VST(0, self.p['sigma'], gain=self.p['gain']) | |
| upper = VST(self.p['scale'], self.p['sigma'], gain=self.p['gain']) | |
| nsr = 1 / (upper - lower) | |
| raw_vst = (raw_vst - lower) / (upper - lower) | |
| ################# 准备去噪 ################# | |
| raw_dn = self.denoise(raw_vst, patch_size, nsr) | |
| ################# VST逆变换 ################# | |
| raw_dn = raw_dn * (upper - lower) + lower | |
| self.denoised_data = inverse_VST(raw_dn, self.p['sigma'], gain=self.p['gain']) / self.p['scale'] | |
| self.denoised_npy = rggb2bayer(self.denoised_data) | |
| # 保存结果 | |
| result = self._generate_result() | |
| log("图像增强完成,请查看结果") | |
| gr.Success("图像增强完成,请查看结果") | |
| return result | |
| except Exception as e: | |
| gr.Error(f"图像增强失败: {str(e)}") | |
| raise RuntimeError(f"增强失败: {str(e)}") from e | |
| # 私有工具方法 ------------------------------------------------------------ | |
| def _extract_color_params(self, raw): | |
| """从RAW文件中提取颜色参数""" | |
| wb = np.array(raw.camera_whitebalance) / raw.camera_whitebalance[1] | |
| ccm = raw.color_matrix[:3, :3].astype(np.float32) | |
| return wb, ccm if ccm[0,0] != 0 else np.eye(3) | |
| def _generate_preview(self): | |
| """生成预览图像""" | |
| processed = (self.raw_data - self.p['bl']) / self.p['scale'] | |
| rgb = FastISP(bayer2rggb(processed)*self.p['ratio']*self.p['ispgain'], | |
| self.p['wb'], self.p['ccm']) | |
| rgb = (rgb.clip(0, 1) * 255).astype(np.uint8) | |
| preview_img = Image.fromarray(rgb) | |
| return preview_img | |
| def _visualize_mask(self): | |
| """可视化噪声掩模""" | |
| from matplotlib import pyplot as plt | |
| # 检查是否为单通道mask | |
| if self.mask_data.ndim != 2: | |
| gr.Error("Input mask must be a 2D array") | |
| raise ValueError("Input mask must be a 2D array") | |
| # 创建viridis颜色映射的查找表 | |
| cmap = plt.cm.viridis | |
| x = np.linspace(0, 1, 256) | |
| lut = (cmap(x)[:, :3] * 255).astype(np.uint8) | |
| # 将mask值缩放到0-255范围并转换为整数索引 | |
| mask_indices = (np.clip(self.mask_data, 0, 1) * 255).astype(np.uint8) | |
| # 使用高级索引进行向量化映射 | |
| rgb_img = lut[mask_indices] | |
| # 缩放并转换为PIL图像 | |
| rgb_img = cv2.resize(rgb_img, (self.p['w'], self.p['h']), interpolation=cv2.INTER_LINEAR) | |
| mask_img = Image.fromarray(rgb_img) | |
| return mask_img | |
| def _generate_result(self): | |
| """保存最终结果""" | |
| rgb = FastISP(self.denoised_data*self.p['ispgain'], | |
| self.p['wb'], | |
| self.p['ccm']) | |
| self.denoised_rgb = Image.fromarray((rgb.clip(0, 1) * 255).astype(np.uint8)) | |
| return self.denoised_rgb | |
| def save_result_npy(self): | |
| """保存结果到 NPY 文件""" | |
| if self.denoised_npy is None: | |
| gr.Error("请先进行图像增强") | |
| raise RuntimeError("请先进行图像增强") | |
| with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as tmp_file: | |
| tmp_file_path = tmp_file.name | |
| np.save(tmp_file_path, self.denoised_npy.astype(np.float32)) | |
| return tmp_file_path | |
| def save_result_png(self): | |
| """保存结果到 PNG 文件""" | |
| if self.denoised_npy is None: | |
| gr.Error("请先进行图像增强") | |
| raise RuntimeError("请先进行图像增强") | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file_png: | |
| tmp_file_path_png = tmp_file_png.name | |
| cv2.imwrite(tmp_file_path_png, np.array(self.denoised_rgb)[:,:,::-1]) | |
| return tmp_file_path_png | |
| class YONDParser(): | |
| def __init__(self, yaml_path="runfiles/Gaussian/gru32n_paper_noclip.yml"): | |
| self.runfile = yaml_path | |
| self.mode = 'eval' | |
| self.debug = False | |
| self.nofig = False | |
| self.nohost = False | |
| self.gpu = 0 | |
| class YOND_anytest(): | |
| def __init__(self, yaml_path, device): | |
| # 初始化 | |
| self.device = device | |
| self.parser = YONDParser(yaml_path) | |
| self.initialization() | |
| def initialization(self): | |
| with open(self.parser.runfile, 'r', encoding="utf-8") as f: | |
| self.args = yaml.load(f.read(), Loader=yaml.FullLoader) | |
| self.mode = self.args['mode'] if self.parser.mode is None else self.parser.mode | |
| if self.parser.debug is True: | |
| self.args['num_workers'] = 0 | |
| warnings.warn('You are using debug mode, only main worker(cpu) is used!!!') | |
| if 'clip' not in self.args['dst']: | |
| self.args['dst']['clip'] = False | |
| self.save_plot = False if self.parser.nofig else True | |
| self.args['dst']['mode'] = self.mode | |
| self.hostname, self.hostpath, self.multi_gpu = get_host_with_dir() | |
| self.yond_dir = self.args['checkpoint'] | |
| if not self.parser.nohost: | |
| for key in self.args: | |
| if 'dst' in key: | |
| self.args[key]['root_dir'] = f"{self.hostpath}/{self.args[key]['root_dir']}" | |
| self.dst = self.args['dst'] | |
| self.arch = self.args['arch'] | |
| self.pipe = self.args['pipeline'] | |
| if self.pipe['bias_corr'] == 'none': | |
| self.pipe['bias_corr'] = None | |
| self.yond_name = self.args['model_name'] | |
| self.method_name = self.args['method_name'] | |
| self.fast_ckpt = self.args['fast_ckpt'] | |
| self.sample_dir = os.path.join(self.args['result_dir'] ,f"{self.method_name}") | |
| os.makedirs(self.sample_dir, exist_ok=True) | |
| os.makedirs('./logs', exist_ok=True) | |
| #os.makedirs('./metrics', exist_ok=True) | |
| def load_model(self, model_path): | |
| # 模型加载 | |
| self.net = globals()[self.arch['name']](self.arch) | |
| model = torch.load(model_path, map_location='cpu') | |
| self.net = load_weights(self.net, model, by_name=False) |