File size: 4,783 Bytes
7382c66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
const MODEL_PATH = '../onnx/model.onnx'; // Relative path to the model
const INPUT_WIDTH = 1092;
const INPUT_HEIGHT = 546;
let session = null;
const statusElement = document.getElementById('status');
const runBtn = document.getElementById('runBtn');
const imageInput = document.getElementById('imageInput');
const inputCanvas = document.getElementById('inputCanvas');
const outputCanvas = document.getElementById('outputCanvas');
const inputCtx = inputCanvas.getContext('2d');
const outputCtx = outputCanvas.getContext('2d');
// Initialize ONNX Runtime
async function init() {
try {
// Enable verbose logging
ort.env.debug = true;
ort.env.logLevel = 'verbose';
statusElement.textContent = 'Loading model... (this may take a while)';
// Configure session options for WebGPU
const options = {
executionProviders: ['webgpu'],
};
session = await ort.InferenceSession.create(MODEL_PATH, options);
statusElement.textContent = 'Model loaded. Ready.';
runBtn.disabled = false;
} catch (e) {
console.error(e);
statusElement.textContent = 'Error loading model: ' + e.message;
// Fallback to wasm if webgpu fails
try {
statusElement.textContent = 'WebGPU failed, trying WASM...';
session = await ort.InferenceSession.create(MODEL_PATH, { executionProviders: ['wasm'] });
statusElement.textContent = 'Model loaded (WASM). Ready.';
runBtn.disabled = false;
} catch (e2) {
statusElement.textContent = 'Error loading model (WASM): ' + e2.message;
}
}
}
imageInput.addEventListener('change', (e) => {
const file = e.target.files[0];
if (!file) return;
const img = new Image();
img.onload = () => {
inputCanvas.width = INPUT_WIDTH;
inputCanvas.height = INPUT_HEIGHT;
inputCtx.drawImage(img, 0, 0, INPUT_WIDTH, INPUT_HEIGHT);
// Clear output
outputCanvas.width = INPUT_WIDTH;
outputCanvas.height = INPUT_HEIGHT;
outputCtx.clearRect(0, 0, INPUT_WIDTH, INPUT_HEIGHT);
};
img.src = URL.createObjectURL(file);
});
runBtn.addEventListener('click', async () => {
if (!session) return;
statusElement.textContent = 'Running inference...';
runBtn.disabled = true;
try {
// Preprocess
const imageData = inputCtx.getImageData(0, 0, INPUT_WIDTH, INPUT_HEIGHT);
const tensor = preprocess(imageData);
// Run inference
const feeds = { input_image: tensor };
const results = await session.run(feeds);
const output = results.depth_map;
// Postprocess and visualize
visualize(output.data, INPUT_WIDTH, INPUT_HEIGHT);
statusElement.textContent = 'Done.';
} catch (e) {
console.error(e);
statusElement.textContent = 'Error running inference: ' + e.message;
} finally {
runBtn.disabled = false;
}
});
function preprocess(imageData) {
const { data, width, height } = imageData;
const float32Data = new Float32Array(3 * width * height);
// The model expects 0-1 inputs and handles normalization internally
for (let i = 0; i < width * height; i++) {
const r = data[i * 4] / 255.0;
const g = data[i * 4 + 1] / 255.0;
const b = data[i * 4 + 2] / 255.0;
float32Data[i] = r; // R
float32Data[width * height + i] = g; // G
float32Data[2 * width * height + i] = b; // B
}
return new ort.Tensor('float32', float32Data, [1, 3, height, width]);
}
function visualize(data, width, height) {
// Find min and max for normalization
let min = Infinity;
let max = -Infinity;
for (let i = 0; i < data.length; i++) {
if (data[i] < min) min = data[i];
if (data[i] > max) max = data[i];
}
const range = max - min;
const imageData = outputCtx.createImageData(width, height);
for (let i = 0; i < data.length; i++) {
// Normalize to 0-1
const val = (data[i] - min) / (range || 1);
// Simple heatmap (Magma-like or just grayscale)
// Let's do grayscale for simplicity, or a simple color map
// Inverted depth usually looks better (closer is brighter)
// But here it's distance, so closer is smaller value.
// If we map min (close) to 255 (white) and max (far) to 0 (black)
const pixelVal = Math.floor((1 - val) * 255);
imageData.data[i * 4] = pixelVal; // R
imageData.data[i * 4 + 1] = pixelVal; // G
imageData.data[i * 4 + 2] = pixelVal; // B
imageData.data[i * 4 + 3] = 255; // Alpha
}
outputCtx.putImageData(imageData, 0, 0);
}
init(); |