File size: 9,285 Bytes
171bc27
 
641eeeb
163fe9c
171bc27
 
 
 
 
 
4e0c3e5
 
 
171bc27
 
722014e
747d7b4
4e0c3e5
722014e
4e0c3e5
171bc27
 
4e0c3e5
79892f2
 
 
4e3512b
 
2b0365b
 
 
 
195413c
722014e
171bc27
e298aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171bc27
9a5cdc0
 
 
4e0c3e5
9a5cdc0
4e0c3e5
 
 
 
9a5cdc0
 
 
4e0c3e5
 
9a5cdc0
 
67b38a5
c5e38e6
9a5cdc0
 
 
e6750e3
9a5cdc0
 
080cf2e
9a5cdc0
 
 
 
0758244
9a5cdc0
 
 
 
 
 
 
 
 
 
c5e38e6
9a5cdc0
 
 
 
 
 
 
 
 
 
 
 
4e0c3e5
9a5cdc0
 
4e0c3e5
 
9a5cdc0
 
4e0c3e5
 
9a5cdc0
4e0c3e5
 
9a5cdc0
 
4e0c3e5
 
 
ca8ee1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import torch
from speechbrain.inference.interfaces import Pretrained
import librosa
import numpy as np


class ASR(Pretrained):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def encode_batch(self, device, wavs, wav_lens=None, normalize=False):
        wavs = wavs.to(device)
        wav_lens = wav_lens.to(device)

        # Forward pass
        encoded_outputs = self.mods.encoder_w2v2(wavs.detach())
        # append
        tokens_bos = torch.zeros((wavs.size(0), 1), dtype=torch.long).to(device)
        embedded_tokens = self.mods.embedding(tokens_bos)
        decoder_outputs, _ = self.mods.decoder(embedded_tokens, encoded_outputs, wav_lens)

        # Output layer for seq2seq log-probabilities
        predictions = self.hparams.test_search(encoded_outputs, wav_lens)[0]
        # predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions]
        predicted_words = []
        for prediction in predictions:
            prediction = [token for token in prediction if token != 0]
            predicted_words.append(self.hparams.tokenizer.decode_ids(prediction).split(" "))
        prediction = []
        for sent in predicted_words:
            sent = self.filter_repetitions(sent, 3)
            prediction.append(sent)
        predicted_words = prediction
        return predicted_words

    def filter_repetitions(self, seq, max_repetition_length):
        seq = list(seq)
        output = []
        max_n = len(seq) // 2
        for n in range(max_n, 0, -1):
            max_repetitions = max(max_repetition_length // n, 1)
            # Don't need to iterate over impossible n values:
            # len(seq) can change a lot during iteration
            if (len(seq) <= n*2) or (len(seq) <= max_repetition_length):
                continue
            iterator = enumerate(seq)
            # Fill first buffers:
            buffers = [[next(iterator)[1]] for _ in range(n)]
            for seq_index, token in iterator:
                current_buffer = seq_index % n
                if token != buffers[current_buffer][-1]:
                    # No repeat, we can flush some tokens
                    buf_len = sum(map(len, buffers))
                    flush_start = (current_buffer-buf_len) % n
                    # Keep n-1 tokens, but possibly mark some for removal
                    for flush_index in range(buf_len - buf_len%n):
                        if (buf_len - flush_index) > n-1:
                            to_flush = buffers[(flush_index + flush_start) % n].pop(0)
                        else:
                            to_flush = None
                        # Here, repetitions get removed:
                        if (flush_index // n < max_repetitions) and to_flush is not None:
                            output.append(to_flush)
                        elif (flush_index // n >= max_repetitions) and to_flush is None:
                            output.append(to_flush)
                buffers[current_buffer].append(token)
            # At the end, final flush
            current_buffer += 1
            buf_len = sum(map(len, buffers))
            flush_start = (current_buffer-buf_len) % n
            for flush_index in range(buf_len):
                to_flush = buffers[(flush_index + flush_start) % n].pop(0)
                # Here, repetitions just get removed:
                if flush_index // n < max_repetitions:
                    output.append(to_flush)
            seq = []
            to_delete = 0
            for token in output:
                if token is None:
                    to_delete += 1
                elif to_delete > 0:
                    to_delete -= 1
                else:
                    seq.append(token)
            output = []
        return seq
    

    # def classify_file(self, path):
    #     # waveform = self.load_audio(path)
    #     waveform, sr = librosa.load(path, sr=16000)
        # waveform = torch.tensor(waveform)

        # # Fake a batch:
        # batch = waveform.unsqueeze(0)
        # rel_length = torch.tensor([1.0])
        # outputs = self.encode_batch(batch, rel_length)
       
    #     return outputs


    def classify_file(self, path, device):
        # Load the audio file
        # path = "long_sample.wav"
        waveform, sr = librosa.load(path, sr=16000)

        # Get audio length in seconds
        audio_length = len(waveform) / sr
        
        if audio_length >= 20:
            print(f"Audio is too long ({audio_length:.2f} seconds), splitting into segments")
            # Detect non-silent segments
            non_silent_intervals = librosa.effects.split(waveform, top_db=20)  # Adjust top_db for sensitivity

            segments = []
            current_segment = []
            current_length = 0
            max_duration = 20 * sr  # Maximum segment duration in samples (20 seconds)

            for interval in non_silent_intervals:
                start, end = interval
                segment_part = waveform[start:end]

                # If adding the next part exceeds max duration, store the segment and start a new one
                if current_length + len(segment_part) > max_duration:
                    segments.append(np.concatenate(current_segment))
                    current_segment = []
                    current_length = 0

                current_segment.append(segment_part)
                current_length += len(segment_part)

            # Append the last segment if it's not empty
            if current_segment:
                segments.append(np.concatenate(current_segment))

            # Process each segment
            outputs = []
            for i, segment in enumerate(segments):
                print(f"Processing segment {i + 1}/{len(segments)}, length: {len(segment) / sr:.2f} seconds")

                segment_tensor = torch.tensor(segment).to(device)

                # Fake a batch for the segment
                batch = segment_tensor.unsqueeze(0).to(device)
                rel_length = torch.tensor([1.0]).to(device)  # Adjust if necessary

                # Pass the segment through the ASR model
                segment_output = self.encode_batch(device, batch, rel_length)
                yield segment_output
        else:
            waveform = torch.tensor(waveform).to(device)
            waveform = waveform.to(device)
            # Fake a batch:
            batch = waveform.unsqueeze(0)
            rel_length = torch.tensor([1.0]).to(device)
            outputs = self.encode_batch(device, batch, rel_length)
            yield outputs


    def classify_file_whisper(self, path, pipe, device):
        waveform, sr = librosa.load(path, sr=16000)
        transcription = pipe(waveform, generate_kwargs={"language": "macedonian"})["text"]
        return transcription
       

    def classify_file_mms(self, path, processor, model, device):
        # Load the audio file
        waveform, sr = librosa.load(path, sr=16000)

        # Get audio length in seconds
        audio_length = len(waveform) / sr
        
        if audio_length >= 20:
            print(f"MMS Audio is too long ({audio_length:.2f} seconds), splitting into segments")
            # Detect non-silent segments
            non_silent_intervals = librosa.effects.split(waveform, top_db=20)  # Adjust top_db for sensitivity

            segments = []
            current_segment = []
            current_length = 0
            max_duration = 20 * sr  # Maximum segment duration in samples (20 seconds)


            for interval in non_silent_intervals:
                start, end = interval
                segment_part = waveform[start:end]

                # If adding the next part exceeds max duration, store the segment and start a new one
                if current_length + len(segment_part) > max_duration:
                    segments.append(np.concatenate(current_segment))
                    current_segment = []
                    current_length = 0

                current_segment.append(segment_part)
                current_length += len(segment_part)

            # Append the last segment if it's not empty
            if current_segment:
                segments.append(np.concatenate(current_segment))

            # Process each segment
            outputs = []
            for i, segment in enumerate(segments):
                print(f"MMS Processing segment {i + 1}/{len(segments)}, length: {len(segment) / sr:.2f} seconds")

                segment_tensor = torch.tensor(segment).to(device)

                # Pass the segment through the ASR model
                inputs = processor(segment_tensor, sampling_rate=16_000, return_tensors="pt").to(device)
                outputs = model(**inputs).logits
                ids = torch.argmax(outputs, dim=-1)[0]
                segment_output = processor.decode(ids)
                yield segment_output
        else:
            waveform = torch.tensor(waveform).to(device)
            inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
            outputs = model(**inputs).logits
            ids = torch.argmax(outputs, dim=-1)[0]
            transcription = processor.decode(ids)
            yield transcription