const MODEL_PATH = '../onnx/model.onnx'; // Relative path to the model const INPUT_WIDTH = 1092; const INPUT_HEIGHT = 546; let session = null; const statusElement = document.getElementById('status'); const runBtn = document.getElementById('runBtn'); const imageInput = document.getElementById('imageInput'); const inputCanvas = document.getElementById('inputCanvas'); const outputCanvas = document.getElementById('outputCanvas'); const inputCtx = inputCanvas.getContext('2d'); const outputCtx = outputCanvas.getContext('2d'); // Initialize ONNX Runtime async function init() { try { // Enable verbose logging ort.env.debug = true; ort.env.logLevel = 'verbose'; statusElement.textContent = 'Loading model... (this may take a while)'; // Configure session options for WebGPU const options = { executionProviders: ['webgpu'], }; session = await ort.InferenceSession.create(MODEL_PATH, options); statusElement.textContent = 'Model loaded. Ready.'; runBtn.disabled = false; } catch (e) { console.error(e); statusElement.textContent = 'Error loading model: ' + e.message; // Fallback to wasm if webgpu fails try { statusElement.textContent = 'WebGPU failed, trying WASM...'; session = await ort.InferenceSession.create(MODEL_PATH, { executionProviders: ['wasm'] }); statusElement.textContent = 'Model loaded (WASM). Ready.'; runBtn.disabled = false; } catch (e2) { statusElement.textContent = 'Error loading model (WASM): ' + e2.message; } } } imageInput.addEventListener('change', (e) => { const file = e.target.files[0]; if (!file) return; const img = new Image(); img.onload = () => { inputCanvas.width = INPUT_WIDTH; inputCanvas.height = INPUT_HEIGHT; inputCtx.drawImage(img, 0, 0, INPUT_WIDTH, INPUT_HEIGHT); // Clear output outputCanvas.width = INPUT_WIDTH; outputCanvas.height = INPUT_HEIGHT; outputCtx.clearRect(0, 0, INPUT_WIDTH, INPUT_HEIGHT); }; img.src = URL.createObjectURL(file); }); runBtn.addEventListener('click', async () => { if (!session) return; statusElement.textContent = 'Running inference...'; runBtn.disabled = true; try { // Preprocess const imageData = inputCtx.getImageData(0, 0, INPUT_WIDTH, INPUT_HEIGHT); const tensor = preprocess(imageData); // Run inference const feeds = { input_image: tensor }; const results = await session.run(feeds); const output = results.depth_map; // Postprocess and visualize visualize(output.data, INPUT_WIDTH, INPUT_HEIGHT); statusElement.textContent = 'Done.'; } catch (e) { console.error(e); statusElement.textContent = 'Error running inference: ' + e.message; } finally { runBtn.disabled = false; } }); function preprocess(imageData) { const { data, width, height } = imageData; const float32Data = new Float32Array(3 * width * height); // The model expects 0-1 inputs and handles normalization internally for (let i = 0; i < width * height; i++) { const r = data[i * 4] / 255.0; const g = data[i * 4 + 1] / 255.0; const b = data[i * 4 + 2] / 255.0; float32Data[i] = r; // R float32Data[width * height + i] = g; // G float32Data[2 * width * height + i] = b; // B } return new ort.Tensor('float32', float32Data, [1, 3, height, width]); } function visualize(data, width, height) { // Find min and max for normalization let min = Infinity; let max = -Infinity; for (let i = 0; i < data.length; i++) { if (data[i] < min) min = data[i]; if (data[i] > max) max = data[i]; } const range = max - min; const imageData = outputCtx.createImageData(width, height); for (let i = 0; i < data.length; i++) { // Normalize to 0-1 const val = (data[i] - min) / (range || 1); // Simple heatmap (Magma-like or just grayscale) // Let's do grayscale for simplicity, or a simple color map // Inverted depth usually looks better (closer is brighter) // But here it's distance, so closer is smaller value. // If we map min (close) to 255 (white) and max (far) to 0 (black) const pixelVal = Math.floor((1 - val) * 255); imageData.data[i * 4] = pixelVal; // R imageData.data[i * 4 + 1] = pixelVal; // G imageData.data[i * 4 + 2] = pixelVal; // B imageData.data[i * 4 + 3] = 255; // Alpha } outputCtx.putImageData(imageData, 0, 0); } init();