Spaces:
Runtime error
Runtime error
Commit
·
725545d
1
Parent(s):
3e5a852
update space
Browse files
app.py
CHANGED
|
@@ -5,7 +5,7 @@ from functools import partial
|
|
| 5 |
from typing import Optional
|
| 6 |
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
|
| 7 |
from shap_e.diffusion.sample import sample_latents
|
| 8 |
-
from shap_e.models.download import load_model, load_config
|
| 9 |
from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh
|
| 10 |
import trimesh
|
| 11 |
import torch.nn as nn
|
|
@@ -275,10 +275,25 @@ def main():
|
|
| 275 |
"""
|
| 276 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 277 |
print("device:", device)
|
| 278 |
-
latent_model = load_model('text300M', device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
print("loaded latent model")
|
| 280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
print("loaded transmitter")
|
|
|
|
|
|
|
| 282 |
diffusion = diffusion_from_config(load_config('diffusion'))
|
| 283 |
freeze_params(xm.parameters())
|
| 284 |
models = dict()
|
|
|
|
| 5 |
from typing import Optional
|
| 6 |
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
|
| 7 |
from shap_e.diffusion.sample import sample_latents
|
| 8 |
+
from shap_e.models.download import load_model, load_config, load_checkpoint
|
| 9 |
from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh
|
| 10 |
import trimesh
|
| 11 |
import torch.nn as nn
|
|
|
|
| 275 |
"""
|
| 276 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 277 |
print("device:", device)
|
| 278 |
+
# latent_model = load_model('text300M', device=device)
|
| 279 |
+
|
| 280 |
+
latent_model = model_from_config(load_config('text300M'), device=device)
|
| 281 |
+
# print(model_name, kwargs)
|
| 282 |
+
# print(model)
|
| 283 |
+
latent_model.load_state_dict(load_checkpoint('text300M', device='cpu'))
|
| 284 |
+
latent_model.eval()
|
| 285 |
print("loaded latent model")
|
| 286 |
+
latent_model.to(device)
|
| 287 |
+
# xm = load_model('transmitter', device=device)
|
| 288 |
+
|
| 289 |
+
xm = model_from_config(load_config('transmitter'), device=device)
|
| 290 |
+
# print(model_name, kwargs)
|
| 291 |
+
# print(model)
|
| 292 |
+
xm.load_state_dict(load_checkpoint('transmitter', device='cpu'))
|
| 293 |
+
xm.eval()
|
| 294 |
print("loaded transmitter")
|
| 295 |
+
xm.to(device)
|
| 296 |
+
|
| 297 |
diffusion = diffusion_from_config(load_config('diffusion'))
|
| 298 |
freeze_params(xm.parameters())
|
| 299 |
models = dict()
|