clr commited on
Commit
424bfb6
·
1 Parent(s): 6f56357

Upload 2 files

Browse files
Files changed (2) hide show
  1. ctcalign.py +166 -169
  2. graph.py +3 -5
ctcalign.py CHANGED
@@ -4,108 +4,98 @@ import numpy as np
4
  from dataclasses import dataclass
5
 
6
 
 
 
7
 
 
 
 
 
8
 
9
- #convert frame-numbers to timestamps in seconds
10
- # w2v2 step size is about 20ms, or 50 frames per second
11
- def f2s(fr):
12
- return fr/50
13
 
14
-
15
- #------------------------------------------
16
- # setup wav2vec2
17
- #------------------------------------------
18
 
19
- # important to know for CTC decoding - potentially language/model dependent
20
- #model_word_separator = '|'
21
- #model_blank_token = '[PAD]'
22
- #is_MODEL_PATH="../models/LVL/wav2vec2-large-xlsr-53-icelandic-ep10-1000h"
23
 
 
 
 
 
24
 
25
 
26
- class CTCAligner:
27
-
28
- def __init__(self, model_path,model_word_separator, model_blank_token):
29
- #self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
- torch.random.manual_seed(0)
31
-
32
- self.model = Wav2Vec2ForCTC.from_pretrained(model_path)#.to(self.device)
33
- self.processor = Wav2Vec2Processor.from_pretrained(model_path)
34
-
35
- # build labels dict from a processor where it is not directly accessible
36
- max_labels = 100 # any reasonable number higher than vocab + extra + special tokens in any language used
37
- ixs = sorted(list(range(max_labels)),reverse=True)
38
- self.labels_dict = {self.processor.tokenizer.decode(n) or model_word_separator:n for n in ixs}
39
-
40
- self.blank_id = self.labels_dict[model_blank_token]
41
- self.model_word_separator = model_word_separator
42
 
43
 
 
 
 
 
 
44
 
45
 
 
 
 
 
 
 
 
 
 
46
 
47
- #------------------------------------------
48
- # forced alignment with ctc decoder
49
- # based on implementation of
50
- # https://pytorch.org/audio/main/tutorials/forced_alignment_tutorial.html
51
- #------------------------------------------
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- # return the label class probability of each audio frame
55
- # wav is the wav data already read in, NOT the file path.
56
- def get_frame_probs(wav,aligner):
57
- with torch.inference_mode(): # similar to with torch.no_grad():
58
- input_values = aligner.processor(wav,sampling_rate=16000).input_values[0]
59
- input_values = torch.tensor(input_values).unsqueeze(0)#, device=aligner.device).unsqueeze(0)
60
- emits = aligner.model(input_values).logits
61
- emits = torch.log_softmax(emits, dim=-1)
62
- return emits[0].cpu().detach()
63
 
64
 
65
- def get_trellis(emission, tokens, blank_id):
 
 
 
 
66
 
67
- num_frame = emission.size(0)
68
- num_tokens = len(tokens)
69
- trellis = torch.empty((num_frame + 1, num_tokens + 1))
70
- trellis[0, 0] = 0
71
- trellis[1:, 0] = torch.cumsum(emission[:, 0], 0) # len of this slice of trellis is len of audio frames)
72
- trellis[0, -num_tokens:] = -float("inf") # len of this slice of trellis is len of transcript tokens
73
- trellis[-num_tokens:, 0] = float("inf")
74
- for t in range(num_frame):
75
- trellis[t + 1, 1:] = torch.maximum(
76
- # Score for staying at the same token
77
- trellis[t, 1:] + emission[t, blank_id],
78
- # Score for changing to the next token
79
- trellis[t, :-1] + emission[t, tokens],
80
- )
81
- return trellis
82
-
83
-
84
-
85
- @dataclass
86
- class Point:
87
- token_index: int
88
- time_index: int
89
- score: float
90
 
91
- @dataclass
92
- class Segment:
93
- label: str
94
- start: int
95
- end: int
96
- score: float
97
-
98
- @property
99
- def mfaform(self):
100
- return f"{f2s(self.start)},{f2s(self.end)},{self.label}"
101
-
102
- @property
103
- def length(self):
104
- return self.end - self.start
105
 
106
 
107
-
108
- def backtrack(trellis, emission, tokens, blank_id):
109
  # Note:
110
  # j and t are indices for trellis, which has extra dimensions
111
  # for time and tokens at the beginning.
@@ -113,112 +103,119 @@ def backtrack(trellis, emission, tokens, blank_id):
113
  # the corresponding index in emission is `T-1`.
114
  # Similarly, when referring to token index `J` in trellis,
115
  # the corresponding index in transcript is `J-1`.
116
- j = trellis.size(1) - 1
117
- t_start = torch.argmax(trellis[:, j]).item()
118
 
119
- path = []
120
- for t in range(t_start, 0, -1):
121
  # 1. Figure out if the current position was stay or change
122
  # `emission[J-1]` is the emission at time frame `J` of trellis dimension.
123
  # Score for token staying the same from time frame J-1 to T.
124
- stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
125
  # Score for token changing from C-1 at T-1 to J at T.
126
- changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
127
 
128
  # 2. Store the path with frame-wise probability.
129
- prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
130
  # Return token index and time index in non-trellis coordinate.
131
- path.append(Point(j - 1, t - 1, prob))
132
-
133
- # 3. Update the token
134
- if changed > stayed:
135
- j -= 1
136
- if j == 0:
137
- break
138
- else:
139
- raise ValueError("Failed to align")
140
- return path[::-1]
141
-
142
-
143
- def merge_repeats(path,transcript):
144
- i1, i2 = 0, 0
145
- segments = []
146
- while i1 < len(path):
147
- while i2 < len(path) and path[i1].token_index == path[i2].token_index: # while both path steps point to the same token index
148
- i2 += 1
149
- score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
150
- segments.append( # when i2 finally switches to a different token,
151
- Segment(
152
- transcript[path[i1].token_index],# to the list of segments, append the token from i1
153
- path[i1].time_index, # time of the first path-point of that token
154
- path[i2 - 1].time_index + 1, # time of the final path-point for that token.
155
- score,
156
- )
157
- )
158
- i1 = i2
159
- return segments
160
-
161
-
162
-
163
- def merge_words(segments, separator):
164
- words = []
165
- i1, i2 = 0, 0
166
- while i1 < len(segments):
167
- if i2 >= len(segments) or segments[i2].label == separator:
168
- if i1 != i2:
169
- segs = segments[i1:i2]
170
- word = "".join([seg.label for seg in segs])
171
- score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
172
- words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
173
- i1 = i2 + 1
174
- i2 = i1
175
  else:
176
- i2 += 1
177
- return words
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
 
180
 
181
- #------------------------------------------
182
- # handle etc.
183
- #------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
 
186
- # generate mfa format for character (phone) and word alignments
187
- # skip the word separator as it is not a phone
188
- def mfalike(chars,wds,wsep):
189
- hed = ['Begin,End,Label,Type,Speaker\n']
190
- wlines = [f'{w.mfaform},words,000\n' for w in wds]
191
- slines = [f'{ch.mfaform},phones,000\n' for ch in chars if ch.label != wsep]
192
- return (''.join(hed+wlines+slines))
193
-
194
- # generate basic exportable list format for character OR word alignments
195
- # skip the word separator as it is not a phone
196
- def basic(segs,wsep="|"):
197
- return [[s.label,f2s(s.start),f2s(s.end)] for s in segs if s.label != wsep]
198
-
199
 
200
- # needs pad labels added to correctly time first segment
201
- # and therefore add word sep character as placeholder in transcript
202
- def prep_transcript(xcp, aligner):
203
- xcp = xcp.replace(' ', aligner.model_word_separator)
204
- label_ids = [aligner.labels_dict[c] for c in xcp]
205
- label_ids = [aligner.blank_id] + label_ids + [aligner.blank_id]
206
- xcp = f'{ aligner.model_word_separator}{xcp}{aligner.model_word_separator}'
207
- return xcp,label_ids
208
 
 
 
 
209
 
210
 
211
- def align(wav_data,transcript,aligner):
212
- norm_transcript,rec_label_ids = prep_transcript(transcript, aligner)
213
- emit = get_frame_probs(wav_data,aligner)
214
- trellis = get_trellis(emit, rec_label_ids, aligner.blank_id)
215
- path = backtrack(trellis, emit, rec_label_ids, aligner.blank_id)
 
 
216
 
217
- segments = merge_repeats(path,norm_transcript)
218
- words = merge_words(segments, aligner.model_word_separator)
 
 
219
 
220
- #segments = [s for s in segments if s[0] != model_word_separator]
221
- #return mfalike(segments,words,model_word_separator)
222
- return basic(words,aligner.model_word_separator), basic(segments,aligner.model_word_separator)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
 
 
4
  from dataclasses import dataclass
5
 
6
 
7
+
8
+ def aligner(model_path,model_word_separator = '|', model_blank_token = '[PAD]'):
9
 
10
+ # build labels dict from a processor where it is not directly accessible
11
+ def get_processor_labels(processor,word_sep,max_labels=100):
12
+ ixs = sorted(list(range(max_labels)),reverse=True)
13
+ return {processor.tokenizer.decode(n) or word_sep:n for n in ixs}
14
 
15
+ #------------------------------------------
16
+ # setup wav2vec2
17
+ #------------------------------------------
 
18
 
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ torch.random.manual_seed(0)
21
+ max_labels = 100 # any reasonable number higher than vocab + extra + special tokens in any language used
 
22
 
 
 
 
 
23
 
24
+ model = Wav2Vec2ForCTC.from_pretrained(model_path).to(device)
25
+ processor = Wav2Vec2Processor.from_pretrained(model_path)
26
+ labels_dict = get_processor_labels(processor,model_word_separator)
27
+ blank_id = labels_dict[model_blank_token]
28
 
29
 
30
+ #convert frame-numbers to timestamps in seconds
31
+ # w2v2 step size is about 20ms, or 50 frames per second
32
+ def f2s(fr):
33
+ return fr/50
34
+
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
+ #------------------------------------------
38
+ # forced alignment with ctc decoder
39
+ # based on implementation of
40
+ # https://pytorch.org/audio/main/tutorials/forced_alignment_tutorial.html
41
+ #------------------------------------------
42
 
43
 
44
+ # return the label class probability of each audio frame
45
+ # wav is the wav data already read in, NOT the file path.
46
+ def get_frame_probs(wav):
47
+ with torch.inference_mode(): # similar to with torch.no_grad():
48
+ input_values = processor(wav,sampling_rate=16000).input_values[0]
49
+ input_values = torch.tensor(input_values, device=device).unsqueeze(0)
50
+ emits = model(input_values).logits
51
+ emits = torch.log_softmax(emits, dim=-1)
52
+ return emits[0].cpu().detach()
53
 
 
 
 
 
 
54
 
55
+ def get_trellis(emission, tokens, blank_id):
56
+
57
+ num_frame = emission.size(0)
58
+ num_tokens = len(tokens)
59
+ trellis = torch.empty((num_frame + 1, num_tokens + 1))
60
+ trellis[0, 0] = 0
61
+ trellis[1:, 0] = torch.cumsum(emission[:, 0], 0) # len of this slice of trellis is len of audio frames)
62
+ trellis[0, -num_tokens:] = -float("inf") # len of this slice of trellis is len of transcript tokens
63
+ trellis[-num_tokens:, 0] = float("inf")
64
+ for t in range(num_frame):
65
+ trellis[t + 1, 1:] = torch.maximum(
66
+ # Score for staying at the same token
67
+ trellis[t, 1:] + emission[t, blank_id],
68
+ # Score for changing to the next token
69
+ trellis[t, :-1] + emission[t, tokens],
70
+ )
71
+ return trellis
72
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
+ @dataclass
76
+ class Point:
77
+ token_index: int
78
+ time_index: int
79
+ score: float
80
 
81
+ @dataclass
82
+ class Segment:
83
+ label: str
84
+ start: int
85
+ end: int
86
+ score: float
87
+
88
+ @property
89
+ def mfaform(self):
90
+ return f"{f2s(self.start)},{f2s(self.end)},{self.label}"
91
+
92
+ @property
93
+ def length(self):
94
+ return self.end - self.start
 
 
 
 
 
 
 
 
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
 
98
+ def backtrack(trellis, emission, tokens, blank_id):
 
99
  # Note:
100
  # j and t are indices for trellis, which has extra dimensions
101
  # for time and tokens at the beginning.
 
103
  # the corresponding index in emission is `T-1`.
104
  # Similarly, when referring to token index `J` in trellis,
105
  # the corresponding index in transcript is `J-1`.
106
+ j = trellis.size(1) - 1
107
+ t_start = torch.argmax(trellis[:, j]).item()
108
 
109
+ path = []
110
+ for t in range(t_start, 0, -1):
111
  # 1. Figure out if the current position was stay or change
112
  # `emission[J-1]` is the emission at time frame `J` of trellis dimension.
113
  # Score for token staying the same from time frame J-1 to T.
114
+ stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
115
  # Score for token changing from C-1 at T-1 to J at T.
116
+ changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
117
 
118
  # 2. Store the path with frame-wise probability.
119
+ prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
120
  # Return token index and time index in non-trellis coordinate.
121
+ path.append(Point(j - 1, t - 1, prob))
122
+
123
+ # 3. Update the token
124
+ if changed > stayed:
125
+ j -= 1
126
+ if j == 0:
127
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  else:
129
+ raise ValueError("Failed to align")
130
+ return path[::-1]
131
+
132
+
133
+ def merge_repeats(path,transcript):
134
+ i1, i2 = 0, 0
135
+ segments = []
136
+ while i1 < len(path):
137
+ while i2 < len(path) and path[i1].token_index == path[i2].token_index: # while both path steps point to the same token index
138
+ i2 += 1
139
+ score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
140
+ segments.append( # when i2 finally switches to a different token,
141
+ Segment(
142
+ transcript[path[i1].token_index],# to the list of segments, append the token from i1
143
+ path[i1].time_index, # time of the first path-point of that token
144
+ path[i2 - 1].time_index + 1, # time of the final path-point for that token.
145
+ score,
146
+ )
147
+ )
148
+ i1 = i2
149
+ return segments
150
 
151
 
152
 
153
+ def merge_words(segments, separator):
154
+ words = []
155
+ i1, i2 = 0, 0
156
+ while i1 < len(segments):
157
+ if i2 >= len(segments) or segments[i2].label == separator:
158
+ if i1 != i2:
159
+ segs = segments[i1:i2]
160
+ word = "".join([seg.label for seg in segs])
161
+ score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
162
+ words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
163
+ i1 = i2 + 1
164
+ i2 = i1
165
+ else:
166
+ i2 += 1
167
+ return words
168
 
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
 
 
 
 
 
 
 
 
171
 
172
+ #------------------------------------------
173
+ # handle etc.
174
+ #------------------------------------------
175
 
176
 
177
+ # generate mfa format for character (phone) and word alignments
178
+ # skip the word separator as it is not a phone
179
+ def mfalike(chars,wds,wsep):
180
+ hed = ['Begin,End,Label,Type,Speaker\n']
181
+ wlines = [f'{w.mfaform},words,000\n' for w in wds]
182
+ slines = [f'{ch.mfaform},phones,000\n' for ch in chars if ch.label != wsep]
183
+ return (''.join(hed+wlines+slines))
184
 
185
+ # generate basic exportable list format for character OR word alignments
186
+ # skip the word separator as it is not a phone
187
+ def basic(segs,wsep="|"):
188
+ return [[s.label,f2s(s.start),f2s(s.end)] for s in segs if s.label != wsep]
189
 
190
+
191
+ # needs pad labels added to correctly time first segment
192
+ # and therefore add word sep character as placeholder in transcript
193
+ def prep_transcript(xcp):
194
+ xcp = xcp.replace(' ',model_word_separator)
195
+ label_ids = [labels_dict[c] for c in xcp]
196
+ label_ids = [blank_id] + label_ids + [blank_id]
197
+ xcp = f'{model_word_separator}{xcp}{model_word_separator}'
198
+ return xcp,label_ids
199
+
200
+
201
+
202
+ def _align(wav_data,transcript):
203
+
204
+ norm_transcript,rec_label_ids = prep_transcript(transcript)
205
+ emit = get_frame_probs(wav_data)
206
+ trellis = get_trellis(emit, rec_label_ids, blank_id)
207
+ path = backtrack(trellis, emit, rec_label_ids, blank_id)
208
+
209
+ segments = merge_repeats(path,norm_transcript)
210
+ words = merge_words(segments, model_word_separator)
211
+
212
+ #segments = [s for s in segments if s[0] != model_word_separator]
213
+ #return mfalike(segments,words,model_word_separator)
214
+ return basic(words,model_word_separator), basic(segments,model_word_separator)
215
+
216
+ return _align
217
+
218
+
219
+
220
 
221
 
graph.py CHANGED
@@ -4,7 +4,6 @@ from scipy import signal
4
  import librosa
5
  import subprocess
6
  import matplotlib.pyplot as plt
7
- import ctcalign
8
 
9
 
10
 
@@ -44,13 +43,12 @@ def get_pitch_tracks(wav_path):
44
  # transcript could be from a corpus with the wav file,
45
  # input by the user,
46
  # or from a previous speech recognition process
47
- def align_and_graph(wav_path, transcript,lang_aligner):
48
 
49
  # fetch data
50
- #f0_data = get_pitch_tracks(wav_path)
51
  speech = readwav(wav_path)
52
- w_align, seg_align = ctcalign.align(speech,normalise_transcript(transcript),lang_aligner)
53
-
54
 
55
  # set up the graph shape
56
  rec_start = w_align[0][1]
 
4
  import librosa
5
  import subprocess
6
  import matplotlib.pyplot as plt
 
7
 
8
 
9
 
 
43
  # transcript could be from a corpus with the wav file,
44
  # input by the user,
45
  # or from a previous speech recognition process
46
+ def align_and_graph(wav_path, transcript, aligner_function):
47
 
48
  # fetch data
 
49
  speech = readwav(wav_path)
50
+ w_align, seg_align = aligner_function(speech,normalise_transcript(transcript))
51
+
52
 
53
  # set up the graph shape
54
  rec_start = w_align[0][1]