|
|
const MODEL_PATH = '../onnx/model.onnx'; |
|
|
const INPUT_WIDTH = 1092; |
|
|
const INPUT_HEIGHT = 546; |
|
|
|
|
|
let session = null; |
|
|
const statusElement = document.getElementById('status'); |
|
|
const runBtn = document.getElementById('runBtn'); |
|
|
const imageInput = document.getElementById('imageInput'); |
|
|
const inputCanvas = document.getElementById('inputCanvas'); |
|
|
const outputCanvas = document.getElementById('outputCanvas'); |
|
|
const inputCtx = inputCanvas.getContext('2d'); |
|
|
const outputCtx = outputCanvas.getContext('2d'); |
|
|
|
|
|
|
|
|
async function init() { |
|
|
try { |
|
|
|
|
|
ort.env.debug = true; |
|
|
ort.env.logLevel = 'verbose'; |
|
|
|
|
|
statusElement.textContent = 'Loading model... (this may take a while)'; |
|
|
|
|
|
const options = { |
|
|
executionProviders: ['webgpu'], |
|
|
}; |
|
|
session = await ort.InferenceSession.create(MODEL_PATH, options); |
|
|
statusElement.textContent = 'Model loaded. Ready.'; |
|
|
runBtn.disabled = false; |
|
|
} catch (e) { |
|
|
console.error(e); |
|
|
statusElement.textContent = 'Error loading model: ' + e.message; |
|
|
|
|
|
try { |
|
|
statusElement.textContent = 'WebGPU failed, trying WASM...'; |
|
|
session = await ort.InferenceSession.create(MODEL_PATH, { executionProviders: ['wasm'] }); |
|
|
statusElement.textContent = 'Model loaded (WASM). Ready.'; |
|
|
runBtn.disabled = false; |
|
|
} catch (e2) { |
|
|
statusElement.textContent = 'Error loading model (WASM): ' + e2.message; |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
imageInput.addEventListener('change', (e) => { |
|
|
const file = e.target.files[0]; |
|
|
if (!file) return; |
|
|
|
|
|
const img = new Image(); |
|
|
img.onload = () => { |
|
|
inputCanvas.width = INPUT_WIDTH; |
|
|
inputCanvas.height = INPUT_HEIGHT; |
|
|
inputCtx.drawImage(img, 0, 0, INPUT_WIDTH, INPUT_HEIGHT); |
|
|
|
|
|
|
|
|
outputCanvas.width = INPUT_WIDTH; |
|
|
outputCanvas.height = INPUT_HEIGHT; |
|
|
outputCtx.clearRect(0, 0, INPUT_WIDTH, INPUT_HEIGHT); |
|
|
}; |
|
|
img.src = URL.createObjectURL(file); |
|
|
}); |
|
|
|
|
|
runBtn.addEventListener('click', async () => { |
|
|
if (!session) return; |
|
|
|
|
|
statusElement.textContent = 'Running inference...'; |
|
|
runBtn.disabled = true; |
|
|
|
|
|
try { |
|
|
|
|
|
const imageData = inputCtx.getImageData(0, 0, INPUT_WIDTH, INPUT_HEIGHT); |
|
|
const tensor = preprocess(imageData); |
|
|
|
|
|
|
|
|
const feeds = { input_image: tensor }; |
|
|
const results = await session.run(feeds); |
|
|
const output = results.depth_map; |
|
|
|
|
|
|
|
|
visualize(output.data, INPUT_WIDTH, INPUT_HEIGHT); |
|
|
statusElement.textContent = 'Done.'; |
|
|
} catch (e) { |
|
|
console.error(e); |
|
|
statusElement.textContent = 'Error running inference: ' + e.message; |
|
|
} finally { |
|
|
runBtn.disabled = false; |
|
|
} |
|
|
}); |
|
|
|
|
|
function preprocess(imageData) { |
|
|
const { data, width, height } = imageData; |
|
|
const float32Data = new Float32Array(3 * width * height); |
|
|
|
|
|
|
|
|
for (let i = 0; i < width * height; i++) { |
|
|
const r = data[i * 4] / 255.0; |
|
|
const g = data[i * 4 + 1] / 255.0; |
|
|
const b = data[i * 4 + 2] / 255.0; |
|
|
|
|
|
float32Data[i] = r; |
|
|
float32Data[width * height + i] = g; |
|
|
float32Data[2 * width * height + i] = b; |
|
|
} |
|
|
|
|
|
return new ort.Tensor('float32', float32Data, [1, 3, height, width]); |
|
|
} |
|
|
|
|
|
function visualize(data, width, height) { |
|
|
|
|
|
let min = Infinity; |
|
|
let max = -Infinity; |
|
|
for (let i = 0; i < data.length; i++) { |
|
|
if (data[i] < min) min = data[i]; |
|
|
if (data[i] > max) max = data[i]; |
|
|
} |
|
|
|
|
|
const range = max - min; |
|
|
const imageData = outputCtx.createImageData(width, height); |
|
|
|
|
|
for (let i = 0; i < data.length; i++) { |
|
|
|
|
|
const val = (data[i] - min) / (range || 1); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const pixelVal = Math.floor((1 - val) * 255); |
|
|
|
|
|
imageData.data[i * 4] = pixelVal; |
|
|
imageData.data[i * 4 + 1] = pixelVal; |
|
|
imageData.data[i * 4 + 2] = pixelVal; |
|
|
imageData.data[i * 4 + 3] = 255; |
|
|
} |
|
|
|
|
|
outputCtx.putImageData(imageData, 0, 0); |
|
|
} |
|
|
|
|
|
init(); |