File size: 4,783 Bytes
7382c66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
const MODEL_PATH = '../onnx/model.onnx'; // Relative path to the model
const INPUT_WIDTH = 1092;
const INPUT_HEIGHT = 546;

let session = null;
const statusElement = document.getElementById('status');
const runBtn = document.getElementById('runBtn');
const imageInput = document.getElementById('imageInput');
const inputCanvas = document.getElementById('inputCanvas');
const outputCanvas = document.getElementById('outputCanvas');
const inputCtx = inputCanvas.getContext('2d');
const outputCtx = outputCanvas.getContext('2d');

// Initialize ONNX Runtime
async function init() {
    try {
        // Enable verbose logging
        ort.env.debug = true;
        ort.env.logLevel = 'verbose';

        statusElement.textContent = 'Loading model... (this may take a while)';
        // Configure session options for WebGPU
        const options = {
            executionProviders: ['webgpu'],
        };
        session = await ort.InferenceSession.create(MODEL_PATH, options);
        statusElement.textContent = 'Model loaded. Ready.';
        runBtn.disabled = false;
    } catch (e) {
        console.error(e);
        statusElement.textContent = 'Error loading model: ' + e.message;
        // Fallback to wasm if webgpu fails
        try {
            statusElement.textContent = 'WebGPU failed, trying WASM...';
            session = await ort.InferenceSession.create(MODEL_PATH, { executionProviders: ['wasm'] });
            statusElement.textContent = 'Model loaded (WASM). Ready.';
            runBtn.disabled = false;
        } catch (e2) {
            statusElement.textContent = 'Error loading model (WASM): ' + e2.message;
        }
    }
}

imageInput.addEventListener('change', (e) => {
    const file = e.target.files[0];
    if (!file) return;

    const img = new Image();
    img.onload = () => {
        inputCanvas.width = INPUT_WIDTH;
        inputCanvas.height = INPUT_HEIGHT;
        inputCtx.drawImage(img, 0, 0, INPUT_WIDTH, INPUT_HEIGHT);
        
        // Clear output
        outputCanvas.width = INPUT_WIDTH;
        outputCanvas.height = INPUT_HEIGHT;
        outputCtx.clearRect(0, 0, INPUT_WIDTH, INPUT_HEIGHT);
    };
    img.src = URL.createObjectURL(file);
});

runBtn.addEventListener('click', async () => {
    if (!session) return;
    
    statusElement.textContent = 'Running inference...';
    runBtn.disabled = true;

    try {
        // Preprocess
        const imageData = inputCtx.getImageData(0, 0, INPUT_WIDTH, INPUT_HEIGHT);
        const tensor = preprocess(imageData);

        // Run inference
        const feeds = { input_image: tensor };
        const results = await session.run(feeds);
        const output = results.depth_map;

        // Postprocess and visualize
        visualize(output.data, INPUT_WIDTH, INPUT_HEIGHT);
        statusElement.textContent = 'Done.';
    } catch (e) {
        console.error(e);
        statusElement.textContent = 'Error running inference: ' + e.message;
    } finally {
        runBtn.disabled = false;
    }
});

function preprocess(imageData) {
    const { data, width, height } = imageData;
    const float32Data = new Float32Array(3 * width * height);
    
    // The model expects 0-1 inputs and handles normalization internally
    for (let i = 0; i < width * height; i++) {
        const r = data[i * 4] / 255.0;
        const g = data[i * 4 + 1] / 255.0;
        const b = data[i * 4 + 2] / 255.0;

        float32Data[i] = r; // R
        float32Data[width * height + i] = g; // G
        float32Data[2 * width * height + i] = b; // B
    }

    return new ort.Tensor('float32', float32Data, [1, 3, height, width]);
}

function visualize(data, width, height) {
    // Find min and max for normalization
    let min = Infinity;
    let max = -Infinity;
    for (let i = 0; i < data.length; i++) {
        if (data[i] < min) min = data[i];
        if (data[i] > max) max = data[i];
    }

    const range = max - min;
    const imageData = outputCtx.createImageData(width, height);
    
    for (let i = 0; i < data.length; i++) {
        // Normalize to 0-1
        const val = (data[i] - min) / (range || 1);
        
        // Simple heatmap (Magma-like or just grayscale)
        // Let's do grayscale for simplicity, or a simple color map
        // Inverted depth usually looks better (closer is brighter)
        // But here it's distance, so closer is smaller value.
        // If we map min (close) to 255 (white) and max (far) to 0 (black)
        
        const pixelVal = Math.floor((1 - val) * 255);

        imageData.data[i * 4] = pixelVal; // R
        imageData.data[i * 4 + 1] = pixelVal; // G
        imageData.data[i * 4 + 2] = pixelVal; // B
        imageData.data[i * 4 + 3] = 255; // Alpha
    }
    
    outputCtx.putImageData(imageData, 0, 0);
}

init();