import sys import os import torch import json from safetensors.torch import load_file # Add src to path sys.path.append(os.path.join(os.getcwd(), 'DA-2-repo/src')) try: from da2.model.spherevit import SphereViT except ImportError as e: print(f"Error importing SphereViT: {e}") sys.exit(1) # Load config config_path = 'DA-2-repo/configs/infer.json' with open(config_path, 'r') as f: config = json.load(f) # Adjust config for fixed size export # Using 1092x546 (multiples of 14: 1092=78*14, 546=39*14) # This is closer to the original config's ~600k pixels H, W = 546, 1092 config['inference']['min_pixels'] = H * W config['inference']['max_pixels'] = H * W print(f"Initializing model with input size {W}x{H}...") # Instantiate model model = SphereViT(config) # Load weights print("Loading weights from model.safetensors...") try: weights = load_file('model.safetensors') missing, unexpected = model.load_state_dict(weights, strict=False) if missing: print(f"Missing keys: {len(missing)}") # print(missing[:5]) if unexpected: print(f"Unexpected keys: {len(unexpected)}") # print(unexpected[:5]) except Exception as e: print(f"Error loading weights: {e}") sys.exit(1) print("Exporting model in FP32 (full precision)...") model.eval() # Dummy input (float32) dummy_input = torch.randn(1, 3, H, W) # Export output_file = "model.onnx" print(f"Exporting to {output_file}...") try: torch.onnx.export( model, dummy_input, output_file, opset_version=17, input_names=["input_image"], output_names=["depth_map"], dynamic_axes={ "input_image": {0: "batch_size"}, "depth_map": {0: "batch_size"} }, export_params=True, do_constant_folding=True, verbose=False ) print(f"Successfully exported to {output_file}") except Exception as e: print(f"Error exporting to ONNX: {e}") import traceback traceback.print_exc()