phiph's picture
Upload folder using huggingface_hub
7382c66 verified
const MODEL_PATH = '../onnx/model.onnx'; // Relative path to the model
const INPUT_WIDTH = 1092;
const INPUT_HEIGHT = 546;
let session = null;
const statusElement = document.getElementById('status');
const runBtn = document.getElementById('runBtn');
const imageInput = document.getElementById('imageInput');
const inputCanvas = document.getElementById('inputCanvas');
const outputCanvas = document.getElementById('outputCanvas');
const inputCtx = inputCanvas.getContext('2d');
const outputCtx = outputCanvas.getContext('2d');
// Initialize ONNX Runtime
async function init() {
try {
// Enable verbose logging
ort.env.debug = true;
ort.env.logLevel = 'verbose';
statusElement.textContent = 'Loading model... (this may take a while)';
// Configure session options for WebGPU
const options = {
executionProviders: ['webgpu'],
};
session = await ort.InferenceSession.create(MODEL_PATH, options);
statusElement.textContent = 'Model loaded. Ready.';
runBtn.disabled = false;
} catch (e) {
console.error(e);
statusElement.textContent = 'Error loading model: ' + e.message;
// Fallback to wasm if webgpu fails
try {
statusElement.textContent = 'WebGPU failed, trying WASM...';
session = await ort.InferenceSession.create(MODEL_PATH, { executionProviders: ['wasm'] });
statusElement.textContent = 'Model loaded (WASM). Ready.';
runBtn.disabled = false;
} catch (e2) {
statusElement.textContent = 'Error loading model (WASM): ' + e2.message;
}
}
}
imageInput.addEventListener('change', (e) => {
const file = e.target.files[0];
if (!file) return;
const img = new Image();
img.onload = () => {
inputCanvas.width = INPUT_WIDTH;
inputCanvas.height = INPUT_HEIGHT;
inputCtx.drawImage(img, 0, 0, INPUT_WIDTH, INPUT_HEIGHT);
// Clear output
outputCanvas.width = INPUT_WIDTH;
outputCanvas.height = INPUT_HEIGHT;
outputCtx.clearRect(0, 0, INPUT_WIDTH, INPUT_HEIGHT);
};
img.src = URL.createObjectURL(file);
});
runBtn.addEventListener('click', async () => {
if (!session) return;
statusElement.textContent = 'Running inference...';
runBtn.disabled = true;
try {
// Preprocess
const imageData = inputCtx.getImageData(0, 0, INPUT_WIDTH, INPUT_HEIGHT);
const tensor = preprocess(imageData);
// Run inference
const feeds = { input_image: tensor };
const results = await session.run(feeds);
const output = results.depth_map;
// Postprocess and visualize
visualize(output.data, INPUT_WIDTH, INPUT_HEIGHT);
statusElement.textContent = 'Done.';
} catch (e) {
console.error(e);
statusElement.textContent = 'Error running inference: ' + e.message;
} finally {
runBtn.disabled = false;
}
});
function preprocess(imageData) {
const { data, width, height } = imageData;
const float32Data = new Float32Array(3 * width * height);
// The model expects 0-1 inputs and handles normalization internally
for (let i = 0; i < width * height; i++) {
const r = data[i * 4] / 255.0;
const g = data[i * 4 + 1] / 255.0;
const b = data[i * 4 + 2] / 255.0;
float32Data[i] = r; // R
float32Data[width * height + i] = g; // G
float32Data[2 * width * height + i] = b; // B
}
return new ort.Tensor('float32', float32Data, [1, 3, height, width]);
}
function visualize(data, width, height) {
// Find min and max for normalization
let min = Infinity;
let max = -Infinity;
for (let i = 0; i < data.length; i++) {
if (data[i] < min) min = data[i];
if (data[i] > max) max = data[i];
}
const range = max - min;
const imageData = outputCtx.createImageData(width, height);
for (let i = 0; i < data.length; i++) {
// Normalize to 0-1
const val = (data[i] - min) / (range || 1);
// Simple heatmap (Magma-like or just grayscale)
// Let's do grayscale for simplicity, or a simple color map
// Inverted depth usually looks better (closer is brighter)
// But here it's distance, so closer is smaller value.
// If we map min (close) to 255 (white) and max (far) to 0 (black)
const pixelVal = Math.floor((1 - val) * 255);
imageData.data[i * 4] = pixelVal; // R
imageData.data[i * 4 + 1] = pixelVal; // G
imageData.data[i * 4 + 2] = pixelVal; // B
imageData.data[i * 4 + 3] = 255; // Alpha
}
outputCtx.putImageData(imageData, 0, 0);
}
init();