Car_VS_Rest / onnx.js
Nekshay's picture
Update onnx.js
1958a45 verified
raw
history blame
5.1 kB
import React, { useRef, useState } from 'react';
import Webcam from 'react-webcam';
import * as ort from 'onnxruntime-web';
function ObjectDetection() {
const [results, setResults] = useState([]);
const [loading, setLoading] = useState(false);
const webcamRef = useRef(null);
const runInference = async () => {
if (!webcamRef.current) return;
setLoading(true);
try {
// Capture image from webcam
const imageSrc = webcamRef.current.getScreenshot();
// Load the ONNX model
const model = await ort.InferenceSession.create('./model.onnx');
// Preprocess the image
const inputTensor = await preprocessImage(imageSrc);
// Define model input
const feeds = { input: inputTensor };
// Run inference
const output = await model.run(feeds);
// Postprocess the output
const detections = postprocessOutput(output);
setResults(detections);
} catch (error) {
console.error('Error running inference:', error);
}
setLoading(false);
};
const preprocessImage = async (imageSrc) => {
const img = new Image();
img.src = imageSrc;
await new Promise((resolve) => (img.onload = resolve));
const canvas = document.createElement('canvas');
const context = canvas.getContext('2d');
// Resize to model input size
const modelInputWidth = 300; // Replace with your model's input width
const modelInputHeight = 300; // Replace with your model's input height
canvas.width = modelInputWidth;
canvas.height = modelInputHeight;
context.drawImage(img, 0, 0, modelInputWidth, modelInputHeight);
const imageData = context.getImageData(0, 0, modelInputWidth, modelInputHeight);
// Check the required data type
const isUint8 = true; // Set to true if your model expects uint8, false for float32
if (isUint8) {
// Create Uint8Array tensor
return new ort.Tensor('uint8', imageData.data, [1, modelInputHeight, modelInputWidth, 3]);
} else {
// Normalize to [0, 1] and create Float32Array tensor
const floatData = new Float32Array(imageData.data.length / 4 * 3);
for (let i = 0, j = 0; i < imageData.data.length; i += 4) {
floatData[j++] = imageData.data[i] / 255; // R
floatData[j++] = imageData.data[i + 1] / 255; // G
floatData[j++] = imageData.data[i + 2] / 255; // B
}
return new ort.Tensor('float32', floatData, [1, 3, modelInputHeight, modelInputWidth]);
}
};
const postprocessOutput = (output) => {
const boxes = output['boxes'].data; // Replace 'boxes' with your model's output name
const scores = output['scores'].data; // Replace 'scores' with your model's output name
const classes = output['classes'].data; // Replace 'classes' with your model's output name
const detections = [];
for (let i = 0; i < scores.length; i++) {
if (scores[i] > 0.5) { // Confidence threshold
detections.push({
box: boxes.slice(i * 4, i * 4 + 4),
score: scores[i],
class: classes[i],
});
}
}
return detections;
};
return React.createElement(
'div',
null,
React.createElement('h1', null, 'Object Detection with Webcam'),
React.createElement(Webcam, {
audio: false,
ref: webcamRef,
screenshotFormat: 'image/jpeg',
width: 300,
height: 300,
}),
React.createElement(
'button',
{ onClick: runInference, disabled: loading },
loading ? 'Detecting...' : 'Capture & Detect'
),
React.createElement(
'div',
null,
React.createElement('h2', null, 'Results:'),
React.createElement(
'ul',
null,
results.map((result, index) =>
React.createElement(
'li',
{ key: index },
`Class: ${result.class}, Score: ${result.score.toFixed(2)}, Box: ${result.box.join(', ')}`
)
)
)
)
);
}
export default ObjectDetection;
const preprocessImage = async (imageSrc) => {
const img = new Image();
img.src = imageSrc;
await new Promise((resolve) => (img.onload = resolve));
const canvas = document.createElement('canvas');
const context = canvas.getContext('2d');
// Resize to model input size
const modelInputWidth = 320; // Replace with your model's input width
const modelInputHeight = 320; // Replace with your model's input height
canvas.width = modelInputWidth;
canvas.height = modelInputHeight;
context.drawImage(img, 0, 0, modelInputWidth, modelInputHeight);
const imageData = context.getImageData(0, 0, modelInputWidth, modelInputHeight);
// Convert RGBA to RGB
const rgbData = new Uint8Array((imageData.data.length / 4) * 3); // 3 channels for RGB
for (let i = 0, j = 0; i < imageData.data.length; i += 4) {
rgbData[j++] = imageData.data[i]; // R
rgbData[j++] = imageData.data[i + 1]; // G
rgbData[j++] = imageData.data[i + 2]; // B
// Skip A (alpha) channel
}
// Create a tensor with shape [1, 320, 320, 3]
return new ort.Tensor('uint8', rgbData, [1, modelInputHeight, modelInputWidth, 3]);
};