|
import React, { useRef, useState } from 'react'; |
|
import Webcam from 'react-webcam'; |
|
import * as ort from 'onnxruntime-web'; |
|
|
|
function ObjectDetection() { |
|
const [results, setResults] = useState([]); |
|
const [loading, setLoading] = useState(false); |
|
const webcamRef = useRef(null); |
|
|
|
const runInference = async () => { |
|
if (!webcamRef.current) return; |
|
setLoading(true); |
|
|
|
try { |
|
|
|
const imageSrc = webcamRef.current.getScreenshot(); |
|
|
|
|
|
const model = await ort.InferenceSession.create('./model.onnx'); |
|
|
|
|
|
const inputTensor = await preprocessImage(imageSrc); |
|
|
|
|
|
const feeds = { input: inputTensor }; |
|
|
|
|
|
const output = await model.run(feeds); |
|
|
|
|
|
const detections = postprocessOutput(output); |
|
setResults(detections); |
|
} catch (error) { |
|
console.error('Error running inference:', error); |
|
} |
|
|
|
setLoading(false); |
|
}; |
|
|
|
const preprocessImage = async (imageSrc) => { |
|
const img = new Image(); |
|
img.src = imageSrc; |
|
await new Promise((resolve) => (img.onload = resolve)); |
|
|
|
const canvas = document.createElement('canvas'); |
|
const context = canvas.getContext('2d'); |
|
|
|
|
|
const modelInputWidth = 300; |
|
const modelInputHeight = 300; |
|
canvas.width = modelInputWidth; |
|
canvas.height = modelInputHeight; |
|
|
|
context.drawImage(img, 0, 0, modelInputWidth, modelInputHeight); |
|
|
|
const imageData = context.getImageData(0, 0, modelInputWidth, modelInputHeight); |
|
|
|
|
|
const isUint8 = true; |
|
|
|
if (isUint8) { |
|
|
|
return new ort.Tensor('uint8', imageData.data, [1, modelInputHeight, modelInputWidth, 3]); |
|
} else { |
|
|
|
const floatData = new Float32Array(imageData.data.length / 4 * 3); |
|
for (let i = 0, j = 0; i < imageData.data.length; i += 4) { |
|
floatData[j++] = imageData.data[i] / 255; |
|
floatData[j++] = imageData.data[i + 1] / 255; |
|
floatData[j++] = imageData.data[i + 2] / 255; |
|
} |
|
return new ort.Tensor('float32', floatData, [1, 3, modelInputHeight, modelInputWidth]); |
|
} |
|
}; |
|
|
|
const postprocessOutput = (output) => { |
|
const boxes = output['boxes'].data; |
|
const scores = output['scores'].data; |
|
const classes = output['classes'].data; |
|
|
|
const detections = []; |
|
for (let i = 0; i < scores.length; i++) { |
|
if (scores[i] > 0.5) { |
|
detections.push({ |
|
box: boxes.slice(i * 4, i * 4 + 4), |
|
score: scores[i], |
|
class: classes[i], |
|
}); |
|
} |
|
} |
|
|
|
return detections; |
|
}; |
|
|
|
return React.createElement( |
|
'div', |
|
null, |
|
React.createElement('h1', null, 'Object Detection with Webcam'), |
|
React.createElement(Webcam, { |
|
audio: false, |
|
ref: webcamRef, |
|
screenshotFormat: 'image/jpeg', |
|
width: 300, |
|
height: 300, |
|
}), |
|
React.createElement( |
|
'button', |
|
{ onClick: runInference, disabled: loading }, |
|
loading ? 'Detecting...' : 'Capture & Detect' |
|
), |
|
React.createElement( |
|
'div', |
|
null, |
|
React.createElement('h2', null, 'Results:'), |
|
React.createElement( |
|
'ul', |
|
null, |
|
results.map((result, index) => |
|
React.createElement( |
|
'li', |
|
{ key: index }, |
|
`Class: ${result.class}, Score: ${result.score.toFixed(2)}, Box: ${result.box.join(', ')}` |
|
) |
|
) |
|
) |
|
) |
|
); |
|
} |
|
|
|
export default ObjectDetection; |
|
|
|
|
|
|
|
|
|
const preprocessImage = async (imageSrc) => { |
|
const img = new Image(); |
|
img.src = imageSrc; |
|
await new Promise((resolve) => (img.onload = resolve)); |
|
|
|
const canvas = document.createElement('canvas'); |
|
const context = canvas.getContext('2d'); |
|
|
|
|
|
const modelInputWidth = 320; |
|
const modelInputHeight = 320; |
|
canvas.width = modelInputWidth; |
|
canvas.height = modelInputHeight; |
|
|
|
context.drawImage(img, 0, 0, modelInputWidth, modelInputHeight); |
|
|
|
const imageData = context.getImageData(0, 0, modelInputWidth, modelInputHeight); |
|
|
|
|
|
const rgbData = new Uint8Array((imageData.data.length / 4) * 3); |
|
for (let i = 0, j = 0; i < imageData.data.length; i += 4) { |
|
rgbData[j++] = imageData.data[i]; |
|
rgbData[j++] = imageData.data[i + 1]; |
|
rgbData[j++] = imageData.data[i + 2]; |
|
|
|
} |
|
|
|
|
|
return new ort.Tensor('uint8', rgbData, [1, modelInputHeight, modelInputWidth, 3]); |
|
}; |