Adding loading bars
Browse filesThis commit upgrades from the forked transformers repository to the latest version as of this commit (3.3.1), adds new loading bars so that transcription progress is reported like model loading progress, reports the transcription speedup factor, and fixes vulnerabilities with a particular package dependency.
whisper-speaker-diarization/package.json
CHANGED
@@ -10,11 +10,13 @@
|
|
10 |
"preview": "vite preview"
|
11 |
},
|
12 |
"dependencies": {
|
13 |
-
"@
|
|
|
14 |
"react": "^18.3.1",
|
15 |
"react-dom": "^18.3.1"
|
16 |
},
|
17 |
"devDependencies": {
|
|
|
18 |
"@types/react": "^18.3.3",
|
19 |
"@types/react-dom": "^18.3.0",
|
20 |
"@vitejs/plugin-react": "^4.3.1",
|
|
|
10 |
"preview": "vite preview"
|
11 |
},
|
12 |
"dependencies": {
|
13 |
+
"@huggingface/transformers": "^3.3.1",
|
14 |
+
"prop-types": "^15.8.1",
|
15 |
"react": "^18.3.1",
|
16 |
"react-dom": "^18.3.1"
|
17 |
},
|
18 |
"devDependencies": {
|
19 |
+
"@rollup/plugin-commonjs": "^28.0.1",
|
20 |
"@types/react": "^18.3.3",
|
21 |
"@types/react-dom": "^18.3.0",
|
22 |
"@vitejs/plugin-react": "^4.3.1",
|
whisper-speaker-diarization/src/App.jsx
CHANGED
@@ -1,218 +1,257 @@
|
|
1 |
-
import { useEffect, useState, useRef, useCallback } from 'react';
|
2 |
-
|
3 |
-
import Progress from './components/Progress';
|
4 |
-
import MediaInput from './components/MediaInput';
|
5 |
-
import Transcript from './components/Transcript';
|
6 |
-
import LanguageSelector from './components/LanguageSelector';
|
7 |
-
|
8 |
-
|
9 |
-
async function hasWebGPU() {
|
10 |
-
if (!navigator.gpu) {
|
11 |
-
return false;
|
12 |
-
}
|
13 |
-
try {
|
14 |
-
const adapter = await navigator.gpu.requestAdapter();
|
15 |
-
return !!adapter;
|
16 |
-
} catch (e) {
|
17 |
-
return false;
|
18 |
-
}
|
19 |
-
}
|
20 |
-
|
21 |
-
function App() {
|
22 |
-
|
23 |
-
// Create a reference to the worker object.
|
24 |
-
const worker = useRef(null);
|
25 |
-
|
26 |
-
// Model loading and progress
|
27 |
-
const [status, setStatus] = useState(null);
|
28 |
-
const [loadingMessage, setLoadingMessage] = useState('');
|
29 |
-
const [progressItems, setProgressItems] = useState([]);
|
30 |
-
|
31 |
-
const mediaInputRef = useRef(null);
|
32 |
-
const [audio, setAudio] = useState(null);
|
33 |
-
const [language, setLanguage] = useState('en');
|
34 |
-
|
35 |
-
const [result, setResult] = useState(null);
|
36 |
-
const [time, setTime] = useState(null);
|
37 |
-
const [
|
38 |
-
|
39 |
-
|
40 |
-
const [
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
}
|
109 |
-
}
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { useEffect, useState, useRef, useCallback } from 'react';
|
2 |
+
|
3 |
+
import Progress from './components/Progress';
|
4 |
+
import MediaInput from './components/MediaInput';
|
5 |
+
import Transcript from './components/Transcript';
|
6 |
+
import LanguageSelector from './components/LanguageSelector';
|
7 |
+
|
8 |
+
|
9 |
+
async function hasWebGPU() {
|
10 |
+
if (!navigator.gpu) {
|
11 |
+
return false;
|
12 |
+
}
|
13 |
+
try {
|
14 |
+
const adapter = await navigator.gpu.requestAdapter();
|
15 |
+
return !!adapter;
|
16 |
+
} catch (e) {
|
17 |
+
return false;
|
18 |
+
}
|
19 |
+
}
|
20 |
+
|
21 |
+
function App() {
|
22 |
+
|
23 |
+
// Create a reference to the worker object.
|
24 |
+
const worker = useRef(null);
|
25 |
+
|
26 |
+
// Model loading and progress
|
27 |
+
const [status, setStatus] = useState(null);
|
28 |
+
const [loadingMessage, setLoadingMessage] = useState('');
|
29 |
+
const [progressItems, setProgressItems] = useState([]);
|
30 |
+
|
31 |
+
const mediaInputRef = useRef(null);
|
32 |
+
const [audio, setAudio] = useState(null);
|
33 |
+
const [language, setLanguage] = useState('en');
|
34 |
+
|
35 |
+
const [result, setResult] = useState(null);
|
36 |
+
const [time, setTime] = useState(null);
|
37 |
+
const [audioLength, setAudioLength] = useState(null);
|
38 |
+
const [currentTime, setCurrentTime] = useState(0);
|
39 |
+
|
40 |
+
const [device, setDevice] = useState('webgpu'); // Try use WebGPU first
|
41 |
+
const [modelSize, setModelSize] = useState('gpu' in navigator ? 196 : 77); // WebGPU=196MB, WebAssembly=77MB
|
42 |
+
useEffect(() => {
|
43 |
+
hasWebGPU().then((b) => {
|
44 |
+
setModelSize(b ? 196 : 77);
|
45 |
+
setDevice(b ? 'webgpu' : 'wasm');
|
46 |
+
});
|
47 |
+
}, []);
|
48 |
+
|
49 |
+
// Create a callback function for messages from the worker thread.
|
50 |
+
const onMessageReceived = (e) => {
|
51 |
+
switch (e.data.status) {
|
52 |
+
case 'loading':
|
53 |
+
// Model file start load: add a new progress item to the list.
|
54 |
+
setStatus('loading');
|
55 |
+
setLoadingMessage(e.data.data);
|
56 |
+
break;
|
57 |
+
|
58 |
+
case 'initiate':
|
59 |
+
setProgressItems(prev => [...prev, e.data]);
|
60 |
+
break;
|
61 |
+
|
62 |
+
case 'progress':
|
63 |
+
// Model file progress: update one of the progress items.
|
64 |
+
setProgressItems(
|
65 |
+
prev => prev.map(item => {
|
66 |
+
if (item.file === e.data.file) {
|
67 |
+
return { ...item, ...e.data }
|
68 |
+
}
|
69 |
+
return item;
|
70 |
+
})
|
71 |
+
);
|
72 |
+
break;
|
73 |
+
|
74 |
+
case 'done':
|
75 |
+
// Model file loaded: remove the progress item from the list.
|
76 |
+
setProgressItems(
|
77 |
+
prev => prev.filter(item => item.file !== e.data.file)
|
78 |
+
);
|
79 |
+
break;
|
80 |
+
|
81 |
+
case 'loaded':
|
82 |
+
// Pipeline ready: the worker is ready to accept messages.
|
83 |
+
setStatus('ready');
|
84 |
+
break;
|
85 |
+
|
86 |
+
case 'transcribe-progress': {
|
87 |
+
// Update progress for transcription/diarization
|
88 |
+
const { task, progress, total } = e.data.data;
|
89 |
+
setProgressItems(prev => {
|
90 |
+
const existingIndex = prev.findIndex(item => item.file === task);
|
91 |
+
if (existingIndex >= 0) {
|
92 |
+
return prev.map((item, i) =>
|
93 |
+
i === existingIndex ? { ...item, progress, total } : item
|
94 |
+
);
|
95 |
+
}
|
96 |
+
const newItem = { file: task, progress, total };
|
97 |
+
return [...prev, newItem];
|
98 |
+
});
|
99 |
+
break;
|
100 |
+
}
|
101 |
+
|
102 |
+
case 'complete':
|
103 |
+
setResult(e.data.result);
|
104 |
+
setTime(e.data.time);
|
105 |
+
setAudioLength(e.data.audio_length);
|
106 |
+
setStatus('ready');
|
107 |
+
break;
|
108 |
+
}
|
109 |
+
};
|
110 |
+
|
111 |
+
// We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
|
112 |
+
useEffect(() => {
|
113 |
+
if (!worker.current) {
|
114 |
+
// Create the worker if it does not yet exist.
|
115 |
+
worker.current = new Worker(new URL('./worker.js', import.meta.url), {
|
116 |
+
type: 'module'
|
117 |
+
});
|
118 |
+
}
|
119 |
+
|
120 |
+
// Attach the callback function as an event listener.
|
121 |
+
worker.current.addEventListener('message', onMessageReceived);
|
122 |
+
|
123 |
+
// Define a cleanup function for when the component is unmounted.
|
124 |
+
return () => {
|
125 |
+
worker.current.removeEventListener('message', onMessageReceived);
|
126 |
+
};
|
127 |
+
}, []);
|
128 |
+
|
129 |
+
const handleClick = useCallback(() => {
|
130 |
+
setResult(null);
|
131 |
+
setTime(null);
|
132 |
+
if (status === null) {
|
133 |
+
setStatus('loading');
|
134 |
+
worker.current.postMessage({ type: 'load', data: { device } });
|
135 |
+
} else {
|
136 |
+
setStatus('running');
|
137 |
+
worker.current.postMessage({
|
138 |
+
type: 'run', data: { audio, language }
|
139 |
+
});
|
140 |
+
}
|
141 |
+
}, [status, audio, language, device]);
|
142 |
+
|
143 |
+
return (
|
144 |
+
<div className="flex flex-col h-screen mx-auto text-gray-800 dark:text-gray-200 bg-white dark:bg-gray-900 max-w-[600px]">
|
145 |
+
|
146 |
+
{(status === 'loading' || status === 'running') && (
|
147 |
+
<div className="flex justify-center items-center fixed w-screen h-screen bg-black z-10 bg-opacity-[92%] top-0 left-0">
|
148 |
+
<div className="w-[500px]">
|
149 |
+
<p className="text-center mb-1 text-white text-md">{loadingMessage}</p>
|
150 |
+
{progressItems
|
151 |
+
.sort((a, b) => {
|
152 |
+
// Define the order: transcription -> segmentation -> diarization
|
153 |
+
const order = { 'transcription': 0, 'segmentation': 1, 'diarization': 2 };
|
154 |
+
return (order[a.file] ?? 3) - (order[b.file] ?? 3);
|
155 |
+
})
|
156 |
+
.map(({ file, progress, total }, i) => (
|
157 |
+
<Progress
|
158 |
+
key={i}
|
159 |
+
text={file === 'transcription' ? 'Converting speech to text' :
|
160 |
+
file === 'segmentation' ? 'Detecting word timestamps' :
|
161 |
+
file === 'diarization' ? 'Identifying speakers' :
|
162 |
+
file}
|
163 |
+
percentage={progress}
|
164 |
+
total={total}
|
165 |
+
/>
|
166 |
+
))
|
167 |
+
}
|
168 |
+
</div>
|
169 |
+
</div>
|
170 |
+
)}
|
171 |
+
<div className="my-auto">
|
172 |
+
<div className="flex flex-col items-center mb-2 text-center">
|
173 |
+
<h1 className="text-5xl font-bold mb-2">Whisper Diarization</h1>
|
174 |
+
<h2 className="text-xl font-semibold">In-browser automatic speech recognition w/ <br />word-level timestamps and speaker segmentation</h2>
|
175 |
+
</div>
|
176 |
+
|
177 |
+
<div className="w-full min-h-[220px] flex flex-col justify-center items-center">
|
178 |
+
{
|
179 |
+
!audio && (
|
180 |
+
<p className="mb-2">
|
181 |
+
You are about to download <a href="https://huggingface.co/onnx-community/whisper-base_timestamped" target="_blank" rel="noreferrer" className="font-medium underline">whisper-base</a> and <a href="https://huggingface.co/onnx-community/pyannote-segmentation-3.0" target="_blank" rel="noreferrer" className="font-medium underline">pyannote-segmentation-3.0</a>,
|
182 |
+
two powerful speech recognition models for generating word-level timestamps across 100 different languages and speaker segmentation, respectively.
|
183 |
+
Once loaded, the models ({modelSize}MB + 6MB) will be cached and reused when you revisit the page.<br />
|
184 |
+
<br />
|
185 |
+
Everything runs locally in your browser using <a href="https://huggingface.co/docs/transformers.js" target="_blank" rel="noreferrer" className="underline">🤗 Transformers.js</a> and ONNX Runtime Web,
|
186 |
+
meaning no API calls are made to a server for inference. You can even disconnect from the internet after the model has loaded!
|
187 |
+
</p>
|
188 |
+
)
|
189 |
+
}
|
190 |
+
|
191 |
+
<div className="flex flex-col w-full m-3 max-w-[520px]">
|
192 |
+
<span className="text-sm mb-0.5">Input audio/video</span>
|
193 |
+
<MediaInput
|
194 |
+
ref={mediaInputRef}
|
195 |
+
className="flex items-center border rounded-md cursor-pointer min-h-[100px] max-h-[500px] overflow-hidden"
|
196 |
+
onInputChange={(audio) => {
|
197 |
+
setResult(null);
|
198 |
+
setAudio(audio);
|
199 |
+
}}
|
200 |
+
onTimeUpdate={(time) => setCurrentTime(time)}
|
201 |
+
onMessage={onMessageReceived}
|
202 |
+
/>
|
203 |
+
</div>
|
204 |
+
|
205 |
+
<div className="relative w-full flex justify-center items-center">
|
206 |
+
<button
|
207 |
+
className="border px-4 py-2 rounded-lg bg-blue-400 text-white hover:bg-blue-500 disabled:bg-blue-100 disabled:cursor-not-allowed select-none"
|
208 |
+
onClick={handleClick}
|
209 |
+
disabled={status === 'running' || (status !== null && audio === null)}
|
210 |
+
>
|
211 |
+
{status === null ? 'Load model' :
|
212 |
+
status === 'running'
|
213 |
+
? 'Running...'
|
214 |
+
: 'Run model'
|
215 |
+
}
|
216 |
+
</button>
|
217 |
+
|
218 |
+
{status !== null &&
|
219 |
+
<div className='absolute right-0 bottom-0'>
|
220 |
+
<span className="text-xs">Language:</span>
|
221 |
+
<br />
|
222 |
+
<LanguageSelector className="border rounded-lg p-1 max-w-[100px]" language={language} setLanguage={setLanguage} />
|
223 |
+
</div>
|
224 |
+
}
|
225 |
+
</div>
|
226 |
+
|
227 |
+
{
|
228 |
+
result && time && (
|
229 |
+
<>
|
230 |
+
<div className="w-full mt-4 border rounded-md">
|
231 |
+
<Transcript
|
232 |
+
className="p-2 max-h-[200px] overflow-y-auto scrollbar-thin select-none"
|
233 |
+
transcript={result.transcript}
|
234 |
+
segments={result.segments}
|
235 |
+
currentTime={currentTime}
|
236 |
+
setCurrentTime={(time) => {
|
237 |
+
setCurrentTime(time);
|
238 |
+
mediaInputRef.current.setMediaTime(time);
|
239 |
+
}}
|
240 |
+
/>
|
241 |
+
</div>
|
242 |
+
<p className="text-sm text-end p-1">Generation time:
|
243 |
+
<span className="font-semibold">{(time / 1000).toLocaleString()} s</span>
|
244 |
+
</p>
|
245 |
+
<p className="text-sm text-end p-1">
|
246 |
+
<span className="font-semibold">{(audioLength / (time / 1000)).toFixed(2)}x transcription!</span>
|
247 |
+
</p>
|
248 |
+
</>
|
249 |
+
)
|
250 |
+
}
|
251 |
+
</div>
|
252 |
+
</div>
|
253 |
+
</div >
|
254 |
+
)
|
255 |
+
}
|
256 |
+
|
257 |
+
export default App
|
whisper-speaker-diarization/src/components/MediaInput.jsx
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
import { useState, forwardRef, useRef, useImperativeHandle, useEffect, useCallback } from 'react';
|
2 |
-
|
3 |
const EXAMPLE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/hopper.webm';
|
4 |
|
5 |
-
const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, ...props }, ref) => {
|
6 |
// UI states
|
7 |
const [dragging, setDragging] = useState(false);
|
8 |
const fileInputRef = useRef(null);
|
@@ -89,7 +89,40 @@ const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, ...props }, ref) =
|
|
89 |
const audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16_000 });
|
90 |
|
91 |
try {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
const audioBuffer = await audioContext.decodeAudioData(buffer);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
let audio;
|
94 |
if (audioBuffer.numberOfChannels === 2) {
|
95 |
// Merge channels
|
@@ -145,8 +178,8 @@ const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, ...props }, ref) =
|
|
145 |
onClick={handleClick}
|
146 |
onDragOver={handleDragOver}
|
147 |
onDrop={handleDrop}
|
148 |
-
onDragEnter={(
|
149 |
-
onDragLeave={(
|
150 |
>
|
151 |
<input
|
152 |
type="file"
|
@@ -189,6 +222,13 @@ const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, ...props }, ref) =
|
|
189 |
</div>
|
190 |
);
|
191 |
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
MediaInput.displayName = 'MediaInput';
|
193 |
|
194 |
export default MediaInput;
|
|
|
1 |
import { useState, forwardRef, useRef, useImperativeHandle, useEffect, useCallback } from 'react';
|
2 |
+
import PropTypes from 'prop-types';
|
3 |
const EXAMPLE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/hopper.webm';
|
4 |
|
5 |
+
const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, onMessage, ...props }, ref) => {
|
6 |
// UI states
|
7 |
const [dragging, setDragging] = useState(false);
|
8 |
const fileInputRef = useRef(null);
|
|
|
89 |
const audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16_000 });
|
90 |
|
91 |
try {
|
92 |
+
// Start audio decoding
|
93 |
+
onMessage({
|
94 |
+
data: {
|
95 |
+
status: 'loading',
|
96 |
+
data: 'Decoding audio buffer...'
|
97 |
+
}
|
98 |
+
});
|
99 |
+
|
100 |
+
onMessage({
|
101 |
+
data: {
|
102 |
+
status: 'initiate',
|
103 |
+
name: 'audio-decoder',
|
104 |
+
file: 'audio-buffer'
|
105 |
+
}
|
106 |
+
});
|
107 |
+
|
108 |
const audioBuffer = await audioContext.decodeAudioData(buffer);
|
109 |
+
|
110 |
+
// Audio decoding complete
|
111 |
+
onMessage({
|
112 |
+
data: {
|
113 |
+
status: 'done',
|
114 |
+
name: 'audio-decoder',
|
115 |
+
file: 'audio-buffer'
|
116 |
+
}
|
117 |
+
});
|
118 |
+
|
119 |
+
// Audio decoding complete
|
120 |
+
onMessage({
|
121 |
+
data: {
|
122 |
+
status: 'loaded'
|
123 |
+
}
|
124 |
+
});
|
125 |
+
|
126 |
let audio;
|
127 |
if (audioBuffer.numberOfChannels === 2) {
|
128 |
// Merge channels
|
|
|
178 |
onClick={handleClick}
|
179 |
onDragOver={handleDragOver}
|
180 |
onDrop={handleDrop}
|
181 |
+
onDragEnter={() => setDragging(true)}
|
182 |
+
onDragLeave={() => setDragging(false)}
|
183 |
>
|
184 |
<input
|
185 |
type="file"
|
|
|
222 |
</div>
|
223 |
);
|
224 |
});
|
225 |
+
|
226 |
+
MediaInput.propTypes = {
|
227 |
+
onInputChange: PropTypes.func.isRequired,
|
228 |
+
onTimeUpdate: PropTypes.func.isRequired,
|
229 |
+
onMessage: PropTypes.func.isRequired
|
230 |
+
};
|
231 |
+
|
232 |
MediaInput.displayName = 'MediaInput';
|
233 |
|
234 |
export default MediaInput;
|
whisper-speaker-diarization/src/components/Progress.jsx
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
function formatBytes(size) {
|
2 |
const i = size == 0 ? 0 : Math.floor(Math.log(size) / Math.log(1024));
|
3 |
return +((size / Math.pow(1024, i)).toFixed(2)) * 1 + ['B', 'kB', 'MB', 'GB', 'TB'][i];
|
@@ -13,3 +15,9 @@ export default function Progress({ text, percentage, total }) {
|
|
13 |
</div>
|
14 |
);
|
15 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PropTypes from 'prop-types';
|
2 |
+
|
3 |
function formatBytes(size) {
|
4 |
const i = size == 0 ? 0 : Math.floor(Math.log(size) / Math.log(1024));
|
5 |
return +((size / Math.pow(1024, i)).toFixed(2)) * 1 + ['B', 'kB', 'MB', 'GB', 'TB'][i];
|
|
|
15 |
</div>
|
16 |
);
|
17 |
}
|
18 |
+
|
19 |
+
Progress.propTypes = {
|
20 |
+
text: PropTypes.string.isRequired,
|
21 |
+
percentage: PropTypes.number,
|
22 |
+
total: PropTypes.number
|
23 |
+
};
|
whisper-speaker-diarization/src/worker.js
CHANGED
@@ -1,124 +1,272 @@
|
|
1 |
-
|
2 |
-
import { pipeline, AutoProcessor, AutoModelForAudioFrameClassification } from '@
|
3 |
-
|
4 |
-
const PER_DEVICE_CONFIG = {
|
5 |
-
webgpu: {
|
6 |
-
dtype: {
|
7 |
-
encoder_model: 'fp32',
|
8 |
-
decoder_model_merged: 'q4',
|
9 |
-
},
|
10 |
-
device: 'webgpu',
|
11 |
-
},
|
12 |
-
wasm: {
|
13 |
-
dtype: 'q8',
|
14 |
-
device: 'wasm',
|
15 |
-
},
|
16 |
-
};
|
17 |
-
|
18 |
-
/**
|
19 |
-
* This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
|
20 |
-
*/
|
21 |
-
class PipelineSingeton {
|
22 |
-
static asr_model_id = 'onnx-community/whisper-base_timestamped';
|
23 |
-
static asr_instance = null;
|
24 |
-
|
25 |
-
static segmentation_model_id = 'onnx-community/pyannote-segmentation-3.0';
|
26 |
-
static segmentation_instance = null;
|
27 |
-
static segmentation_processor = null;
|
28 |
-
|
29 |
-
static async getInstance(progress_callback = null, device = 'webgpu') {
|
30 |
-
this.asr_instance ??= pipeline('automatic-speech-recognition', this.asr_model_id, {
|
31 |
-
...PER_DEVICE_CONFIG[device],
|
32 |
-
progress_callback,
|
33 |
-
});
|
34 |
-
|
35 |
-
this.segmentation_processor ??= AutoProcessor.from_pretrained(this.segmentation_model_id, {
|
36 |
-
progress_callback,
|
37 |
-
});
|
38 |
-
this.segmentation_instance ??= AutoModelForAudioFrameClassification.from_pretrained(this.segmentation_model_id, {
|
39 |
-
// NOTE: WebGPU is not currently supported for this model
|
40 |
-
// See https://github.com/microsoft/onnxruntime/issues/21386
|
41 |
-
device: 'wasm',
|
42 |
-
dtype: 'fp32',
|
43 |
-
progress_callback,
|
44 |
-
});
|
45 |
-
|
46 |
-
return Promise.all([this.asr_instance, this.segmentation_processor, this.segmentation_instance]);
|
47 |
-
}
|
48 |
-
}
|
49 |
-
|
50 |
-
async function load({ device }) {
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
}
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
}
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import { pipeline, AutoProcessor, AutoModelForAudioFrameClassification } from '@huggingface/transformers';
|
3 |
+
|
4 |
+
const PER_DEVICE_CONFIG = {
|
5 |
+
webgpu: {
|
6 |
+
dtype: {
|
7 |
+
encoder_model: 'fp32',
|
8 |
+
decoder_model_merged: 'q4',
|
9 |
+
},
|
10 |
+
device: 'webgpu',
|
11 |
+
},
|
12 |
+
wasm: {
|
13 |
+
dtype: 'q8',
|
14 |
+
device: 'wasm',
|
15 |
+
},
|
16 |
+
};
|
17 |
+
|
18 |
+
/**
|
19 |
+
* This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
|
20 |
+
*/
|
21 |
+
class PipelineSingeton {
|
22 |
+
static asr_model_id = 'onnx-community/whisper-base_timestamped';
|
23 |
+
static asr_instance = null;
|
24 |
+
|
25 |
+
static segmentation_model_id = 'onnx-community/pyannote-segmentation-3.0';
|
26 |
+
static segmentation_instance = null;
|
27 |
+
static segmentation_processor = null;
|
28 |
+
|
29 |
+
static async getInstance(progress_callback = null, device = 'webgpu') {
|
30 |
+
this.asr_instance ??= pipeline('automatic-speech-recognition', this.asr_model_id, {
|
31 |
+
...PER_DEVICE_CONFIG[device],
|
32 |
+
progress_callback,
|
33 |
+
});
|
34 |
+
|
35 |
+
this.segmentation_processor ??= AutoProcessor.from_pretrained(this.segmentation_model_id, {
|
36 |
+
progress_callback,
|
37 |
+
});
|
38 |
+
this.segmentation_instance ??= AutoModelForAudioFrameClassification.from_pretrained(this.segmentation_model_id, {
|
39 |
+
// NOTE: WebGPU is not currently supported for this model
|
40 |
+
// See https://github.com/microsoft/onnxruntime/issues/21386
|
41 |
+
device: 'wasm',
|
42 |
+
dtype: 'fp32',
|
43 |
+
progress_callback,
|
44 |
+
});
|
45 |
+
|
46 |
+
return Promise.all([this.asr_instance, this.segmentation_processor, this.segmentation_instance]);
|
47 |
+
}
|
48 |
+
}
|
49 |
+
|
50 |
+
async function load({ device }) {
|
51 |
+
try {
|
52 |
+
const message = {
|
53 |
+
status: 'loading',
|
54 |
+
data: `Loading models (${device})...`
|
55 |
+
};
|
56 |
+
self.postMessage(message);
|
57 |
+
|
58 |
+
const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance(x => {
|
59 |
+
// We also add a progress callback to the pipeline so that we can
|
60 |
+
// track model loading.
|
61 |
+
self.postMessage(x);
|
62 |
+
}, device);
|
63 |
+
|
64 |
+
if (device === 'webgpu') {
|
65 |
+
const warmupMessage = {
|
66 |
+
status: 'loading',
|
67 |
+
data: 'Compiling shaders and warming up model...'
|
68 |
+
};
|
69 |
+
|
70 |
+
self.postMessage(warmupMessage);
|
71 |
+
|
72 |
+
await transcriber(new Float32Array(16_000), {
|
73 |
+
language: 'en',
|
74 |
+
});
|
75 |
+
}
|
76 |
+
|
77 |
+
self.postMessage({ status: 'loaded' });
|
78 |
+
} catch (error) {
|
79 |
+
console.error('Loading error:', error);
|
80 |
+
const errorMessage = {
|
81 |
+
status: 'error',
|
82 |
+
error: error.message || 'Failed to load models'
|
83 |
+
};
|
84 |
+
self.postMessage(errorMessage);
|
85 |
+
}
|
86 |
+
}
|
87 |
+
|
88 |
+
async function segment(processor, model, audio) {
|
89 |
+
try {
|
90 |
+
// Report start of segmentation
|
91 |
+
self.postMessage({
|
92 |
+
status: 'transcribe-progress',
|
93 |
+
data: {
|
94 |
+
task: 'segmentation',
|
95 |
+
progress: 0,
|
96 |
+
total: audio.length
|
97 |
+
}
|
98 |
+
});
|
99 |
+
|
100 |
+
// Process audio in chunks to show progress
|
101 |
+
const inputs = await processor(audio);
|
102 |
+
|
103 |
+
// Report segmentation feature extraction progress
|
104 |
+
self.postMessage({
|
105 |
+
status: 'transcribe-progress',
|
106 |
+
data: {
|
107 |
+
task: 'segmentation',
|
108 |
+
progress: 50,
|
109 |
+
total: audio.length
|
110 |
+
}
|
111 |
+
});
|
112 |
+
|
113 |
+
const { logits } = await model(inputs);
|
114 |
+
|
115 |
+
// Report segmentation completion
|
116 |
+
self.postMessage({
|
117 |
+
status: 'transcribe-progress',
|
118 |
+
data: {
|
119 |
+
task: 'segmentation',
|
120 |
+
progress: 100,
|
121 |
+
total: audio.length
|
122 |
+
}
|
123 |
+
});
|
124 |
+
|
125 |
+
// Start diarization
|
126 |
+
self.postMessage({
|
127 |
+
status: 'transcribe-progress',
|
128 |
+
data: {
|
129 |
+
task: 'diarization',
|
130 |
+
progress: 0,
|
131 |
+
total: audio.length
|
132 |
+
}
|
133 |
+
});
|
134 |
+
|
135 |
+
const segments = processor.post_process_speaker_diarization(logits, audio.length)[0];
|
136 |
+
|
137 |
+
// Attach labels and report diarization completion
|
138 |
+
for (const segment of segments) {
|
139 |
+
segment.label = model.config.id2label[segment.id];
|
140 |
+
}
|
141 |
+
|
142 |
+
self.postMessage({
|
143 |
+
status: 'transcribe-progress',
|
144 |
+
data: {
|
145 |
+
task: 'diarization',
|
146 |
+
progress: 100,
|
147 |
+
total: audio.length
|
148 |
+
}
|
149 |
+
});
|
150 |
+
|
151 |
+
return segments;
|
152 |
+
} catch (error) {
|
153 |
+
console.error('Segmentation error:', error);
|
154 |
+
return [{
|
155 |
+
id: 0,
|
156 |
+
start: 0,
|
157 |
+
end: (audio.length / 480016) * 30,
|
158 |
+
label: 'SPEAKER_00',
|
159 |
+
confidence: 1.0
|
160 |
+
}];
|
161 |
+
}
|
162 |
+
}
|
163 |
+
|
164 |
+
async function run({ audio, language }) {
|
165 |
+
try {
|
166 |
+
const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance();
|
167 |
+
|
168 |
+
const audioLengthSeconds = (audio.length / 16000);
|
169 |
+
|
170 |
+
// Initialize transcription progress
|
171 |
+
self.postMessage({
|
172 |
+
status: 'transcribe-progress',
|
173 |
+
data: {
|
174 |
+
task: 'transcription',
|
175 |
+
progress: 0,
|
176 |
+
total: audio.length
|
177 |
+
}
|
178 |
+
});
|
179 |
+
|
180 |
+
const start = performance.now();
|
181 |
+
// Process in 30-second chunks
|
182 |
+
const CHUNK_SIZE = 3 * 30 * 16000; // 30 seconds * 16000 samples/second
|
183 |
+
const numChunks = Math.ceil(audio.length / CHUNK_SIZE);
|
184 |
+
let transcriptResults = [];
|
185 |
+
|
186 |
+
for (let i = 0; i < numChunks; i++) {
|
187 |
+
const start = i * CHUNK_SIZE;
|
188 |
+
const end = Math.min((i + 1) * CHUNK_SIZE, audio.length);
|
189 |
+
const chunk = audio.slice(start, end);
|
190 |
+
|
191 |
+
// Process chunk
|
192 |
+
const chunkResult = await transcriber(chunk, {
|
193 |
+
language,
|
194 |
+
return_timestamps: 'word',
|
195 |
+
chunk_length_s: 30,
|
196 |
+
});
|
197 |
+
const progressMessage = {
|
198 |
+
status: 'transcribe-progress',
|
199 |
+
data: {
|
200 |
+
task: 'transcription',
|
201 |
+
progress: Math.round((i+1) / numChunks * 100),
|
202 |
+
total: audio.length
|
203 |
+
}
|
204 |
+
};
|
205 |
+
self.postMessage(progressMessage);
|
206 |
+
|
207 |
+
|
208 |
+
// Adjust timestamps for this chunk
|
209 |
+
if (chunkResult.chunks) {
|
210 |
+
chunkResult.chunks.forEach(chunk => {
|
211 |
+
if (chunk.timestamp) {
|
212 |
+
chunk.timestamp[0] += start / 16000; // Convert samples to seconds
|
213 |
+
chunk.timestamp[1] += start / 16000;
|
214 |
+
}
|
215 |
+
});
|
216 |
+
}
|
217 |
+
|
218 |
+
transcriptResults.push(chunkResult);
|
219 |
+
}
|
220 |
+
|
221 |
+
// Combine results
|
222 |
+
const transcript = {
|
223 |
+
text: transcriptResults.map(r => r.text).join(''),
|
224 |
+
chunks: transcriptResults.flatMap(r => r.chunks || [])
|
225 |
+
};
|
226 |
+
|
227 |
+
// Run segmentation in parallel with the last chunk
|
228 |
+
const segments = await segment(segmentation_processor, segmentation_model, audio);
|
229 |
+
|
230 |
+
// Ensure transcription shows as complete
|
231 |
+
self.postMessage({
|
232 |
+
status: 'transcribe-progress',
|
233 |
+
data: {
|
234 |
+
task: 'transcription',
|
235 |
+
progress: 100,
|
236 |
+
total: audio.length
|
237 |
+
}
|
238 |
+
});
|
239 |
+
const end = performance.now();
|
240 |
+
|
241 |
+
const completeMessage = {
|
242 |
+
status: 'complete',
|
243 |
+
result: { transcript, segments },
|
244 |
+
audio_length: audioLengthSeconds,
|
245 |
+
time: end - start
|
246 |
+
};
|
247 |
+
self.postMessage(completeMessage);
|
248 |
+
} catch (error) {
|
249 |
+
console.error('Processing error:', error);
|
250 |
+
const errorMessage = {
|
251 |
+
status: 'error',
|
252 |
+
error: error.message || 'Failed to process audio'
|
253 |
+
};
|
254 |
+
console.log('Worker sending error:', errorMessage);
|
255 |
+
self.postMessage(errorMessage);
|
256 |
+
}
|
257 |
+
}
|
258 |
+
|
259 |
+
// Listen for messages from the main thread
|
260 |
+
self.addEventListener('message', async (e) => {
|
261 |
+
const { type, data } = e.data;
|
262 |
+
|
263 |
+
switch (type) {
|
264 |
+
case 'load':
|
265 |
+
load(data);
|
266 |
+
break;
|
267 |
+
|
268 |
+
case 'run':
|
269 |
+
run(data);
|
270 |
+
break;
|
271 |
+
}
|
272 |
+
});
|