|
|
import sys |
|
|
import os |
|
|
import torch |
|
|
import json |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
|
|
|
sys.path.append(os.path.join(os.getcwd(), 'DA-2-repo/src')) |
|
|
|
|
|
try: |
|
|
from da2.model.spherevit import SphereViT |
|
|
except ImportError as e: |
|
|
print(f"Error importing SphereViT: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
config_path = 'DA-2-repo/configs/infer.json' |
|
|
with open(config_path, 'r') as f: |
|
|
config = json.load(f) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
H, W = 546, 1092 |
|
|
config['inference']['min_pixels'] = H * W |
|
|
config['inference']['max_pixels'] = H * W |
|
|
|
|
|
print(f"Initializing model with input size {W}x{H}...") |
|
|
|
|
|
model = SphereViT(config) |
|
|
|
|
|
|
|
|
print("Loading weights from model.safetensors...") |
|
|
try: |
|
|
weights = load_file('model.safetensors') |
|
|
missing, unexpected = model.load_state_dict(weights, strict=False) |
|
|
if missing: |
|
|
print(f"Missing keys: {len(missing)}") |
|
|
|
|
|
if unexpected: |
|
|
print(f"Unexpected keys: {len(unexpected)}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading weights: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
print("Exporting model in FP32 (full precision)...") |
|
|
model.eval() |
|
|
|
|
|
|
|
|
dummy_input = torch.randn(1, 3, H, W) |
|
|
|
|
|
|
|
|
output_file = "onnx/model.onnx" |
|
|
print(f"Exporting to {output_file}...") |
|
|
try: |
|
|
torch.onnx.export( |
|
|
model, |
|
|
dummy_input, |
|
|
output_file, |
|
|
opset_version=17, |
|
|
input_names=["pixel_values"], |
|
|
output_names=["predicted_depth"], |
|
|
dynamic_axes={ |
|
|
"pixel_values": {0: "batch_size"}, |
|
|
"predicted_depth": {0: "batch_size"} |
|
|
}, |
|
|
export_params=True, |
|
|
do_constant_folding=True, |
|
|
verbose=False |
|
|
) |
|
|
|
|
|
print(f"Successfully exported to {output_file}") |
|
|
|
|
|
try: |
|
|
from onnxruntime.quantization import quantize_dynamic, QuantType |
|
|
quantized_output_file = "onnx/model_quantized.onnx" |
|
|
print(f"Quantizing model to {quantized_output_file}...") |
|
|
quantize_dynamic( |
|
|
output_file, |
|
|
quantized_output_file, |
|
|
weight_type=QuantType.QInt8 |
|
|
) |
|
|
print(f"Successfully quantized to {quantized_output_file}") |
|
|
except Exception as qe: |
|
|
print(f"Error during quantization: {qe}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
except Exception as e: |
|
|
print(f"Error exporting to ONNX: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|