File size: 4,467 Bytes
0e07d71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from .utils import *

def scale_down(img):
    return np.float32(img) / 255.

def scale_up(img):
    return np.uint8(img.clip(0,1) * 255.)

def tensor2im(image_tensor, visualize=False, video=False):    
    image_tensor = image_tensor.detach()

    if visualize:                
        image_tensor = image_tensor[:, 0:3, ...]

    if not video: 
        image_numpy = image_tensor[0].cpu().float().numpy()
        image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0
    else:
        image_numpy = image_tensor.cpu().float().numpy()
        image_numpy = (np.transpose(image_numpy, (0, 2, 3, 1))) * 255.0

    image_numpy = np.clip(image_numpy, 0, 255)

    return image_numpy

def quality_assess(X, Y, data_range=255):
    # Y: correct; X: estimate
    if X.ndim == 3:
        psnr = compare_psnr(Y, X, data_range=data_range)
        ssim = compare_ssim(Y, X, data_range=data_range, channel_axis=-1)
        return {'PSNR':psnr, 'SSIM': ssim}
    else:
        raise NotImplementedError

def feature_vis(tensor, name='out', save=False, first_only=True):
    feature = tensor.detach().cpu().numpy().transpose(0,2,3,1)
    if save:
        if feature.min() < 0 or feature.max()>1:
            warnings.warn('Signals are clipped to [0, 1] for visualization!!!!')
        os.makedirs('worklog/feature_vis', exist_ok=True)
        for i in range(len(feature)):
            cv2.imwrite(f'worklog/feature_vis/{name}_{i}.png', np.uint8(feature[i,:,:,::-1]*255))
            if first_only: break
    return feature

def plot_sample(img_lr, img_dn, img_hr, filename='result', model_name='Unet',
                epoch=-1, print_metrics=False, save_plot=True, save_path='./', res=None):
    if np.max(img_hr) <= 1:
        # 变回uint8
        img_lr = scale_up(img_lr)
        img_dn = scale_up(img_dn)
        img_hr = scale_up(img_hr)
    # 计算PSNR和SSIM
    if res is None:
        psnr = []
        ssim = []
        psnr.append(compare_psnr(img_hr, img_lr))
        psnr.append(compare_psnr(img_hr, img_dn))
        ssim.append(compare_ssim(img_hr, img_lr, channel_axis=-1))
        ssim.append(compare_ssim(img_hr, img_dn, channel_axis=-1))
        psnr.append(-1)
        ssim.append(-1)
    else:
        psnr = [res[0], res[2], -1]
        ssim = [res[1], res[3], -1]
    # Images and titles
    images = {
        'Noisy Image': img_lr,
        model_name: img_dn,
        'Ground Truth': img_hr
    }
    if os.path.exists(save_path) is False:
        os.makedirs(save_path)
    # Plot the images. Note: rescaling and using squeeze since we are getting batches of size 1
    fig, axes = plt.subplots(1, 3, figsize=(20, 4.5))
    for i, (title, img) in enumerate(images.items()):
        axes[i].imshow(img)
        axes[i].set_title("{}\n{} - psnr:{:.2f} - ssim{:.4f}".format(title, img.shape, psnr[i], ssim[i]))
        axes[i].axis('off')
    plt.suptitle('{} - Epoch: {}'.format(filename, epoch))
    if print_metrics:
        log('PSNR:', psnr)
        log('SSIM:', ssim)
    # Save directory
    if os.path.exists(save_path) is False:
        os.makedirs(save_path)
    savefile = os.path.join(save_path, "{}-Epoch{}.jpg".format(filename, epoch))
    if save_plot:
        denoisedfile = os.path.join(save_path, "{}_denoised.png".format(filename))
        cv2.imwrite(denoisedfile, img_dn[:,:,::-1])
        fig.savefig(savefile, bbox_inches='tight')
        plt.close()
    return psnr, ssim, filename

def save_picture(img_sr, save_path='./images/test',frame_id='0000'):
    # 变回uint8
    img_sr = scale_up(img_sr.transpose(1,2,0))
    if os._exists(save_path) is not True:
        os.makedirs(save_path, exist_ok=True)
    plt.imsave(os.path.join(save_path, frame_id+'.png'), img_sr)
    gc.collect()

def save_tensor_to_numpy(tensor, name, save_path='./debug'):
    """
    将 PyTorch 张量转换为 NumPy 数组并保存到文件中    

    参数:
        tensor (torch.Tensor): 要保存的 PyTorch 张量。
        name (str): 保存的文件名(不包括路径和扩展名)。
        save_path (str): 保存路径,默认为当前目录('.')。
    """
    # 确保张量在 CPU 上
    tensor = tensor.detach().clone().cpu()
    
    # 转换为 NumPy 数组
    array = tensor.numpy()
    
    # 构造保存路径
    os.makedirs(save_path, exist_ok=True)
    full_path = f"{save_path}/{name}.npy"
    
    # 保存到文件
    np.save(full_path, array)
    print(f"Tensor saved as NumPy array to: {full_path}")