hansen97 commited on
Commit
1f3466c
·
1 Parent(s): c173203

net.to(self.device) update

Browse files
Files changed (1) hide show
  1. app_function.py +5 -2
app_function.py CHANGED
@@ -260,6 +260,8 @@ class YOND_Backend:
260
  raw_vst = (raw_vst - lower) / (upper - lower)
261
 
262
  ################# 准备去噪 #################
 
 
263
  raw_vst = torch.from_numpy(raw_vst).float().to(self.device).permute(2,0,1)[None,]
264
  if 'guided' in self.yond.arch:
265
  sigma_corr = 1.03
@@ -333,6 +335,8 @@ class YOND_Backend:
333
  raw_vst = (raw_vst - lower) / (upper - lower)
334
 
335
  ################# 准备去噪 #################
 
 
336
  raw_vst = torch.from_numpy(raw_vst).float().to(self.device).permute(2,0,1)[None,]
337
  if 'guided' in self.yond.arch:
338
  t = torch.tensor(nsr*self.p['sigsnr'], dtype=raw_vst.dtype, device=self.device)
@@ -493,5 +497,4 @@ class YOND_anytest():
493
  # 模型加载
494
  self.net = globals()[self.arch['name']](self.arch)
495
  model = torch.load(model_path, map_location='cpu')
496
- self.net = load_weights(self.net, model, by_name=False)
497
- self.net = self.net.to(self.device)
 
260
  raw_vst = (raw_vst - lower) / (upper - lower)
261
 
262
  ################# 准备去噪 #################
263
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
264
+ self.yond.net = self.yond.net.to(self.device)
265
  raw_vst = torch.from_numpy(raw_vst).float().to(self.device).permute(2,0,1)[None,]
266
  if 'guided' in self.yond.arch:
267
  sigma_corr = 1.03
 
335
  raw_vst = (raw_vst - lower) / (upper - lower)
336
 
337
  ################# 准备去噪 #################
338
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
339
+ self.yond.net = self.yond.net.to(self.device)
340
  raw_vst = torch.from_numpy(raw_vst).float().to(self.device).permute(2,0,1)[None,]
341
  if 'guided' in self.yond.arch:
342
  t = torch.tensor(nsr*self.p['sigsnr'], dtype=raw_vst.dtype, device=self.device)
 
497
  # 模型加载
498
  self.net = globals()[self.arch['name']](self.arch)
499
  model = torch.load(model_path, map_location='cpu')
500
+ self.net = load_weights(self.net, model, by_name=False)