Spaces:
Running
on
Zero
Running
on
Zero
net.to(self.device) update
Browse files- app_function.py +5 -2
app_function.py
CHANGED
|
@@ -260,6 +260,8 @@ class YOND_Backend:
|
|
| 260 |
raw_vst = (raw_vst - lower) / (upper - lower)
|
| 261 |
|
| 262 |
################# 准备去噪 #################
|
|
|
|
|
|
|
| 263 |
raw_vst = torch.from_numpy(raw_vst).float().to(self.device).permute(2,0,1)[None,]
|
| 264 |
if 'guided' in self.yond.arch:
|
| 265 |
sigma_corr = 1.03
|
|
@@ -333,6 +335,8 @@ class YOND_Backend:
|
|
| 333 |
raw_vst = (raw_vst - lower) / (upper - lower)
|
| 334 |
|
| 335 |
################# 准备去噪 #################
|
|
|
|
|
|
|
| 336 |
raw_vst = torch.from_numpy(raw_vst).float().to(self.device).permute(2,0,1)[None,]
|
| 337 |
if 'guided' in self.yond.arch:
|
| 338 |
t = torch.tensor(nsr*self.p['sigsnr'], dtype=raw_vst.dtype, device=self.device)
|
|
@@ -493,5 +497,4 @@ class YOND_anytest():
|
|
| 493 |
# 模型加载
|
| 494 |
self.net = globals()[self.arch['name']](self.arch)
|
| 495 |
model = torch.load(model_path, map_location='cpu')
|
| 496 |
-
self.net = load_weights(self.net, model, by_name=False)
|
| 497 |
-
self.net = self.net.to(self.device)
|
|
|
|
| 260 |
raw_vst = (raw_vst - lower) / (upper - lower)
|
| 261 |
|
| 262 |
################# 准备去噪 #################
|
| 263 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 264 |
+
self.yond.net = self.yond.net.to(self.device)
|
| 265 |
raw_vst = torch.from_numpy(raw_vst).float().to(self.device).permute(2,0,1)[None,]
|
| 266 |
if 'guided' in self.yond.arch:
|
| 267 |
sigma_corr = 1.03
|
|
|
|
| 335 |
raw_vst = (raw_vst - lower) / (upper - lower)
|
| 336 |
|
| 337 |
################# 准备去噪 #################
|
| 338 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 339 |
+
self.yond.net = self.yond.net.to(self.device)
|
| 340 |
raw_vst = torch.from_numpy(raw_vst).float().to(self.device).permute(2,0,1)[None,]
|
| 341 |
if 'guided' in self.yond.arch:
|
| 342 |
t = torch.tensor(nsr*self.p['sigsnr'], dtype=raw_vst.dtype, device=self.device)
|
|
|
|
| 497 |
# 模型加载
|
| 498 |
self.net = globals()[self.arch['name']](self.arch)
|
| 499 |
model = torch.load(model_path, map_location='cpu')
|
| 500 |
+
self.net = load_weights(self.net, model, by_name=False)
|
|
|