|
|
import sys |
|
|
import os |
|
|
import torch |
|
|
import json |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
|
|
|
sys.path.append(os.path.join(os.getcwd(), 'DA-2-repo/src')) |
|
|
|
|
|
try: |
|
|
from da2.model.spherevit import SphereViT |
|
|
except ImportError as e: |
|
|
print(f"Error importing SphereViT: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
config_path = 'DA-2-repo/configs/infer.json' |
|
|
with open(config_path, 'r') as f: |
|
|
config = json.load(f) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
H, W = 546, 1092 |
|
|
config['inference']['min_pixels'] = H * W |
|
|
config['inference']['max_pixels'] = H * W |
|
|
|
|
|
print(f"Initializing model with input size {W}x{H}...") |
|
|
|
|
|
model = SphereViT(config) |
|
|
|
|
|
|
|
|
print("Loading weights from model.safetensors...") |
|
|
try: |
|
|
weights = load_file('model.safetensors') |
|
|
missing, unexpected = model.load_state_dict(weights, strict=False) |
|
|
if missing: |
|
|
print(f"Missing keys: {len(missing)}") |
|
|
|
|
|
if unexpected: |
|
|
print(f"Unexpected keys: {len(unexpected)}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading weights: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
print("Exporting model in FP32 (full precision)...") |
|
|
model.eval() |
|
|
|
|
|
|
|
|
dummy_input = torch.randn(1, 3, H, W) |
|
|
|
|
|
|
|
|
output_file = "model.onnx" |
|
|
print(f"Exporting to {output_file}...") |
|
|
try: |
|
|
torch.onnx.export( |
|
|
model, |
|
|
dummy_input, |
|
|
output_file, |
|
|
opset_version=17, |
|
|
input_names=["input_image"], |
|
|
output_names=["depth_map"], |
|
|
dynamic_axes={ |
|
|
"input_image": {0: "batch_size"}, |
|
|
"depth_map": {0: "batch_size"} |
|
|
}, |
|
|
export_params=True, |
|
|
do_constant_folding=True, |
|
|
verbose=False |
|
|
) |
|
|
print(f"Successfully exported to {output_file}") |
|
|
except Exception as e: |
|
|
print(f"Error exporting to ONNX: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|