Porjaz commited on
Commit
ca8ee1d
·
verified ·
1 Parent(s): 81fa3cc

Update custom_interface.py

Browse files
Files changed (1) hide show
  1. custom_interface.py +64 -0
custom_interface.py CHANGED
@@ -156,3 +156,67 @@ class ASR(Pretrained):
156
  rel_length = torch.tensor([1.0]).to(device)
157
  outputs = self.encode_batch(device, batch, rel_length)
158
  yield outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  rel_length = torch.tensor([1.0]).to(device)
157
  outputs = self.encode_batch(device, batch, rel_length)
158
  yield outputs
159
+
160
+
161
+ def classify_file_whisper(self, path, pipe, device):
162
+ waveform, sr = librosa.load(path, sr=16000)
163
+ transcription = pipe(waveform, generate_kwargs={"language": "macedonian"})["text"]
164
+ return transcription
165
+
166
+
167
+ def classify_file_mms(self, path, processor, model, device):
168
+ # Load the audio file
169
+ waveform, sr = librosa.load(path, sr=16000)
170
+
171
+ # Get audio length in seconds
172
+ audio_length = len(waveform) / sr
173
+
174
+ if audio_length >= 20:
175
+ print(f"MMS Audio is too long ({audio_length:.2f} seconds), splitting into segments")
176
+ # Detect non-silent segments
177
+ non_silent_intervals = librosa.effects.split(waveform, top_db=20) # Adjust top_db for sensitivity
178
+
179
+ segments = []
180
+ current_segment = []
181
+ current_length = 0
182
+ max_duration = 20 * sr # Maximum segment duration in samples (20 seconds)
183
+
184
+
185
+ for interval in non_silent_intervals:
186
+ start, end = interval
187
+ segment_part = waveform[start:end]
188
+
189
+ # If adding the next part exceeds max duration, store the segment and start a new one
190
+ if current_length + len(segment_part) > max_duration:
191
+ segments.append(np.concatenate(current_segment))
192
+ current_segment = []
193
+ current_length = 0
194
+
195
+ current_segment.append(segment_part)
196
+ current_length += len(segment_part)
197
+
198
+ # Append the last segment if it's not empty
199
+ if current_segment:
200
+ segments.append(np.concatenate(current_segment))
201
+
202
+ # Process each segment
203
+ outputs = []
204
+ for i, segment in enumerate(segments):
205
+ print(f"MMS Processing segment {i + 1}/{len(segments)}, length: {len(segment) / sr:.2f} seconds")
206
+
207
+ segment_tensor = torch.tensor(segment).to(device)
208
+
209
+ # Pass the segment through the ASR model
210
+ inputs = processor(segment_tensor, sampling_rate=16_000, return_tensors="pt").to(device)
211
+ outputs = model(**inputs).logits
212
+ ids = torch.argmax(outputs, dim=-1)[0]
213
+ segment_output = processor.decode(ids)
214
+ yield segment_output
215
+ else:
216
+ waveform = torch.tensor(waveform).to(device)
217
+ inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
218
+ outputs = model(**inputs).logits
219
+ ids = torch.argmax(outputs, dim=-1)[0]
220
+ transcription = processor.decode(ids)
221
+ yield transcription
222
+