nbpe97 commited on
Commit
baeb019
·
1 Parent(s): 6076ee9

Adding loading bars

Browse files

This commit upgrades from the forked transformers repository to the latest version as of this commit (3.3.1), adds new loading bars so that transcription progress is reported like model loading progress, reports the transcription speedup factor, and fixes vulnerabilities with a particular package dependency.

whisper-speaker-diarization/package.json CHANGED
@@ -10,11 +10,13 @@
10
  "preview": "vite preview"
11
  },
12
  "dependencies": {
13
- "@xenova/transformers": "github:xenova/transformers.js#v3",
 
14
  "react": "^18.3.1",
15
  "react-dom": "^18.3.1"
16
  },
17
  "devDependencies": {
 
18
  "@types/react": "^18.3.3",
19
  "@types/react-dom": "^18.3.0",
20
  "@vitejs/plugin-react": "^4.3.1",
 
10
  "preview": "vite preview"
11
  },
12
  "dependencies": {
13
+ "@huggingface/transformers": "^3.3.1",
14
+ "prop-types": "^15.8.1",
15
  "react": "^18.3.1",
16
  "react-dom": "^18.3.1"
17
  },
18
  "devDependencies": {
19
+ "@rollup/plugin-commonjs": "^28.0.1",
20
  "@types/react": "^18.3.3",
21
  "@types/react-dom": "^18.3.0",
22
  "@vitejs/plugin-react": "^4.3.1",
whisper-speaker-diarization/src/App.jsx CHANGED
@@ -1,218 +1,257 @@
1
- import { useEffect, useState, useRef, useCallback } from 'react';
2
-
3
- import Progress from './components/Progress';
4
- import MediaInput from './components/MediaInput';
5
- import Transcript from './components/Transcript';
6
- import LanguageSelector from './components/LanguageSelector';
7
-
8
-
9
- async function hasWebGPU() {
10
- if (!navigator.gpu) {
11
- return false;
12
- }
13
- try {
14
- const adapter = await navigator.gpu.requestAdapter();
15
- return !!adapter;
16
- } catch (e) {
17
- return false;
18
- }
19
- }
20
-
21
- function App() {
22
-
23
- // Create a reference to the worker object.
24
- const worker = useRef(null);
25
-
26
- // Model loading and progress
27
- const [status, setStatus] = useState(null);
28
- const [loadingMessage, setLoadingMessage] = useState('');
29
- const [progressItems, setProgressItems] = useState([]);
30
-
31
- const mediaInputRef = useRef(null);
32
- const [audio, setAudio] = useState(null);
33
- const [language, setLanguage] = useState('en');
34
-
35
- const [result, setResult] = useState(null);
36
- const [time, setTime] = useState(null);
37
- const [currentTime, setCurrentTime] = useState(0);
38
-
39
- const [device, setDevice] = useState('webgpu'); // Try use WebGPU first
40
- const [modelSize, setModelSize] = useState('gpu' in navigator ? 196 : 77); // WebGPU=196MB, WebAssembly=77MB
41
- useEffect(() => {
42
- hasWebGPU().then((b) => {
43
- setModelSize(b ? 196 : 77);
44
- setDevice(b ? 'webgpu' : 'wasm');
45
- });
46
- }, []);
47
-
48
- // We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
49
- useEffect(() => {
50
- if (!worker.current) {
51
- // Create the worker if it does not yet exist.
52
- worker.current = new Worker(new URL('./worker.js', import.meta.url), {
53
- type: 'module'
54
- });
55
- }
56
-
57
- // Create a callback function for messages from the worker thread.
58
- const onMessageReceived = (e) => {
59
- switch (e.data.status) {
60
- case 'loading':
61
- // Model file start load: add a new progress item to the list.
62
- setStatus('loading');
63
- setLoadingMessage(e.data.data);
64
- break;
65
-
66
- case 'initiate':
67
- setProgressItems(prev => [...prev, e.data]);
68
- break;
69
-
70
- case 'progress':
71
- // Model file progress: update one of the progress items.
72
- setProgressItems(
73
- prev => prev.map(item => {
74
- if (item.file === e.data.file) {
75
- return { ...item, ...e.data }
76
- }
77
- return item;
78
- })
79
- );
80
- break;
81
-
82
- case 'done':
83
- // Model file loaded: remove the progress item from the list.
84
- setProgressItems(
85
- prev => prev.filter(item => item.file !== e.data.file)
86
- );
87
- break;
88
-
89
- case 'loaded':
90
- // Pipeline ready: the worker is ready to accept messages.
91
- setStatus('ready');
92
- break;
93
-
94
- case 'complete':
95
- setResult(e.data.result);
96
- setTime(e.data.time);
97
- setStatus('ready');
98
- break;
99
- }
100
- };
101
-
102
- // Attach the callback function as an event listener.
103
- worker.current.addEventListener('message', onMessageReceived);
104
-
105
- // Define a cleanup function for when the component is unmounted.
106
- return () => {
107
- worker.current.removeEventListener('message', onMessageReceived);
108
- };
109
- }, []);
110
-
111
- const handleClick = useCallback(() => {
112
- setResult(null);
113
- setTime(null);
114
- if (status === null) {
115
- setStatus('loading');
116
- worker.current.postMessage({ type: 'load', data: { device } });
117
- } else {
118
- setStatus('running');
119
- worker.current.postMessage({
120
- type: 'run', data: { audio, language }
121
- });
122
- }
123
- }, [status, audio, language, device]);
124
-
125
- return (
126
- <div className="flex flex-col h-screen mx-auto text-gray-800 dark:text-gray-200 bg-white dark:bg-gray-900 max-w-[600px]">
127
-
128
- {status === 'loading' && (
129
- <div className="flex justify-center items-center fixed w-screen h-screen bg-black z-10 bg-opacity-[92%] top-0 left-0">
130
- <div className="w-[500px]">
131
- <p className="text-center mb-1 text-white text-md">{loadingMessage}</p>
132
- {progressItems.map(({ file, progress, total }, i) => (
133
- <Progress key={i} text={file} percentage={progress} total={total} />
134
- ))}
135
- </div>
136
- </div>
137
- )}
138
- <div className="my-auto">
139
- <div className="flex flex-col items-center mb-2 text-center">
140
- <h1 className="text-5xl font-bold mb-2">Whisper Diarization</h1>
141
- <h2 className="text-xl font-semibold">In-browser automatic speech recognition w/ <br />word-level timestamps and speaker segmentation</h2>
142
- </div>
143
-
144
- <div className="w-full min-h-[220px] flex flex-col justify-center items-center">
145
- {
146
- !audio && (
147
- <p className="mb-2">
148
- You are about to download <a href="https://huggingface.co/onnx-community/whisper-base_timestamped" target="_blank" rel="noreferrer" className="font-medium underline">whisper-base</a> and <a href="https://huggingface.co/onnx-community/pyannote-segmentation-3.0" target="_blank" rel="noreferrer" className="font-medium underline">pyannote-segmentation-3.0</a>,
149
- two powerful speech recognition models for generating word-level timestamps across 100 different languages and speaker segmentation, respectively.
150
- Once loaded, the models ({modelSize}MB + 6MB) will be cached and reused when you revisit the page.<br />
151
- <br />
152
- Everything runs locally in your browser using <a href="https://huggingface.co/docs/transformers.js" target="_blank" rel="noreferrer" className="underline">🤗&nbsp;Transformers.js</a> and ONNX Runtime Web,
153
- meaning no API calls are made to a server for inference. You can even disconnect from the internet after the model has loaded!
154
- </p>
155
- )
156
- }
157
-
158
- <div className="flex flex-col w-full m-3 max-w-[520px]">
159
- <span className="text-sm mb-0.5">Input audio/video</span>
160
- <MediaInput
161
- ref={mediaInputRef}
162
- className="flex items-center border rounded-md cursor-pointer min-h-[100px] max-h-[500px] overflow-hidden"
163
- onInputChange={(audio) => {
164
- setResult(null);
165
- setAudio(audio);
166
- }}
167
- onTimeUpdate={(time) => setCurrentTime(time)}
168
- />
169
- </div>
170
-
171
- <div className="relative w-full flex justify-center items-center">
172
- <button
173
- className="border px-4 py-2 rounded-lg bg-blue-400 text-white hover:bg-blue-500 disabled:bg-blue-100 disabled:cursor-not-allowed select-none"
174
- onClick={handleClick}
175
- disabled={status === 'running' || (status !== null && audio === null)}
176
- >
177
- {status === null ? 'Load model' :
178
- status === 'running'
179
- ? 'Running...'
180
- : 'Run model'
181
- }
182
- </button>
183
-
184
- {status !== null &&
185
- <div className='absolute right-0 bottom-0'>
186
- <span className="text-xs">Language:</span>
187
- <br />
188
- <LanguageSelector className="border rounded-lg p-1 max-w-[100px]" language={language} setLanguage={setLanguage} />
189
- </div>
190
- }
191
- </div>
192
-
193
- {
194
- result && time && (
195
- <>
196
- <div className="w-full mt-4 border rounded-md">
197
- <Transcript
198
- className="p-2 max-h-[200px] overflow-y-auto scrollbar-thin select-none"
199
- transcript={result.transcript}
200
- segments={result.segments}
201
- currentTime={currentTime}
202
- setCurrentTime={(time) => {
203
- setCurrentTime(time);
204
- mediaInputRef.current.setMediaTime(time);
205
- }}
206
- />
207
- </div>
208
- <p className="text-sm text-gray-600 text-end p-1">Generation time: <span className="text-gray-800 font-semibold">{time.toFixed(2)}ms</span></p>
209
- </>
210
- )
211
- }
212
- </div>
213
- </div>
214
- </div >
215
- )
216
- }
217
-
218
- export default App
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useEffect, useState, useRef, useCallback } from 'react';
2
+
3
+ import Progress from './components/Progress';
4
+ import MediaInput from './components/MediaInput';
5
+ import Transcript from './components/Transcript';
6
+ import LanguageSelector from './components/LanguageSelector';
7
+
8
+
9
+ async function hasWebGPU() {
10
+ if (!navigator.gpu) {
11
+ return false;
12
+ }
13
+ try {
14
+ const adapter = await navigator.gpu.requestAdapter();
15
+ return !!adapter;
16
+ } catch (e) {
17
+ return false;
18
+ }
19
+ }
20
+
21
+ function App() {
22
+
23
+ // Create a reference to the worker object.
24
+ const worker = useRef(null);
25
+
26
+ // Model loading and progress
27
+ const [status, setStatus] = useState(null);
28
+ const [loadingMessage, setLoadingMessage] = useState('');
29
+ const [progressItems, setProgressItems] = useState([]);
30
+
31
+ const mediaInputRef = useRef(null);
32
+ const [audio, setAudio] = useState(null);
33
+ const [language, setLanguage] = useState('en');
34
+
35
+ const [result, setResult] = useState(null);
36
+ const [time, setTime] = useState(null);
37
+ const [audioLength, setAudioLength] = useState(null);
38
+ const [currentTime, setCurrentTime] = useState(0);
39
+
40
+ const [device, setDevice] = useState('webgpu'); // Try use WebGPU first
41
+ const [modelSize, setModelSize] = useState('gpu' in navigator ? 196 : 77); // WebGPU=196MB, WebAssembly=77MB
42
+ useEffect(() => {
43
+ hasWebGPU().then((b) => {
44
+ setModelSize(b ? 196 : 77);
45
+ setDevice(b ? 'webgpu' : 'wasm');
46
+ });
47
+ }, []);
48
+
49
+ // Create a callback function for messages from the worker thread.
50
+ const onMessageReceived = (e) => {
51
+ switch (e.data.status) {
52
+ case 'loading':
53
+ // Model file start load: add a new progress item to the list.
54
+ setStatus('loading');
55
+ setLoadingMessage(e.data.data);
56
+ break;
57
+
58
+ case 'initiate':
59
+ setProgressItems(prev => [...prev, e.data]);
60
+ break;
61
+
62
+ case 'progress':
63
+ // Model file progress: update one of the progress items.
64
+ setProgressItems(
65
+ prev => prev.map(item => {
66
+ if (item.file === e.data.file) {
67
+ return { ...item, ...e.data }
68
+ }
69
+ return item;
70
+ })
71
+ );
72
+ break;
73
+
74
+ case 'done':
75
+ // Model file loaded: remove the progress item from the list.
76
+ setProgressItems(
77
+ prev => prev.filter(item => item.file !== e.data.file)
78
+ );
79
+ break;
80
+
81
+ case 'loaded':
82
+ // Pipeline ready: the worker is ready to accept messages.
83
+ setStatus('ready');
84
+ break;
85
+
86
+ case 'transcribe-progress': {
87
+ // Update progress for transcription/diarization
88
+ const { task, progress, total } = e.data.data;
89
+ setProgressItems(prev => {
90
+ const existingIndex = prev.findIndex(item => item.file === task);
91
+ if (existingIndex >= 0) {
92
+ return prev.map((item, i) =>
93
+ i === existingIndex ? { ...item, progress, total } : item
94
+ );
95
+ }
96
+ const newItem = { file: task, progress, total };
97
+ return [...prev, newItem];
98
+ });
99
+ break;
100
+ }
101
+
102
+ case 'complete':
103
+ setResult(e.data.result);
104
+ setTime(e.data.time);
105
+ setAudioLength(e.data.audio_length);
106
+ setStatus('ready');
107
+ break;
108
+ }
109
+ };
110
+
111
+ // We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
112
+ useEffect(() => {
113
+ if (!worker.current) {
114
+ // Create the worker if it does not yet exist.
115
+ worker.current = new Worker(new URL('./worker.js', import.meta.url), {
116
+ type: 'module'
117
+ });
118
+ }
119
+
120
+ // Attach the callback function as an event listener.
121
+ worker.current.addEventListener('message', onMessageReceived);
122
+
123
+ // Define a cleanup function for when the component is unmounted.
124
+ return () => {
125
+ worker.current.removeEventListener('message', onMessageReceived);
126
+ };
127
+ }, []);
128
+
129
+ const handleClick = useCallback(() => {
130
+ setResult(null);
131
+ setTime(null);
132
+ if (status === null) {
133
+ setStatus('loading');
134
+ worker.current.postMessage({ type: 'load', data: { device } });
135
+ } else {
136
+ setStatus('running');
137
+ worker.current.postMessage({
138
+ type: 'run', data: { audio, language }
139
+ });
140
+ }
141
+ }, [status, audio, language, device]);
142
+
143
+ return (
144
+ <div className="flex flex-col h-screen mx-auto text-gray-800 dark:text-gray-200 bg-white dark:bg-gray-900 max-w-[600px]">
145
+
146
+ {(status === 'loading' || status === 'running') && (
147
+ <div className="flex justify-center items-center fixed w-screen h-screen bg-black z-10 bg-opacity-[92%] top-0 left-0">
148
+ <div className="w-[500px]">
149
+ <p className="text-center mb-1 text-white text-md">{loadingMessage}</p>
150
+ {progressItems
151
+ .sort((a, b) => {
152
+ // Define the order: transcription -> segmentation -> diarization
153
+ const order = { 'transcription': 0, 'segmentation': 1, 'diarization': 2 };
154
+ return (order[a.file] ?? 3) - (order[b.file] ?? 3);
155
+ })
156
+ .map(({ file, progress, total }, i) => (
157
+ <Progress
158
+ key={i}
159
+ text={file === 'transcription' ? 'Converting speech to text' :
160
+ file === 'segmentation' ? 'Detecting word timestamps' :
161
+ file === 'diarization' ? 'Identifying speakers' :
162
+ file}
163
+ percentage={progress}
164
+ total={total}
165
+ />
166
+ ))
167
+ }
168
+ </div>
169
+ </div>
170
+ )}
171
+ <div className="my-auto">
172
+ <div className="flex flex-col items-center mb-2 text-center">
173
+ <h1 className="text-5xl font-bold mb-2">Whisper Diarization</h1>
174
+ <h2 className="text-xl font-semibold">In-browser automatic speech recognition w/ <br />word-level timestamps and speaker segmentation</h2>
175
+ </div>
176
+
177
+ <div className="w-full min-h-[220px] flex flex-col justify-center items-center">
178
+ {
179
+ !audio && (
180
+ <p className="mb-2">
181
+ You are about to download <a href="https://huggingface.co/onnx-community/whisper-base_timestamped" target="_blank" rel="noreferrer" className="font-medium underline">whisper-base</a> and <a href="https://huggingface.co/onnx-community/pyannote-segmentation-3.0" target="_blank" rel="noreferrer" className="font-medium underline">pyannote-segmentation-3.0</a>,
182
+ two powerful speech recognition models for generating word-level timestamps across 100 different languages and speaker segmentation, respectively.
183
+ Once loaded, the models ({modelSize}MB + 6MB) will be cached and reused when you revisit the page.<br />
184
+ <br />
185
+ Everything runs locally in your browser using <a href="https://huggingface.co/docs/transformers.js" target="_blank" rel="noreferrer" className="underline">🤗&nbsp;Transformers.js</a> and ONNX Runtime Web,
186
+ meaning no API calls are made to a server for inference. You can even disconnect from the internet after the model has loaded!
187
+ </p>
188
+ )
189
+ }
190
+
191
+ <div className="flex flex-col w-full m-3 max-w-[520px]">
192
+ <span className="text-sm mb-0.5">Input audio/video</span>
193
+ <MediaInput
194
+ ref={mediaInputRef}
195
+ className="flex items-center border rounded-md cursor-pointer min-h-[100px] max-h-[500px] overflow-hidden"
196
+ onInputChange={(audio) => {
197
+ setResult(null);
198
+ setAudio(audio);
199
+ }}
200
+ onTimeUpdate={(time) => setCurrentTime(time)}
201
+ onMessage={onMessageReceived}
202
+ />
203
+ </div>
204
+
205
+ <div className="relative w-full flex justify-center items-center">
206
+ <button
207
+ className="border px-4 py-2 rounded-lg bg-blue-400 text-white hover:bg-blue-500 disabled:bg-blue-100 disabled:cursor-not-allowed select-none"
208
+ onClick={handleClick}
209
+ disabled={status === 'running' || (status !== null && audio === null)}
210
+ >
211
+ {status === null ? 'Load model' :
212
+ status === 'running'
213
+ ? 'Running...'
214
+ : 'Run model'
215
+ }
216
+ </button>
217
+
218
+ {status !== null &&
219
+ <div className='absolute right-0 bottom-0'>
220
+ <span className="text-xs">Language:</span>
221
+ <br />
222
+ <LanguageSelector className="border rounded-lg p-1 max-w-[100px]" language={language} setLanguage={setLanguage} />
223
+ </div>
224
+ }
225
+ </div>
226
+
227
+ {
228
+ result && time && (
229
+ <>
230
+ <div className="w-full mt-4 border rounded-md">
231
+ <Transcript
232
+ className="p-2 max-h-[200px] overflow-y-auto scrollbar-thin select-none"
233
+ transcript={result.transcript}
234
+ segments={result.segments}
235
+ currentTime={currentTime}
236
+ setCurrentTime={(time) => {
237
+ setCurrentTime(time);
238
+ mediaInputRef.current.setMediaTime(time);
239
+ }}
240
+ />
241
+ </div>
242
+ <p className="text-sm text-end p-1">Generation time:
243
+ <span className="font-semibold">{(time / 1000).toLocaleString()} s</span>
244
+ </p>
245
+ <p className="text-sm text-end p-1">
246
+ <span className="font-semibold">{(audioLength / (time / 1000)).toFixed(2)}x transcription!</span>
247
+ </p>
248
+ </>
249
+ )
250
+ }
251
+ </div>
252
+ </div>
253
+ </div >
254
+ )
255
+ }
256
+
257
+ export default App
whisper-speaker-diarization/src/components/MediaInput.jsx CHANGED
@@ -1,8 +1,8 @@
1
  import { useState, forwardRef, useRef, useImperativeHandle, useEffect, useCallback } from 'react';
2
-
3
  const EXAMPLE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/hopper.webm';
4
 
5
- const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, ...props }, ref) => {
6
  // UI states
7
  const [dragging, setDragging] = useState(false);
8
  const fileInputRef = useRef(null);
@@ -89,7 +89,40 @@ const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, ...props }, ref) =
89
  const audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16_000 });
90
 
91
  try {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  const audioBuffer = await audioContext.decodeAudioData(buffer);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  let audio;
94
  if (audioBuffer.numberOfChannels === 2) {
95
  // Merge channels
@@ -145,8 +178,8 @@ const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, ...props }, ref) =
145
  onClick={handleClick}
146
  onDragOver={handleDragOver}
147
  onDrop={handleDrop}
148
- onDragEnter={(e) => setDragging(true)}
149
- onDragLeave={(e) => setDragging(false)}
150
  >
151
  <input
152
  type="file"
@@ -189,6 +222,13 @@ const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, ...props }, ref) =
189
  </div>
190
  );
191
  });
 
 
 
 
 
 
 
192
  MediaInput.displayName = 'MediaInput';
193
 
194
  export default MediaInput;
 
1
  import { useState, forwardRef, useRef, useImperativeHandle, useEffect, useCallback } from 'react';
2
+ import PropTypes from 'prop-types';
3
  const EXAMPLE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/hopper.webm';
4
 
5
+ const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, onMessage, ...props }, ref) => {
6
  // UI states
7
  const [dragging, setDragging] = useState(false);
8
  const fileInputRef = useRef(null);
 
89
  const audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16_000 });
90
 
91
  try {
92
+ // Start audio decoding
93
+ onMessage({
94
+ data: {
95
+ status: 'loading',
96
+ data: 'Decoding audio buffer...'
97
+ }
98
+ });
99
+
100
+ onMessage({
101
+ data: {
102
+ status: 'initiate',
103
+ name: 'audio-decoder',
104
+ file: 'audio-buffer'
105
+ }
106
+ });
107
+
108
  const audioBuffer = await audioContext.decodeAudioData(buffer);
109
+
110
+ // Audio decoding complete
111
+ onMessage({
112
+ data: {
113
+ status: 'done',
114
+ name: 'audio-decoder',
115
+ file: 'audio-buffer'
116
+ }
117
+ });
118
+
119
+ // Audio decoding complete
120
+ onMessage({
121
+ data: {
122
+ status: 'loaded'
123
+ }
124
+ });
125
+
126
  let audio;
127
  if (audioBuffer.numberOfChannels === 2) {
128
  // Merge channels
 
178
  onClick={handleClick}
179
  onDragOver={handleDragOver}
180
  onDrop={handleDrop}
181
+ onDragEnter={() => setDragging(true)}
182
+ onDragLeave={() => setDragging(false)}
183
  >
184
  <input
185
  type="file"
 
222
  </div>
223
  );
224
  });
225
+
226
+ MediaInput.propTypes = {
227
+ onInputChange: PropTypes.func.isRequired,
228
+ onTimeUpdate: PropTypes.func.isRequired,
229
+ onMessage: PropTypes.func.isRequired
230
+ };
231
+
232
  MediaInput.displayName = 'MediaInput';
233
 
234
  export default MediaInput;
whisper-speaker-diarization/src/components/Progress.jsx CHANGED
@@ -1,3 +1,5 @@
 
 
1
  function formatBytes(size) {
2
  const i = size == 0 ? 0 : Math.floor(Math.log(size) / Math.log(1024));
3
  return +((size / Math.pow(1024, i)).toFixed(2)) * 1 + ['B', 'kB', 'MB', 'GB', 'TB'][i];
@@ -13,3 +15,9 @@ export default function Progress({ text, percentage, total }) {
13
  </div>
14
  );
15
  }
 
 
 
 
 
 
 
1
+ import PropTypes from 'prop-types';
2
+
3
  function formatBytes(size) {
4
  const i = size == 0 ? 0 : Math.floor(Math.log(size) / Math.log(1024));
5
  return +((size / Math.pow(1024, i)).toFixed(2)) * 1 + ['B', 'kB', 'MB', 'GB', 'TB'][i];
 
15
  </div>
16
  );
17
  }
18
+
19
+ Progress.propTypes = {
20
+ text: PropTypes.string.isRequired,
21
+ percentage: PropTypes.number,
22
+ total: PropTypes.number
23
+ };
whisper-speaker-diarization/src/worker.js CHANGED
@@ -1,124 +1,272 @@
1
-
2
- import { pipeline, AutoProcessor, AutoModelForAudioFrameClassification } from '@xenova/transformers';
3
-
4
- const PER_DEVICE_CONFIG = {
5
- webgpu: {
6
- dtype: {
7
- encoder_model: 'fp32',
8
- decoder_model_merged: 'q4',
9
- },
10
- device: 'webgpu',
11
- },
12
- wasm: {
13
- dtype: 'q8',
14
- device: 'wasm',
15
- },
16
- };
17
-
18
- /**
19
- * This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
20
- */
21
- class PipelineSingeton {
22
- static asr_model_id = 'onnx-community/whisper-base_timestamped';
23
- static asr_instance = null;
24
-
25
- static segmentation_model_id = 'onnx-community/pyannote-segmentation-3.0';
26
- static segmentation_instance = null;
27
- static segmentation_processor = null;
28
-
29
- static async getInstance(progress_callback = null, device = 'webgpu') {
30
- this.asr_instance ??= pipeline('automatic-speech-recognition', this.asr_model_id, {
31
- ...PER_DEVICE_CONFIG[device],
32
- progress_callback,
33
- });
34
-
35
- this.segmentation_processor ??= AutoProcessor.from_pretrained(this.segmentation_model_id, {
36
- progress_callback,
37
- });
38
- this.segmentation_instance ??= AutoModelForAudioFrameClassification.from_pretrained(this.segmentation_model_id, {
39
- // NOTE: WebGPU is not currently supported for this model
40
- // See https://github.com/microsoft/onnxruntime/issues/21386
41
- device: 'wasm',
42
- dtype: 'fp32',
43
- progress_callback,
44
- });
45
-
46
- return Promise.all([this.asr_instance, this.segmentation_processor, this.segmentation_instance]);
47
- }
48
- }
49
-
50
- async function load({ device }) {
51
- self.postMessage({
52
- status: 'loading',
53
- data: `Loading models (${device})...`
54
- });
55
-
56
- // Load the pipeline and save it for future use.
57
- const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance(x => {
58
- // We also add a progress callback to the pipeline so that we can
59
- // track model loading.
60
- self.postMessage(x);
61
- }, device);
62
-
63
- if (device === 'webgpu') {
64
- self.postMessage({
65
- status: 'loading',
66
- data: 'Compiling shaders and warming up model...'
67
- });
68
-
69
- await transcriber(new Float32Array(16_000), {
70
- language: 'en',
71
- });
72
- }
73
-
74
- self.postMessage({ status: 'loaded' });
75
- }
76
-
77
- async function segment(processor, model, audio) {
78
- const inputs = await processor(audio);
79
- const { logits } = await model(inputs);
80
- const segments = processor.post_process_speaker_diarization(logits, audio.length)[0];
81
-
82
- // Attach labels
83
- for (const segment of segments) {
84
- segment.label = model.config.id2label[segment.id];
85
- }
86
-
87
- return segments;
88
- }
89
-
90
- async function run({ audio, language }) {
91
- const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance();
92
-
93
- const start = performance.now();
94
-
95
- // Run transcription and segmentation in parallel
96
- const [transcript, segments] = await Promise.all([
97
- transcriber(audio, {
98
- language,
99
- return_timestamps: 'word',
100
- chunk_length_s: 30,
101
- }),
102
- segment(segmentation_processor, segmentation_model, audio)
103
- ]);
104
- console.table(segments, ['start', 'end', 'id', 'label', 'confidence']);
105
-
106
- const end = performance.now();
107
-
108
- self.postMessage({ status: 'complete', result: { transcript, segments }, time: end - start });
109
- }
110
-
111
- // Listen for messages from the main thread
112
- self.addEventListener('message', async (e) => {
113
- const { type, data } = e.data;
114
-
115
- switch (type) {
116
- case 'load':
117
- load(data);
118
- break;
119
-
120
- case 'run':
121
- run(data);
122
- break;
123
- }
124
- });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import { pipeline, AutoProcessor, AutoModelForAudioFrameClassification } from '@huggingface/transformers';
3
+
4
+ const PER_DEVICE_CONFIG = {
5
+ webgpu: {
6
+ dtype: {
7
+ encoder_model: 'fp32',
8
+ decoder_model_merged: 'q4',
9
+ },
10
+ device: 'webgpu',
11
+ },
12
+ wasm: {
13
+ dtype: 'q8',
14
+ device: 'wasm',
15
+ },
16
+ };
17
+
18
+ /**
19
+ * This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
20
+ */
21
+ class PipelineSingeton {
22
+ static asr_model_id = 'onnx-community/whisper-base_timestamped';
23
+ static asr_instance = null;
24
+
25
+ static segmentation_model_id = 'onnx-community/pyannote-segmentation-3.0';
26
+ static segmentation_instance = null;
27
+ static segmentation_processor = null;
28
+
29
+ static async getInstance(progress_callback = null, device = 'webgpu') {
30
+ this.asr_instance ??= pipeline('automatic-speech-recognition', this.asr_model_id, {
31
+ ...PER_DEVICE_CONFIG[device],
32
+ progress_callback,
33
+ });
34
+
35
+ this.segmentation_processor ??= AutoProcessor.from_pretrained(this.segmentation_model_id, {
36
+ progress_callback,
37
+ });
38
+ this.segmentation_instance ??= AutoModelForAudioFrameClassification.from_pretrained(this.segmentation_model_id, {
39
+ // NOTE: WebGPU is not currently supported for this model
40
+ // See https://github.com/microsoft/onnxruntime/issues/21386
41
+ device: 'wasm',
42
+ dtype: 'fp32',
43
+ progress_callback,
44
+ });
45
+
46
+ return Promise.all([this.asr_instance, this.segmentation_processor, this.segmentation_instance]);
47
+ }
48
+ }
49
+
50
+ async function load({ device }) {
51
+ try {
52
+ const message = {
53
+ status: 'loading',
54
+ data: `Loading models (${device})...`
55
+ };
56
+ self.postMessage(message);
57
+
58
+ const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance(x => {
59
+ // We also add a progress callback to the pipeline so that we can
60
+ // track model loading.
61
+ self.postMessage(x);
62
+ }, device);
63
+
64
+ if (device === 'webgpu') {
65
+ const warmupMessage = {
66
+ status: 'loading',
67
+ data: 'Compiling shaders and warming up model...'
68
+ };
69
+
70
+ self.postMessage(warmupMessage);
71
+
72
+ await transcriber(new Float32Array(16_000), {
73
+ language: 'en',
74
+ });
75
+ }
76
+
77
+ self.postMessage({ status: 'loaded' });
78
+ } catch (error) {
79
+ console.error('Loading error:', error);
80
+ const errorMessage = {
81
+ status: 'error',
82
+ error: error.message || 'Failed to load models'
83
+ };
84
+ self.postMessage(errorMessage);
85
+ }
86
+ }
87
+
88
+ async function segment(processor, model, audio) {
89
+ try {
90
+ // Report start of segmentation
91
+ self.postMessage({
92
+ status: 'transcribe-progress',
93
+ data: {
94
+ task: 'segmentation',
95
+ progress: 0,
96
+ total: audio.length
97
+ }
98
+ });
99
+
100
+ // Process audio in chunks to show progress
101
+ const inputs = await processor(audio);
102
+
103
+ // Report segmentation feature extraction progress
104
+ self.postMessage({
105
+ status: 'transcribe-progress',
106
+ data: {
107
+ task: 'segmentation',
108
+ progress: 50,
109
+ total: audio.length
110
+ }
111
+ });
112
+
113
+ const { logits } = await model(inputs);
114
+
115
+ // Report segmentation completion
116
+ self.postMessage({
117
+ status: 'transcribe-progress',
118
+ data: {
119
+ task: 'segmentation',
120
+ progress: 100,
121
+ total: audio.length
122
+ }
123
+ });
124
+
125
+ // Start diarization
126
+ self.postMessage({
127
+ status: 'transcribe-progress',
128
+ data: {
129
+ task: 'diarization',
130
+ progress: 0,
131
+ total: audio.length
132
+ }
133
+ });
134
+
135
+ const segments = processor.post_process_speaker_diarization(logits, audio.length)[0];
136
+
137
+ // Attach labels and report diarization completion
138
+ for (const segment of segments) {
139
+ segment.label = model.config.id2label[segment.id];
140
+ }
141
+
142
+ self.postMessage({
143
+ status: 'transcribe-progress',
144
+ data: {
145
+ task: 'diarization',
146
+ progress: 100,
147
+ total: audio.length
148
+ }
149
+ });
150
+
151
+ return segments;
152
+ } catch (error) {
153
+ console.error('Segmentation error:', error);
154
+ return [{
155
+ id: 0,
156
+ start: 0,
157
+ end: (audio.length / 480016) * 30,
158
+ label: 'SPEAKER_00',
159
+ confidence: 1.0
160
+ }];
161
+ }
162
+ }
163
+
164
+ async function run({ audio, language }) {
165
+ try {
166
+ const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance();
167
+
168
+ const audioLengthSeconds = (audio.length / 16000);
169
+
170
+ // Initialize transcription progress
171
+ self.postMessage({
172
+ status: 'transcribe-progress',
173
+ data: {
174
+ task: 'transcription',
175
+ progress: 0,
176
+ total: audio.length
177
+ }
178
+ });
179
+
180
+ const start = performance.now();
181
+ // Process in 30-second chunks
182
+ const CHUNK_SIZE = 3 * 30 * 16000; // 30 seconds * 16000 samples/second
183
+ const numChunks = Math.ceil(audio.length / CHUNK_SIZE);
184
+ let transcriptResults = [];
185
+
186
+ for (let i = 0; i < numChunks; i++) {
187
+ const start = i * CHUNK_SIZE;
188
+ const end = Math.min((i + 1) * CHUNK_SIZE, audio.length);
189
+ const chunk = audio.slice(start, end);
190
+
191
+ // Process chunk
192
+ const chunkResult = await transcriber(chunk, {
193
+ language,
194
+ return_timestamps: 'word',
195
+ chunk_length_s: 30,
196
+ });
197
+ const progressMessage = {
198
+ status: 'transcribe-progress',
199
+ data: {
200
+ task: 'transcription',
201
+ progress: Math.round((i+1) / numChunks * 100),
202
+ total: audio.length
203
+ }
204
+ };
205
+ self.postMessage(progressMessage);
206
+
207
+
208
+ // Adjust timestamps for this chunk
209
+ if (chunkResult.chunks) {
210
+ chunkResult.chunks.forEach(chunk => {
211
+ if (chunk.timestamp) {
212
+ chunk.timestamp[0] += start / 16000; // Convert samples to seconds
213
+ chunk.timestamp[1] += start / 16000;
214
+ }
215
+ });
216
+ }
217
+
218
+ transcriptResults.push(chunkResult);
219
+ }
220
+
221
+ // Combine results
222
+ const transcript = {
223
+ text: transcriptResults.map(r => r.text).join(''),
224
+ chunks: transcriptResults.flatMap(r => r.chunks || [])
225
+ };
226
+
227
+ // Run segmentation in parallel with the last chunk
228
+ const segments = await segment(segmentation_processor, segmentation_model, audio);
229
+
230
+ // Ensure transcription shows as complete
231
+ self.postMessage({
232
+ status: 'transcribe-progress',
233
+ data: {
234
+ task: 'transcription',
235
+ progress: 100,
236
+ total: audio.length
237
+ }
238
+ });
239
+ const end = performance.now();
240
+
241
+ const completeMessage = {
242
+ status: 'complete',
243
+ result: { transcript, segments },
244
+ audio_length: audioLengthSeconds,
245
+ time: end - start
246
+ };
247
+ self.postMessage(completeMessage);
248
+ } catch (error) {
249
+ console.error('Processing error:', error);
250
+ const errorMessage = {
251
+ status: 'error',
252
+ error: error.message || 'Failed to process audio'
253
+ };
254
+ console.log('Worker sending error:', errorMessage);
255
+ self.postMessage(errorMessage);
256
+ }
257
+ }
258
+
259
+ // Listen for messages from the main thread
260
+ self.addEventListener('message', async (e) => {
261
+ const { type, data } = e.data;
262
+
263
+ switch (type) {
264
+ case 'load':
265
+ load(data);
266
+ break;
267
+
268
+ case 'run':
269
+ run(data);
270
+ break;
271
+ }
272
+ });