Arnab Das commited on
Commit
1e4d7e8
·
1 Parent(s): 6c204c1

reverting changes

Browse files
app.py CHANGED
@@ -3,7 +3,6 @@ import gradio as gr
3
  import models as MOD
4
  import process_data as PD
5
  from transformers import pipeline
6
- from manipulate_model.utils import get_config_and_model, infere
7
 
8
  model_master = {
9
  "SSL-AASIST (Trained on ASV-Spoof5)": {"eer_threshold": 3.3330237865448,
@@ -23,8 +22,6 @@ model.load_state_dict(torch.load("ssl_aasist_epoch_7.pth", map_location="cpu"))
23
  model.eval()
24
  loaded_model = "SSL-AASIST (Trained on ASV-Spoof5)"
25
 
26
- manpulate_config, manipulate_model = get_config_and_model()
27
-
28
  def process(file, type):
29
  global model
30
  global loaded_model
@@ -98,32 +95,7 @@ transcribe_proc = gr.Interface(
98
  allow_flagging="never"
99
  )
100
 
101
- #############################################################################################
102
- #For manipulation detection interface
103
-
104
- def detect_manipulation(inputs):
105
- global manipulate_model
106
- global manpulate_config
107
- out = infere(manipulate_model, inputs, manpulate_config)
108
- out = out.tolist()
109
- return str(out)
110
-
111
- manipulate_proc = gr.Interface(
112
- fn = detect_manipulation,
113
- inputs=[
114
- gr.Audio(type="filepath", label="Speech file (<30s)", max_length=30, sources=["microphone", "upload"], show_download_button=True)
115
- ],
116
- outputs=[
117
- gr.Text(label="Predicted manipulations", info="Manipulation detection is performed automatically."),
118
- ],
119
- title="Find the manipulated segments",
120
- description=(
121
- "Automatactic manipulation detection service. Upload a audio file."
122
- ),
123
- allow_flagging="never"
124
- )
125
-
126
  with demo:
127
- gr.TabbedInterface([file_proc, transcribe_proc, manipulate_proc], ["Analyze Audio File", "Transcribe Audio File", "Manipulation Detection"])
128
  demo.queue(max_size=10)
129
  demo.launch(share=True)
 
3
  import models as MOD
4
  import process_data as PD
5
  from transformers import pipeline
 
6
 
7
  model_master = {
8
  "SSL-AASIST (Trained on ASV-Spoof5)": {"eer_threshold": 3.3330237865448,
 
22
  model.eval()
23
  loaded_model = "SSL-AASIST (Trained on ASV-Spoof5)"
24
 
 
 
25
  def process(file, type):
26
  global model
27
  global loaded_model
 
95
  allow_flagging="never"
96
  )
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  with demo:
99
+ gr.TabbedInterface([file_proc, transcribe_proc], ["Analyze Audio File", "Transcribe Audio File"])
100
  demo.queue(max_size=10)
101
  demo.launch(share=True)
app_backup.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import models as MOD
4
+ import process_data as PD
5
+ from transformers import pipeline
6
+ from manipulate_model.utils import get_config_and_model, infere
7
+
8
+ model_master = {
9
+ "SSL-AASIST (Trained on ASV-Spoof5)": {"eer_threshold": 3.3330237865448,
10
+ "data_process_func": "process_ssl_assist_input",
11
+ "note": "This model is trained only on ASVSpoof 2024 training data.",
12
+ "model_class": "Model",
13
+ "model_checkpoint": "ssl_aasist_epoch_7.pth"},
14
+ "AASIST": {"eer_threshold": 1.8018419742584229,
15
+ "data_process_func": "process_assist_input",
16
+ "note": "This model is trained on ASVSpoof 2024 training data.",
17
+ "model_class":"AASIST_Model",
18
+ "model_checkpoint": "orig_aasist_epoch_1.pth"}
19
+ }
20
+
21
+ model = MOD.Model(None, "cpu")
22
+ model.load_state_dict(torch.load("ssl_aasist_epoch_7.pth", map_location="cpu"))
23
+ model.eval()
24
+ loaded_model = "SSL-AASIST (Trained on ASV-Spoof5)"
25
+
26
+ manpulate_config, manipulate_model = get_config_and_model()
27
+
28
+ def process(file, type):
29
+ global model
30
+ global loaded_model
31
+ inp = getattr(PD, model_master[type]["data_process_func"])(file)
32
+ if not loaded_model == type:
33
+ model = getattr(MOD, model_master[type]["model_class"])(None, "cpu")
34
+ model.load_state_dict(torch.load(model_master[type]["model_checkpoint"], map_location="cpu"))
35
+ model.eval()
36
+ loaded_model = type
37
+
38
+ op = model(inp).detach().squeeze()[1].item()
39
+
40
+ response_text = "Decision score: {} \nDecision threshold: {} \nNotes: 1. Any score below threshold is indicative of fake. \n2. {} ".format(
41
+ str(op), str(model_master[type]["eer_threshold"]), model_master[type]["note"])
42
+ return response_text
43
+
44
+
45
+ demo = gr.Blocks()
46
+ file_proc = gr.Interface(
47
+ fn=process,
48
+ inputs=[
49
+ gr.Audio(sources=["upload"], label="Audio file", type="filepath"),
50
+ gr.Radio(["SSL-AASIST (Trained on ASV-Spoof5)", "AASIST"], label="Select Model", type="value"),
51
+ ],
52
+ outputs="text",
53
+ title="Find the Fake: Analyze 'Real' or 'Fake'.",
54
+ description=(
55
+ "Analyze fake or real with a click of a button. Upload a .wav or .flac file."
56
+ ),
57
+ examples=[
58
+ ["./bonafide.flac", "SSL-AASIST (Trained on ASV-Spoof5)"],
59
+ ["./fake.flac", "SSL-AASIST (Trained on ASV-Spoof5)"],
60
+ ["./bonafide.flac", "AASIST"],
61
+ ["./fake.flac", "AASIST"],
62
+ ],
63
+ cache_examples=True,
64
+ allow_flagging="never",
65
+ )
66
+ #####################################################################################
67
+ # For ASR interface
68
+ pipe = pipeline(
69
+ task="automatic-speech-recognition",
70
+ model="openai/whisper-large-v3",
71
+ chunk_length_s=30,
72
+ device="cpu",
73
+ )
74
+
75
+ def transcribe(inputs):
76
+ if inputs is None:
77
+ raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
78
+
79
+ op = pipe(inputs, batch_size=8, generate_kwargs={"task": "transcribe"}, return_timestamps=False, return_language=True)
80
+ lang = op["chunks"][0]["language"]
81
+ text = op["text"]
82
+
83
+ return lang, text
84
+
85
+ transcribe_proc = gr.Interface(
86
+ fn = transcribe,
87
+ inputs = [
88
+ gr.Audio(type="filepath", label="Speech file (<30s)", max_length=30, sources=["microphone", "upload"], show_download_button=True)
89
+ ],
90
+ outputs=[
91
+ gr.Text(label="Predicted Language", info="Language identification is performed automatically."),
92
+ gr.Text(label="Predicted transcription", info="Best hypothesis."),
93
+ ],
94
+ title="Transcribe Anything.",
95
+ description=(
96
+ "Automatactic language identification and transcription service by Whisper Large V3. Upload a .wav or .flac file."
97
+ ),
98
+ allow_flagging="never"
99
+ )
100
+
101
+ #############################################################################################
102
+ #For manipulation detection interface
103
+
104
+ def detect_manipulation(inputs):
105
+ global manipulate_model
106
+ global manpulate_config
107
+ out = infere(manipulate_model, inputs, manpulate_config)
108
+ out = out.tolist()
109
+ return str(out)
110
+
111
+ manipulate_proc = gr.Interface(
112
+ fn = detect_manipulation,
113
+ inputs=[
114
+ gr.Audio(type="filepath", label="Speech file (<30s)", max_length=30, sources=["microphone", "upload"], show_download_button=True)
115
+ ],
116
+ outputs=[
117
+ gr.Text(label="Predicted manipulations", info="Manipulation detection is performed automatically."),
118
+ ],
119
+ title="Find the manipulated segments",
120
+ description=(
121
+ "Automatactic manipulation detection service. Upload a audio file."
122
+ ),
123
+ allow_flagging="never"
124
+ )
125
+
126
+ with demo:
127
+ gr.TabbedInterface([file_proc, transcribe_proc, manipulate_proc], ["Analyze Audio File", "Transcribe Audio File", "Manipulation Detection"])
128
+ demo.queue(max_size=10)
129
+ demo.launch(share=True)
manipulate_demo_files/fake.mp4 DELETED
Binary file (242 kB)
 
manipulate_demo_files/real.mp4 DELETED
Binary file (238 kB)
 
manipulate_model/decoder/aasist/aasist.py DELETED
@@ -1,617 +0,0 @@
1
- import random
2
- from typing import Union
3
-
4
- import numpy as np
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from torch import Tensor
9
-
10
- # import fairseq
11
-
12
-
13
- ___author__ = "Hemlata Tak"
14
- __email__ = "[email protected]"
15
-
16
- ############################
17
- ## FOR fine-tuned SSL MODEL
18
- ############################
19
-
20
-
21
- # class SSLModel(nn.Module):
22
- # def __init__(self, device):
23
- # super(SSLModel, self).__init__()
24
-
25
- # cp_path = "xlsr2_300m.pt" # Change the pre-trained XLSR model path.
26
- # model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
27
- # [cp_path]
28
- # )
29
- # self.model = model[0]
30
- # self.device = device
31
- # self.out_dim = 1024
32
- # return
33
-
34
- # def extract_feat(self, input_data):
35
-
36
- # # put the model to GPU if it not there
37
- # if (
38
- # next(self.model.parameters()).device != input_data.device
39
- # or next(self.model.parameters()).dtype != input_data.dtype
40
- # ):
41
- # self.model.to(input_data.device, dtype=input_data.dtype)
42
- # self.model.train()
43
-
44
- # if True:
45
- # # input should be in shape (batch, length)
46
- # if input_data.ndim == 3:
47
- # input_tmp = input_data[:, :, 0]
48
- # else:
49
- # input_tmp = input_data
50
-
51
- # # [batch, length, dim]
52
- # emb = self.model(input_tmp, mask=False, features_only=True)["x"]
53
- # return emb
54
-
55
-
56
- # ---------AASIST back-end------------------------#
57
- """ Jee-weon Jung, Hee-Soo Heo, Hemlata Tak, Hye-jin Shim, Joon Son Chung, Bong-Jin Lee, Ha-Jin Yu and Nicholas Evans.
58
- AASIST: Audio Anti-Spoofing Using Integrated Spectro-Temporal Graph Attention Networks.
59
- In Proc. ICASSP 2022, pp: 6367--6371."""
60
-
61
-
62
- class GraphAttentionLayer(nn.Module):
63
- def __init__(self, in_dim, out_dim, **kwargs):
64
- super().__init__()
65
-
66
- # attention map
67
- self.att_proj = nn.Linear(in_dim, out_dim)
68
- self.att_weight = self._init_new_params(out_dim, 1)
69
-
70
- # project
71
- self.proj_with_att = nn.Linear(in_dim, out_dim)
72
- self.proj_without_att = nn.Linear(in_dim, out_dim)
73
-
74
- # batch norm
75
- self.bn = nn.BatchNorm1d(out_dim)
76
-
77
- # dropout for inputs
78
- self.input_drop = nn.Dropout(p=0.2)
79
-
80
- # activate
81
- self.act = nn.SELU(inplace=True)
82
-
83
- # temperature
84
- self.temp = 1.0
85
- if "temperature" in kwargs:
86
- self.temp = kwargs["temperature"]
87
-
88
- def forward(self, x):
89
- """
90
- x :(#bs, #node, #dim)
91
- """
92
- # apply input dropout
93
- x = self.input_drop(x)
94
-
95
- # derive attention map
96
- att_map = self._derive_att_map(x)
97
-
98
- # projection
99
- x = self._project(x, att_map)
100
-
101
- # apply batch norm
102
- x = self._apply_BN(x)
103
- x = self.act(x)
104
- return x
105
-
106
- def _pairwise_mul_nodes(self, x):
107
- """
108
- Calculates pairwise multiplication of nodes.
109
- - for attention map
110
- x :(#bs, #node, #dim)
111
- out_shape :(#bs, #node, #node, #dim)
112
- """
113
-
114
- nb_nodes = x.size(1)
115
- x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
116
- x_mirror = x.transpose(1, 2)
117
-
118
- return x * x_mirror
119
-
120
- def _derive_att_map(self, x):
121
- """
122
- x :(#bs, #node, #dim)
123
- out_shape :(#bs, #node, #node, 1)
124
- """
125
- att_map = self._pairwise_mul_nodes(x)
126
- # size: (#bs, #node, #node, #dim_out)
127
- att_map = torch.tanh(self.att_proj(att_map))
128
- # size: (#bs, #node, #node, 1)
129
- att_map = torch.matmul(att_map, self.att_weight)
130
-
131
- # apply temperature
132
- att_map = att_map / self.temp
133
-
134
- att_map = F.softmax(att_map, dim=-2)
135
-
136
- return att_map
137
-
138
- def _project(self, x, att_map):
139
- x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
140
- x2 = self.proj_without_att(x)
141
-
142
- return x1 + x2
143
-
144
- def _apply_BN(self, x):
145
- org_size = x.size()
146
- x = x.view(-1, org_size[-1])
147
- x = self.bn(x)
148
- x = x.view(org_size)
149
-
150
- return x
151
-
152
- def _init_new_params(self, *size):
153
- out = nn.Parameter(torch.FloatTensor(*size))
154
- nn.init.xavier_normal_(out)
155
- return out
156
-
157
-
158
- class HtrgGraphAttentionLayer(nn.Module):
159
- def __init__(self, in_dim, out_dim, **kwargs):
160
- super().__init__()
161
-
162
- self.proj_type1 = nn.Linear(in_dim, in_dim)
163
- self.proj_type2 = nn.Linear(in_dim, in_dim)
164
-
165
- # attention map
166
- self.att_proj = nn.Linear(in_dim, out_dim)
167
- self.att_projM = nn.Linear(in_dim, out_dim)
168
-
169
- self.att_weight11 = self._init_new_params(out_dim, 1)
170
- self.att_weight22 = self._init_new_params(out_dim, 1)
171
- self.att_weight12 = self._init_new_params(out_dim, 1)
172
- self.att_weightM = self._init_new_params(out_dim, 1)
173
-
174
- # project
175
- self.proj_with_att = nn.Linear(in_dim, out_dim)
176
- self.proj_without_att = nn.Linear(in_dim, out_dim)
177
-
178
- self.proj_with_attM = nn.Linear(in_dim, out_dim)
179
- self.proj_without_attM = nn.Linear(in_dim, out_dim)
180
-
181
- # batch norm
182
- self.bn = nn.BatchNorm1d(out_dim)
183
-
184
- # dropout for inputs
185
- self.input_drop = nn.Dropout(p=0.2)
186
-
187
- # activate
188
- self.act = nn.SELU(inplace=True)
189
-
190
- # temperature
191
- self.temp = 1.0
192
- if "temperature" in kwargs:
193
- self.temp = kwargs["temperature"]
194
-
195
- def forward(self, x1, x2, master=None):
196
- """
197
- x1 :(#bs, #node, #dim)
198
- x2 :(#bs, #node, #dim)
199
- """
200
- # print('x1',x1.shape)
201
- # print('x2',x2.shape)
202
- num_type1 = x1.size(1)
203
- num_type2 = x2.size(1)
204
- # print('num_type1',num_type1)
205
- # print('num_type2',num_type2)
206
- x1 = self.proj_type1(x1)
207
- # print('proj_type1',x1.shape)
208
- x2 = self.proj_type2(x2)
209
- # print('proj_type2',x2.shape)
210
- x = torch.cat([x1, x2], dim=1)
211
- # print('Concat x1 and x2',x.shape)
212
-
213
- if master is None:
214
- master = torch.mean(x, dim=1, keepdim=True)
215
- # print('master',master.shape)
216
- # apply input dropout
217
- x = self.input_drop(x)
218
-
219
- # derive attention map
220
- att_map = self._derive_att_map(x, num_type1, num_type2)
221
- # print('master',master.shape)
222
- # directional edge for master node
223
- master = self._update_master(x, master)
224
- # print('master',master.shape)
225
- # projection
226
- x = self._project(x, att_map)
227
- # print('proj x',x.shape)
228
- # apply batch norm
229
- x = self._apply_BN(x)
230
- x = self.act(x)
231
-
232
- x1 = x.narrow(1, 0, num_type1)
233
- # print('x1',x1.shape)
234
- x2 = x.narrow(1, num_type1, num_type2)
235
- # print('x2',x2.shape)
236
- return x1, x2, master
237
-
238
- def _update_master(self, x, master):
239
-
240
- att_map = self._derive_att_map_master(x, master)
241
- master = self._project_master(x, master, att_map)
242
-
243
- return master
244
-
245
- def _pairwise_mul_nodes(self, x):
246
- """
247
- Calculates pairwise multiplication of nodes.
248
- - for attention map
249
- x :(#bs, #node, #dim)
250
- out_shape :(#bs, #node, #node, #dim)
251
- """
252
-
253
- nb_nodes = x.size(1)
254
- x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
255
- x_mirror = x.transpose(1, 2)
256
-
257
- return x * x_mirror
258
-
259
- def _derive_att_map_master(self, x, master):
260
- """
261
- x :(#bs, #node, #dim)
262
- out_shape :(#bs, #node, #node, 1)
263
- """
264
- att_map = x * master
265
- att_map = torch.tanh(self.att_projM(att_map))
266
-
267
- att_map = torch.matmul(att_map, self.att_weightM)
268
-
269
- # apply temperature
270
- att_map = att_map / self.temp
271
-
272
- att_map = F.softmax(att_map, dim=-2)
273
-
274
- return att_map
275
-
276
- def _derive_att_map(self, x, num_type1, num_type2):
277
- """
278
- x :(#bs, #node, #dim)
279
- out_shape :(#bs, #node, #node, 1)
280
- """
281
- att_map = self._pairwise_mul_nodes(x)
282
- # size: (#bs, #node, #node, #dim_out)
283
- att_map = torch.tanh(self.att_proj(att_map))
284
- # size: (#bs, #node, #node, 1)
285
-
286
- att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)
287
-
288
- att_board[:, :num_type1, :num_type1, :] = torch.matmul(
289
- att_map[:, :num_type1, :num_type1, :], self.att_weight11
290
- )
291
- att_board[:, num_type1:, num_type1:, :] = torch.matmul(
292
- att_map[:, num_type1:, num_type1:, :], self.att_weight22
293
- )
294
- att_board[:, :num_type1, num_type1:, :] = torch.matmul(
295
- att_map[:, :num_type1, num_type1:, :], self.att_weight12
296
- )
297
- att_board[:, num_type1:, :num_type1, :] = torch.matmul(
298
- att_map[:, num_type1:, :num_type1, :], self.att_weight12
299
- )
300
-
301
- att_map = att_board
302
-
303
- # apply temperature
304
- att_map = att_map / self.temp
305
-
306
- att_map = F.softmax(att_map, dim=-2)
307
-
308
- return att_map
309
-
310
- def _project(self, x, att_map):
311
- x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
312
- x2 = self.proj_without_att(x)
313
-
314
- return x1 + x2
315
-
316
- def _project_master(self, x, master, att_map):
317
-
318
- x1 = self.proj_with_attM(torch.matmul(att_map.squeeze(-1).unsqueeze(1), x))
319
- x2 = self.proj_without_attM(master)
320
-
321
- return x1 + x2
322
-
323
- def _apply_BN(self, x):
324
- org_size = x.size()
325
- x = x.view(-1, org_size[-1])
326
- x = self.bn(x)
327
- x = x.view(org_size)
328
-
329
- return x
330
-
331
- def _init_new_params(self, *size):
332
- out = nn.Parameter(torch.FloatTensor(*size))
333
- nn.init.xavier_normal_(out)
334
- return out
335
-
336
-
337
- class GraphPool(nn.Module):
338
- def __init__(self, k: float, in_dim: int, p: Union[float, int]):
339
- super().__init__()
340
- self.k = k
341
- self.sigmoid = nn.Sigmoid()
342
- self.proj = nn.Linear(in_dim, 1)
343
- self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
344
- self.in_dim = in_dim
345
-
346
- def forward(self, h):
347
- Z = self.drop(h)
348
- weights = self.proj(Z)
349
- scores = self.sigmoid(weights)
350
- new_h = self.top_k_graph(scores, h, self.k)
351
-
352
- return new_h
353
-
354
- def top_k_graph(self, scores, h, k):
355
- """
356
- args
357
- =====
358
- scores: attention-based weights (#bs, #node, 1)
359
- h: graph data (#bs, #node, #dim)
360
- k: ratio of remaining nodes, (float)
361
- returns
362
- =====
363
- h: graph pool applied data (#bs, #node', #dim)
364
- """
365
- _, n_nodes, n_feat = h.size()
366
- n_nodes = max(int(n_nodes * k), 1)
367
- _, idx = torch.topk(scores, n_nodes, dim=1)
368
- idx = idx.expand(-1, -1, n_feat)
369
-
370
- h = h * scores
371
- h = torch.gather(h, 1, idx)
372
-
373
- return h
374
-
375
-
376
- class Residual_block(nn.Module):
377
- def __init__(self, nb_filts, first=False):
378
- super().__init__()
379
- self.first = first
380
-
381
- if not self.first:
382
- self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
383
- self.conv1 = nn.Conv2d(
384
- in_channels=nb_filts[0],
385
- out_channels=nb_filts[1],
386
- kernel_size=(2, 3),
387
- padding=(1, 1),
388
- stride=1,
389
- )
390
- self.selu = nn.SELU(inplace=True)
391
-
392
- self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
393
- self.conv2 = nn.Conv2d(
394
- in_channels=nb_filts[1],
395
- out_channels=nb_filts[1],
396
- kernel_size=(2, 3),
397
- padding=(0, 1),
398
- stride=1,
399
- )
400
-
401
- if nb_filts[0] != nb_filts[1]:
402
- self.downsample = True
403
- self.conv_downsample = nn.Conv2d(
404
- in_channels=nb_filts[0],
405
- out_channels=nb_filts[1],
406
- padding=(0, 1),
407
- kernel_size=(1, 3),
408
- stride=1,
409
- )
410
-
411
- else:
412
- self.downsample = False
413
-
414
- def forward(self, x):
415
- identity = x
416
- if not self.first:
417
- out = self.bn1(x)
418
- out = self.selu(out)
419
- else:
420
- out = x
421
-
422
- # print('out',out.shape)
423
- out = self.conv1(x)
424
-
425
- # print('aft conv1 out',out.shape)
426
- out = self.bn2(out)
427
- out = self.selu(out)
428
- # print('out',out.shape)
429
- out = self.conv2(out)
430
- # print('conv2 out',out.shape)
431
-
432
- if self.downsample:
433
- identity = self.conv_downsample(identity)
434
-
435
- out += identity
436
- # out = self.mp(out)
437
- return out
438
-
439
-
440
- class AASIST(nn.Module):
441
- def __init__(self, config):
442
- super().__init__()
443
- self.config = config
444
-
445
- # AASIST parameters
446
- filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]]
447
- gat_dims = [64, 32]
448
- pool_ratios = [0.5, 0.5, 0.5, 0.5]
449
- temperatures = [2.0, 2.0, 100.0, 100.0]
450
-
451
- ####
452
- # create network wav2vec 2.0
453
- ####
454
- # self.ssl_model = SSLModel(self.device)
455
- self.LL = nn.Linear(self.config.model.decoder.encoding_dim, 128)
456
-
457
- self.first_bn = nn.BatchNorm2d(num_features=1)
458
- self.first_bn1 = nn.BatchNorm2d(num_features=64)
459
- self.drop = nn.Dropout(0.5, inplace=True)
460
- self.drop_way = nn.Dropout(0.2, inplace=True)
461
- self.selu = nn.SELU(inplace=True)
462
-
463
- # RawNet2 encoder
464
- self.encoder = nn.Sequential(
465
- nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
466
- nn.Sequential(Residual_block(nb_filts=filts[2])),
467
- nn.Sequential(Residual_block(nb_filts=filts[3])),
468
- nn.Sequential(Residual_block(nb_filts=filts[4])),
469
- nn.Sequential(Residual_block(nb_filts=filts[4])),
470
- nn.Sequential(Residual_block(nb_filts=filts[4])),
471
- )
472
-
473
- self.attention = nn.Sequential(
474
- nn.Conv2d(64, 128, kernel_size=(1, 1)),
475
- nn.SELU(inplace=True),
476
- nn.BatchNorm2d(128),
477
- nn.Conv2d(128, 64, kernel_size=(1, 1)),
478
- )
479
- # position encoding
480
- self.pos_S = nn.Parameter(torch.randn(1, 42, filts[-1][-1]))
481
-
482
- self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
483
- self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
484
-
485
- # Graph module
486
- self.GAT_layer_S = GraphAttentionLayer(
487
- filts[-1][-1], gat_dims[0], temperature=temperatures[0]
488
- )
489
- self.GAT_layer_T = GraphAttentionLayer(
490
- filts[-1][-1], gat_dims[0], temperature=temperatures[1]
491
- )
492
- # HS-GAL layer
493
- self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
494
- gat_dims[0], gat_dims[1], temperature=temperatures[2]
495
- )
496
- self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
497
- gat_dims[1], gat_dims[1], temperature=temperatures[2]
498
- )
499
- self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
500
- gat_dims[0], gat_dims[1], temperature=temperatures[2]
501
- )
502
- self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
503
- gat_dims[1], gat_dims[1], temperature=temperatures[2]
504
- )
505
-
506
- # Graph pooling layers
507
- self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
508
- self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
509
- self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
510
- self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
511
-
512
- self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
513
- self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
514
-
515
- if self.config.model.task == "audio-video":
516
- self.out_layer = nn.Linear(5 * gat_dims[1], 4)
517
- else:
518
- self.out_layer = nn.Linear(5 * gat_dims[1], 2)
519
-
520
- def forward(self, x):
521
- # -------pre-trained Wav2vec model fine tunning ------------------------##
522
- # x_ssl_feat = self.ssl_model.extract_feat(x.squeeze(-1))
523
- x = self.LL(x) # (bs,frame_number,feat_out_dim)
524
-
525
- # post-processing on front-end features
526
- x = x.transpose(1, 2) # (bs,feat_out_dim,frame_number)
527
- x = x.unsqueeze(dim=1) # add channel
528
- x = F.max_pool2d(x, (3, 3))
529
- x = self.first_bn(x)
530
- x = self.selu(x)
531
-
532
- # RawNet2-based encoder
533
- x = self.encoder(x)
534
- x = self.first_bn1(x)
535
- x = self.selu(x)
536
-
537
- w = self.attention(x)
538
-
539
- # ------------SA for spectral feature-------------#
540
- w1 = F.softmax(w, dim=-1)
541
- m = torch.sum(x * w1, dim=-1)
542
- e_S = m.transpose(1, 2) + self.pos_S
543
-
544
- # graph module layer
545
- gat_S = self.GAT_layer_S(e_S)
546
- out_S = self.pool_S(gat_S) # (#bs, #node, #dim)
547
-
548
- # ------------SA for temporal feature-------------#
549
- w2 = F.softmax(w, dim=-2)
550
- m1 = torch.sum(x * w2, dim=-2)
551
-
552
- e_T = m1.transpose(1, 2)
553
-
554
- # graph module layer
555
- gat_T = self.GAT_layer_T(e_T)
556
- out_T = self.pool_T(gat_T)
557
-
558
- # learnable master node
559
- master1 = self.master1.expand(x.size(0), -1, -1)
560
- master2 = self.master2.expand(x.size(0), -1, -1)
561
-
562
- # inference 1
563
- out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
564
- out_T, out_S, master=self.master1
565
- )
566
-
567
- out_S1 = self.pool_hS1(out_S1)
568
- out_T1 = self.pool_hT1(out_T1)
569
-
570
- out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
571
- out_T1, out_S1, master=master1
572
- )
573
- out_T1 = out_T1 + out_T_aug
574
- out_S1 = out_S1 + out_S_aug
575
- master1 = master1 + master_aug
576
-
577
- # inference 2
578
- out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
579
- out_T, out_S, master=self.master2
580
- )
581
- out_S2 = self.pool_hS2(out_S2)
582
- out_T2 = self.pool_hT2(out_T2)
583
-
584
- out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
585
- out_T2, out_S2, master=master2
586
- )
587
- out_T2 = out_T2 + out_T_aug
588
- out_S2 = out_S2 + out_S_aug
589
- master2 = master2 + master_aug
590
-
591
- out_T1 = self.drop_way(out_T1)
592
- out_T2 = self.drop_way(out_T2)
593
- out_S1 = self.drop_way(out_S1)
594
- out_S2 = self.drop_way(out_S2)
595
- master1 = self.drop_way(master1)
596
- master2 = self.drop_way(master2)
597
-
598
- out_T = torch.max(out_T1, out_T2)
599
- out_S = torch.max(out_S1, out_S2)
600
- master = torch.max(master1, master2)
601
-
602
- # Readout operation
603
- T_max, _ = torch.max(torch.abs(out_T), dim=1)
604
- T_avg = torch.mean(out_T, dim=1)
605
-
606
- S_max, _ = torch.max(torch.abs(out_S), dim=1)
607
- S_avg = torch.mean(out_S, dim=1)
608
-
609
- last_hidden = torch.cat([T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
610
-
611
- last_hidden = self.drop(last_hidden)
612
- output = self.out_layer(last_hidden)
613
-
614
- if self.config.model.task == "audio-video":
615
- output = output.view(-1, 2, 2)
616
-
617
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
manipulate_model/decoder/decoder.py DELETED
@@ -1,20 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
-
5
- class Decoder(nn.Module):
6
- def __init__(self, config):
7
- super(Decoder, self).__init__()
8
- self.config = config
9
-
10
- self.decoder = None
11
-
12
- if config.model.decoder.name.lower() == "aasist":
13
- from manipulate_model.decoder.aasist.aasist import AASIST
14
-
15
- self.decoder = AASIST(config)
16
- else:
17
- raise ValueError("Invalid decoder name")
18
-
19
- def forward(self, x):
20
- return self.decoder(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
manipulate_model/demo-model/audio/config.yaml DELETED
@@ -1,44 +0,0 @@
1
- model:
2
- task: audio
3
- encoder:
4
- name: wavlm
5
- version: base
6
- pretrained: true
7
- pretrained_path: manipulate_model/encoder_checkpoints/wavlm/WavLM-Base+.pt
8
- output_layer: 3
9
- encoder_freeze: false
10
- decoder:
11
- name: aasist
12
- version: default
13
- output_size: 2
14
- online_encoding: true
15
- data:
16
- name: av1m
17
- train_parts: all
18
- val_parts: all
19
- test_parts: all
20
- train_size: -1
21
- val_size: -1
22
- test_size: -1
23
- shape:
24
- - 3
25
- - 224
26
- - 224
27
- sr: 16000
28
- fps: 25
29
- center_transition: true
30
- window_size: 4
31
- sliding_window: false
32
- train:
33
- num_workers: 16
34
- batch_size: 64
35
- num_epochs: 15
36
- optimizer: adam
37
- scheduler: step
38
- lr: 0.0001
39
- step_size: 1
40
- gamma: 0.1
41
- loss: bce
42
- log_interval: 100
43
- shuffle: true
44
- debug: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
manipulate_model/demo-model/audio/weights.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:03879151a2f0fb09431c0889537b1590fc9f2bba0faa1a01cbed84186b69b916
3
- size 379398821
 
 
 
 
manipulate_model/encoder/encoder.py DELETED
@@ -1,40 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from torch.nn.functional import pad
5
- from collections import OrderedDict
6
-
7
-
8
- class Encoder(nn.Module):
9
- def __init__(self, config):
10
- super(Encoder, self).__init__()
11
- self.config = config
12
-
13
- self.encoder = None
14
- self.succeeding_layers = None
15
-
16
- # AUDIO
17
- if self.config.model.task == "audio":
18
- if self.config.model.encoder.name.lower() == "wavlm":
19
- from manipulate_model.encoder.wavlm.WavLM import WavLM, WavLMConfig
20
-
21
- ckpt = torch.load(
22
- config.model.encoder.pretrained_path, map_location="cpu"
23
- )
24
- cfg = WavLMConfig(ckpt["cfg"])
25
- self.encoder = WavLM(cfg)
26
-
27
-
28
- def forward(self, x):
29
- if self.config.model.encoder.name.lower() == "wavlm":
30
- return self.encoder(x, output_layer=self.config.model.encoder.output_layer)
31
- elif self.config.model.encoder.name.lower() == "videomamba":
32
- return self.encoder(x)
33
-
34
- return self.encoder(x)
35
-
36
- def get_encoding_dim(self):
37
- return self.encoder.get_encoding_dim()
38
-
39
- def get_temporal_dim(self):
40
- return self.encoder.get_temporal_dim(window_size=self.config.data.window_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
manipulate_model/encoder/wavlm/WavLM.py DELETED
@@ -1,795 +0,0 @@
1
- # --------------------------------------------------------
2
- # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
- # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
- # Copyright (c) 2021 Microsoft
5
- # Licensed under The MIT License [see LICENSE for details]
6
- # Based on fairseq code bases
7
- # https://github.com/pytorch/fairseq
8
- # --------------------------------------------------------
9
-
10
- import math
11
- import logging
12
- from typing import List, Optional, Tuple
13
-
14
- import numpy as np
15
-
16
- import torch
17
- import torch.nn as nn
18
- import torch.nn.functional as F
19
- from torch.nn import LayerNorm
20
- from manipulate_model.encoder.wavlm.modules import (
21
- Fp32GroupNorm,
22
- Fp32LayerNorm,
23
- GradMultiply,
24
- MultiheadAttention,
25
- SamePad,
26
- init_bert_params,
27
- get_activation_fn,
28
- TransposeLast,
29
- GLU_Linear,
30
- )
31
-
32
- logger = logging.getLogger(__name__)
33
-
34
-
35
- def compute_mask_indices(
36
- shape: Tuple[int, int],
37
- padding_mask: Optional[torch.Tensor],
38
- mask_prob: float,
39
- mask_length: int,
40
- mask_type: str = "static",
41
- mask_other: float = 0.0,
42
- min_masks: int = 0,
43
- no_overlap: bool = False,
44
- min_space: int = 0,
45
- ) -> np.ndarray:
46
- """
47
- Computes random mask spans for a given shape
48
-
49
- Args:
50
- shape: the the shape for which to compute masks.
51
- should be of size 2 where first element is batch size and 2nd is timesteps
52
- padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
53
- mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
54
- number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
55
- however due to overlaps, the actual number will be smaller (unless no_overlap is True)
56
- mask_type: how to compute mask lengths
57
- static = fixed size
58
- uniform = sample from uniform distribution [mask_other, mask_length*2]
59
- normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
60
- poisson = sample from possion distribution with lambda = mask length
61
- min_masks: minimum number of masked spans
62
- no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
63
- min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
64
- """
65
-
66
- bsz, all_sz = shape
67
- mask = np.full((bsz, all_sz), False)
68
-
69
- all_num_mask = int(
70
- # add a random number for probabilistic rounding
71
- mask_prob * all_sz / float(mask_length)
72
- + np.random.rand()
73
- )
74
-
75
- all_num_mask = max(min_masks, all_num_mask)
76
-
77
- mask_idcs = []
78
- for i in range(bsz):
79
- if padding_mask is not None:
80
- sz = all_sz - padding_mask[i].long().sum().item()
81
- num_mask = int(
82
- # add a random number for probabilistic rounding
83
- mask_prob * sz / float(mask_length)
84
- + np.random.rand()
85
- )
86
- num_mask = max(min_masks, num_mask)
87
- else:
88
- sz = all_sz
89
- num_mask = all_num_mask
90
-
91
- if mask_type == "static":
92
- lengths = np.full(num_mask, mask_length)
93
- elif mask_type == "uniform":
94
- lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
95
- elif mask_type == "normal":
96
- lengths = np.random.normal(mask_length, mask_other, size=num_mask)
97
- lengths = [max(1, int(round(x))) for x in lengths]
98
- elif mask_type == "poisson":
99
- lengths = np.random.poisson(mask_length, size=num_mask)
100
- lengths = [int(round(x)) for x in lengths]
101
- else:
102
- raise Exception("unknown mask selection " + mask_type)
103
-
104
- if sum(lengths) == 0:
105
- lengths[0] = min(mask_length, sz - 1)
106
-
107
- if no_overlap:
108
- mask_idc = []
109
-
110
- def arrange(s, e, length, keep_length):
111
- span_start = np.random.randint(s, e - length)
112
- mask_idc.extend(span_start + i for i in range(length))
113
-
114
- new_parts = []
115
- if span_start - s - min_space >= keep_length:
116
- new_parts.append((s, span_start - min_space + 1))
117
- if e - span_start - keep_length - min_space > keep_length:
118
- new_parts.append((span_start + length + min_space, e))
119
- return new_parts
120
-
121
- parts = [(0, sz)]
122
- min_length = min(lengths)
123
- for length in sorted(lengths, reverse=True):
124
- lens = np.fromiter(
125
- (e - s if e - s >= length + min_space else 0 for s, e in parts),
126
- np.int,
127
- )
128
- l_sum = np.sum(lens)
129
- if l_sum == 0:
130
- break
131
- probs = lens / np.sum(lens)
132
- c = np.random.choice(len(parts), p=probs)
133
- s, e = parts.pop(c)
134
- parts.extend(arrange(s, e, length, min_length))
135
- mask_idc = np.asarray(mask_idc)
136
- else:
137
- min_len = min(lengths)
138
- if sz - min_len <= num_mask:
139
- min_len = sz - num_mask - 1
140
-
141
- mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
142
-
143
- mask_idc = np.asarray(
144
- [
145
- mask_idc[j] + offset
146
- for j in range(len(mask_idc))
147
- for offset in range(lengths[j])
148
- ]
149
- )
150
-
151
- mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
152
-
153
- min_len = min([len(m) for m in mask_idcs])
154
- for i, mask_idc in enumerate(mask_idcs):
155
- if len(mask_idc) > min_len:
156
- mask_idc = np.random.choice(mask_idc, min_len, replace=False)
157
- mask[i, mask_idc] = True
158
-
159
- return mask
160
-
161
-
162
- class WavLMConfig:
163
- def __init__(self, cfg=None):
164
- self.extractor_mode: str = (
165
- "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
166
- )
167
- self.encoder_layers: int = 12 # num encoder layers in the transformer
168
-
169
- self.encoder_embed_dim: int = 768 # encoder embedding dimension
170
- self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
171
- self.encoder_attention_heads: int = 12 # num encoder attention heads
172
- self.activation_fn: str = "gelu" # activation function to use
173
-
174
- self.layer_norm_first: bool = False # apply layernorm first in the transformer
175
- self.conv_feature_layers: str = (
176
- "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
177
- )
178
- self.conv_bias: bool = False # include bias in conv encoder
179
- self.feature_grad_mult: float = (
180
- 1.0 # multiply feature extractor var grads by this
181
- )
182
-
183
- self.normalize: bool = (
184
- False # normalize input to have 0 mean and unit variance during training
185
- )
186
-
187
- # dropouts
188
- self.dropout: float = 0.1 # dropout probability for the transformer
189
- self.attention_dropout: float = 0.1 # dropout probability for attention weights
190
- self.activation_dropout: float = (
191
- 0.0 # dropout probability after activation in FFN
192
- )
193
- self.encoder_layerdrop: float = (
194
- 0.0 # probability of dropping a tarnsformer layer
195
- )
196
- self.dropout_input: float = (
197
- 0.0 # dropout to apply to the input (after feat extr)
198
- )
199
- self.dropout_features: float = (
200
- 0.0 # dropout to apply to the features (after feat extr)
201
- )
202
-
203
- # masking
204
- self.mask_length: int = 10 # mask length
205
- self.mask_prob: float = 0.65 # probability of replacing a token with mask
206
- self.mask_selection: str = "static" # how to choose mask length
207
- self.mask_other: float = (
208
- 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
209
- )
210
- self.no_mask_overlap: bool = False # whether to allow masks to overlap
211
- self.mask_min_space: int = (
212
- 1 # min space between spans (if no overlap is enabled)
213
- )
214
-
215
- # channel masking
216
- self.mask_channel_length: int = 10 # length of the mask for features (channels)
217
- self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
218
- self.mask_channel_selection: str = (
219
- "static" # how to choose mask length for channel masking
220
- )
221
- self.mask_channel_other: float = (
222
- 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
223
- )
224
- self.no_mask_channel_overlap: bool = (
225
- False # whether to allow channel masks to overlap
226
- )
227
- self.mask_channel_min_space: int = (
228
- 1 # min space between spans (if no overlap is enabled)
229
- )
230
-
231
- # positional embeddings
232
- self.conv_pos: int = (
233
- 128 # number of filters for convolutional positional embeddings
234
- )
235
- self.conv_pos_groups: int = (
236
- 16 # number of groups for convolutional positional embedding
237
- )
238
-
239
- # relative position embedding
240
- self.relative_position_embedding: bool = (
241
- False # apply relative position embedding
242
- )
243
- self.num_buckets: int = 320 # number of buckets for relative position embedding
244
- self.max_distance: int = (
245
- 1280 # maximum distance for relative position embedding
246
- )
247
- self.gru_rel_pos: bool = False # apply gated relative position embedding
248
-
249
- if cfg is not None:
250
- self.update(cfg)
251
-
252
- def update(self, cfg: dict):
253
- self.__dict__.update(cfg)
254
-
255
-
256
- class WavLM(nn.Module):
257
- def __init__(
258
- self,
259
- cfg: WavLMConfig,
260
- ) -> None:
261
- super().__init__()
262
- logger.info(f"WavLM Config: {cfg.__dict__}")
263
-
264
- self.cfg = cfg
265
- feature_enc_layers = eval(cfg.conv_feature_layers)
266
- self.embed = feature_enc_layers[-1][0]
267
-
268
- self.feature_extractor = ConvFeatureExtractionModel(
269
- conv_layers=feature_enc_layers,
270
- dropout=0.0,
271
- mode=cfg.extractor_mode,
272
- conv_bias=cfg.conv_bias,
273
- )
274
-
275
- self.post_extract_proj = (
276
- nn.Linear(self.embed, cfg.encoder_embed_dim)
277
- if self.embed != cfg.encoder_embed_dim
278
- else None
279
- )
280
-
281
- self.mask_prob = cfg.mask_prob
282
- self.mask_selection = cfg.mask_selection
283
- self.mask_other = cfg.mask_other
284
- self.mask_length = cfg.mask_length
285
- self.no_mask_overlap = cfg.no_mask_overlap
286
- self.mask_min_space = cfg.mask_min_space
287
-
288
- self.mask_channel_prob = cfg.mask_channel_prob
289
- self.mask_channel_selection = cfg.mask_channel_selection
290
- self.mask_channel_other = cfg.mask_channel_other
291
- self.mask_channel_length = cfg.mask_channel_length
292
- self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
293
- self.mask_channel_min_space = cfg.mask_channel_min_space
294
-
295
- self.dropout_input = nn.Dropout(cfg.dropout_input)
296
- self.dropout_features = nn.Dropout(cfg.dropout_features)
297
-
298
- self.feature_grad_mult = cfg.feature_grad_mult
299
-
300
- self.mask_emb = nn.Parameter(
301
- torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
302
- )
303
-
304
- self.encoder = TransformerEncoder(cfg)
305
- self.layer_norm = LayerNorm(self.embed)
306
-
307
- def apply_mask(self, x, padding_mask):
308
- B, T, C = x.shape
309
- if self.mask_prob > 0:
310
- mask_indices = compute_mask_indices(
311
- (B, T),
312
- padding_mask,
313
- self.mask_prob,
314
- self.mask_length,
315
- self.mask_selection,
316
- self.mask_other,
317
- min_masks=2,
318
- no_overlap=self.no_mask_overlap,
319
- min_space=self.mask_min_space,
320
- )
321
- mask_indices = torch.from_numpy(mask_indices).to(x.device)
322
- x[mask_indices] = self.mask_emb
323
- else:
324
- mask_indices = None
325
-
326
- if self.mask_channel_prob > 0:
327
- mask_channel_indices = compute_mask_indices(
328
- (B, C),
329
- None,
330
- self.mask_channel_prob,
331
- self.mask_channel_length,
332
- self.mask_channel_selection,
333
- self.mask_channel_other,
334
- no_overlap=self.no_mask_channel_overlap,
335
- min_space=self.mask_channel_min_space,
336
- )
337
- mask_channel_indices = (
338
- torch.from_numpy(mask_channel_indices)
339
- .to(x.device)
340
- .unsqueeze(1)
341
- .expand(-1, T, -1)
342
- )
343
- x[mask_channel_indices] = 0
344
-
345
- return x, mask_indices
346
-
347
- def forward_padding_mask(
348
- self,
349
- features: torch.Tensor,
350
- padding_mask: torch.Tensor,
351
- ) -> torch.Tensor:
352
- extra = padding_mask.size(1) % features.size(1)
353
- if extra > 0:
354
- padding_mask = padding_mask[:, :-extra]
355
- padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
356
- padding_mask = padding_mask.all(-1)
357
- return padding_mask
358
-
359
- def extract_features(
360
- self,
361
- source: torch.Tensor,
362
- padding_mask: Optional[torch.Tensor] = None,
363
- mask: bool = False,
364
- ret_conv: bool = False,
365
- output_layer: Optional[int] = None,
366
- ret_layer_results: bool = False,
367
- ):
368
-
369
- if self.feature_grad_mult > 0:
370
- features = self.feature_extractor(source)
371
- if self.feature_grad_mult != 1.0:
372
- features = GradMultiply.apply(features, self.feature_grad_mult)
373
- else:
374
- with torch.no_grad():
375
- features = self.feature_extractor(source)
376
-
377
- features = features.transpose(1, 2)
378
- features = self.layer_norm(features)
379
-
380
- if padding_mask is not None:
381
- padding_mask = self.forward_padding_mask(features, padding_mask)
382
-
383
- if self.post_extract_proj is not None:
384
- features = self.post_extract_proj(features)
385
-
386
- features = self.dropout_input(features)
387
-
388
- if mask:
389
- x, mask_indices = self.apply_mask(features, padding_mask)
390
- else:
391
- x = features
392
-
393
- # feature: (B, T, D), float
394
- # target: (B, T), long
395
- # x: (B, T, D), float
396
- # padding_mask: (B, T), bool
397
- # mask_indices: (B, T), bool
398
- x, layer_results = self.encoder(
399
- x,
400
- padding_mask=padding_mask,
401
- layer=None if output_layer is None else output_layer - 1,
402
- )
403
-
404
- res = {
405
- "x": x,
406
- "padding_mask": padding_mask,
407
- "features": features,
408
- "layer_results": layer_results,
409
- }
410
-
411
- feature = res["features"] if ret_conv else res["x"]
412
- if ret_layer_results:
413
- feature = (feature, res["layer_results"])
414
- return feature, res["padding_mask"]
415
-
416
- def forward(self, x, output_layer=None):
417
- return self.extract_features(x, output_layer=output_layer)[0]
418
-
419
- def get_encoding_dim(self):
420
- return self.cfg.encoder_embed_dim
421
-
422
- def get_temporal_dim(self, window_size):
423
- return 2 * window_size - 1
424
-
425
-
426
- class ConvFeatureExtractionModel(nn.Module):
427
- def __init__(
428
- self,
429
- conv_layers: List[Tuple[int, int, int]],
430
- dropout: float = 0.0,
431
- mode: str = "default",
432
- conv_bias: bool = False,
433
- conv_type: str = "default",
434
- ):
435
- super().__init__()
436
-
437
- assert mode in {"default", "layer_norm"}
438
-
439
- def block(
440
- n_in,
441
- n_out,
442
- k,
443
- stride,
444
- is_layer_norm=False,
445
- is_group_norm=False,
446
- conv_bias=False,
447
- ):
448
- def make_conv():
449
- conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
450
- nn.init.kaiming_normal_(conv.weight)
451
- return conv
452
-
453
- assert (
454
- is_layer_norm and is_group_norm
455
- ) == False, "layer norm and group norm are exclusive"
456
-
457
- if is_layer_norm:
458
- return nn.Sequential(
459
- make_conv(),
460
- nn.Dropout(p=dropout),
461
- nn.Sequential(
462
- TransposeLast(),
463
- Fp32LayerNorm(dim, elementwise_affine=True),
464
- TransposeLast(),
465
- ),
466
- nn.GELU(),
467
- )
468
- elif is_group_norm:
469
- return nn.Sequential(
470
- make_conv(),
471
- nn.Dropout(p=dropout),
472
- Fp32GroupNorm(dim, dim, affine=True),
473
- nn.GELU(),
474
- )
475
- else:
476
- return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
477
-
478
- self.conv_type = conv_type
479
- if self.conv_type == "default":
480
- in_d = 1
481
- self.conv_layers = nn.ModuleList()
482
- for i, cl in enumerate(conv_layers):
483
- assert len(cl) == 3, "invalid conv definition: " + str(cl)
484
- (dim, k, stride) = cl
485
-
486
- self.conv_layers.append(
487
- block(
488
- in_d,
489
- dim,
490
- k,
491
- stride,
492
- is_layer_norm=mode == "layer_norm",
493
- is_group_norm=mode == "default" and i == 0,
494
- conv_bias=conv_bias,
495
- )
496
- )
497
- in_d = dim
498
- elif self.conv_type == "conv2d":
499
- in_d = 1
500
- self.conv_layers = nn.ModuleList()
501
- for i, cl in enumerate(conv_layers):
502
- assert len(cl) == 3
503
- (dim, k, stride) = cl
504
-
505
- self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride))
506
- self.conv_layers.append(torch.nn.ReLU())
507
- in_d = dim
508
- elif self.conv_type == "custom":
509
- in_d = 1
510
- idim = 80
511
- self.conv_layers = nn.ModuleList()
512
- for i, cl in enumerate(conv_layers):
513
- assert len(cl) == 3
514
- (dim, k, stride) = cl
515
- self.conv_layers.append(
516
- torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
517
- )
518
- self.conv_layers.append(torch.nn.LayerNorm([dim, idim]))
519
- self.conv_layers.append(torch.nn.ReLU())
520
- in_d = dim
521
- if (i + 1) % 2 == 0:
522
- self.conv_layers.append(
523
- torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
524
- )
525
- idim = int(math.ceil(idim / 2))
526
- else:
527
- pass
528
-
529
- def forward(self, x, mask=None):
530
-
531
- # BxT -> BxCxT
532
- x = x.unsqueeze(1)
533
- if self.conv_type == "custom":
534
- for conv in self.conv_layers:
535
- if isinstance(conv, nn.LayerNorm):
536
- x = x.transpose(1, 2)
537
- x = conv(x).transpose(1, 2)
538
- else:
539
- x = conv(x)
540
- x = x.transpose(2, 3).contiguous()
541
- x = x.view(x.size(0), -1, x.size(-1))
542
- else:
543
- for conv in self.conv_layers:
544
- x = conv(x)
545
- if self.conv_type == "conv2d":
546
- b, c, t, f = x.size()
547
- x = x.transpose(2, 3).contiguous().view(b, c * f, t)
548
- return x
549
-
550
-
551
- class TransformerEncoder(nn.Module):
552
- def __init__(self, args):
553
- super().__init__()
554
-
555
- self.dropout = args.dropout
556
- self.embedding_dim = args.encoder_embed_dim
557
-
558
- self.pos_conv = nn.Conv1d(
559
- self.embedding_dim,
560
- self.embedding_dim,
561
- kernel_size=args.conv_pos,
562
- padding=args.conv_pos // 2,
563
- groups=args.conv_pos_groups,
564
- )
565
- dropout = 0
566
- std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
567
- nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
568
- nn.init.constant_(self.pos_conv.bias, 0)
569
-
570
- self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
571
- self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
572
-
573
- if hasattr(args, "relative_position_embedding"):
574
- self.relative_position_embedding = args.relative_position_embedding
575
- self.num_buckets = args.num_buckets
576
- self.max_distance = args.max_distance
577
- else:
578
- self.relative_position_embedding = False
579
- self.num_buckets = 0
580
- self.max_distance = 0
581
-
582
- self.layers = nn.ModuleList(
583
- [
584
- TransformerSentenceEncoderLayer(
585
- embedding_dim=self.embedding_dim,
586
- ffn_embedding_dim=args.encoder_ffn_embed_dim,
587
- num_attention_heads=args.encoder_attention_heads,
588
- dropout=self.dropout,
589
- attention_dropout=args.attention_dropout,
590
- activation_dropout=args.activation_dropout,
591
- activation_fn=args.activation_fn,
592
- layer_norm_first=args.layer_norm_first,
593
- has_relative_attention_bias=(
594
- self.relative_position_embedding and i == 0
595
- ),
596
- num_buckets=self.num_buckets,
597
- max_distance=self.max_distance,
598
- gru_rel_pos=args.gru_rel_pos,
599
- )
600
- for i in range(args.encoder_layers)
601
- ]
602
- )
603
-
604
- self.layer_norm_first = args.layer_norm_first
605
- self.layer_norm = LayerNorm(self.embedding_dim)
606
- self.layerdrop = args.encoder_layerdrop
607
-
608
- self.apply(init_bert_params)
609
-
610
- def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
611
- x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
612
-
613
- if self.layer_norm_first and layer is None:
614
- x = self.layer_norm(x)
615
-
616
- return x, layer_results
617
-
618
- def extract_features(
619
- self, x, padding_mask=None, streaming_mask=None, tgt_layer=None
620
- ):
621
-
622
- if padding_mask is not None:
623
- x[padding_mask] = 0
624
-
625
- x_conv = self.pos_conv(x.transpose(1, 2))
626
- x_conv = x_conv.transpose(1, 2)
627
- x = x + x_conv
628
-
629
- if not self.layer_norm_first:
630
- x = self.layer_norm(x)
631
-
632
- x = F.dropout(x, p=self.dropout, training=self.training)
633
-
634
- # B x T x C -> T x B x C
635
- x = x.transpose(0, 1)
636
-
637
- layer_results = []
638
- z = None
639
- if tgt_layer is not None:
640
- layer_results.append((x, z))
641
- r = None
642
- pos_bias = None
643
- for i, layer in enumerate(self.layers):
644
- dropout_probability = np.random.random()
645
- if not self.training or (dropout_probability > self.layerdrop):
646
- x, z, pos_bias = layer(
647
- x,
648
- self_attn_padding_mask=padding_mask,
649
- need_weights=False,
650
- self_attn_mask=streaming_mask,
651
- pos_bias=pos_bias,
652
- )
653
- if tgt_layer is not None:
654
- layer_results.append((x, z))
655
- if i == tgt_layer:
656
- r = x
657
- break
658
-
659
- if r is not None:
660
- x = r
661
-
662
- # T x B x C -> B x T x C
663
- x = x.transpose(0, 1)
664
-
665
- return x, layer_results
666
-
667
-
668
- class TransformerSentenceEncoderLayer(nn.Module):
669
- """
670
- Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
671
- models.
672
- """
673
-
674
- def __init__(
675
- self,
676
- embedding_dim: float = 768,
677
- ffn_embedding_dim: float = 3072,
678
- num_attention_heads: float = 8,
679
- dropout: float = 0.1,
680
- attention_dropout: float = 0.1,
681
- activation_dropout: float = 0.1,
682
- activation_fn: str = "relu",
683
- layer_norm_first: bool = False,
684
- has_relative_attention_bias: bool = False,
685
- num_buckets: int = 0,
686
- max_distance: int = 0,
687
- rescale_init: bool = False,
688
- gru_rel_pos: bool = False,
689
- ) -> None:
690
-
691
- super().__init__()
692
- # Initialize parameters
693
- self.embedding_dim = embedding_dim
694
- self.dropout = dropout
695
- self.activation_dropout = activation_dropout
696
-
697
- # Initialize blocks
698
- self.activation_name = activation_fn
699
- self.activation_fn = get_activation_fn(activation_fn)
700
- self.self_attn = MultiheadAttention(
701
- self.embedding_dim,
702
- num_attention_heads,
703
- dropout=attention_dropout,
704
- self_attention=True,
705
- has_relative_attention_bias=has_relative_attention_bias,
706
- num_buckets=num_buckets,
707
- max_distance=max_distance,
708
- rescale_init=rescale_init,
709
- gru_rel_pos=gru_rel_pos,
710
- )
711
-
712
- self.dropout1 = nn.Dropout(dropout)
713
- self.dropout2 = nn.Dropout(self.activation_dropout)
714
- self.dropout3 = nn.Dropout(dropout)
715
-
716
- self.layer_norm_first = layer_norm_first
717
-
718
- # layer norm associated with the self attention layer
719
- self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
720
-
721
- if self.activation_name == "glu":
722
- self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
723
- else:
724
- self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
725
- self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
726
-
727
- # layer norm associated with the position wise feed-forward NN
728
- self.final_layer_norm = LayerNorm(self.embedding_dim)
729
-
730
- def forward(
731
- self,
732
- x: torch.Tensor,
733
- self_attn_mask: torch.Tensor = None,
734
- self_attn_padding_mask: torch.Tensor = None,
735
- need_weights: bool = False,
736
- pos_bias=None,
737
- ):
738
- """
739
- LayerNorm is applied either before or after the self-attention/ffn
740
- modules similar to the original Transformer imlementation.
741
- """
742
- residual = x
743
-
744
- if self.layer_norm_first:
745
- x = self.self_attn_layer_norm(x)
746
- x, attn, pos_bias = self.self_attn(
747
- query=x,
748
- key=x,
749
- value=x,
750
- key_padding_mask=self_attn_padding_mask,
751
- need_weights=False,
752
- attn_mask=self_attn_mask,
753
- position_bias=pos_bias,
754
- )
755
- x = self.dropout1(x)
756
- x = residual + x
757
-
758
- residual = x
759
- x = self.final_layer_norm(x)
760
- if self.activation_name == "glu":
761
- x = self.fc1(x)
762
- else:
763
- x = self.activation_fn(self.fc1(x))
764
- x = self.dropout2(x)
765
- x = self.fc2(x)
766
- x = self.dropout3(x)
767
- x = residual + x
768
- else:
769
- x, attn, pos_bias = self.self_attn(
770
- query=x,
771
- key=x,
772
- value=x,
773
- key_padding_mask=self_attn_padding_mask,
774
- need_weights=need_weights,
775
- attn_mask=self_attn_mask,
776
- position_bias=pos_bias,
777
- )
778
-
779
- x = self.dropout1(x)
780
- x = residual + x
781
-
782
- x = self.self_attn_layer_norm(x)
783
-
784
- residual = x
785
- if self.activation_name == "glu":
786
- x = self.fc1(x)
787
- else:
788
- x = self.activation_fn(self.fc1(x))
789
- x = self.dropout2(x)
790
- x = self.fc2(x)
791
- x = self.dropout3(x)
792
- x = residual + x
793
- x = self.final_layer_norm(x)
794
-
795
- return x, attn, pos_bias
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
manipulate_model/encoder/wavlm/modules.py DELETED
@@ -1,827 +0,0 @@
1
- # --------------------------------------------------------
2
- # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
- # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
- # Copyright (c) 2021 Microsoft
5
- # Licensed under The MIT License [see LICENSE for details]
6
- # Based on fairseq code bases
7
- # https://github.com/pytorch/fairseq
8
- # --------------------------------------------------------
9
-
10
- import math
11
- import warnings
12
- from typing import Dict, Optional, Tuple
13
- import torch
14
- from torch import Tensor, nn
15
- from torch.nn import Parameter
16
- import torch.nn.functional as F
17
-
18
-
19
- class TransposeLast(nn.Module):
20
- def __init__(self, deconstruct_idx=None):
21
- super().__init__()
22
- self.deconstruct_idx = deconstruct_idx
23
-
24
- def forward(self, x):
25
- if self.deconstruct_idx is not None:
26
- x = x[self.deconstruct_idx]
27
- return x.transpose(-2, -1)
28
-
29
-
30
- class Fp32LayerNorm(nn.LayerNorm):
31
- def __init__(self, *args, **kwargs):
32
- super().__init__(*args, **kwargs)
33
-
34
- def forward(self, input):
35
- output = F.layer_norm(
36
- input.float(),
37
- self.normalized_shape,
38
- self.weight.float() if self.weight is not None else None,
39
- self.bias.float() if self.bias is not None else None,
40
- self.eps,
41
- )
42
- return output.type_as(input)
43
-
44
-
45
- class Fp32GroupNorm(nn.GroupNorm):
46
- def __init__(self, *args, **kwargs):
47
- super().__init__(*args, **kwargs)
48
-
49
- def forward(self, input):
50
- output = F.group_norm(
51
- input.float(),
52
- self.num_groups,
53
- self.weight.float() if self.weight is not None else None,
54
- self.bias.float() if self.bias is not None else None,
55
- self.eps,
56
- )
57
- return output.type_as(input)
58
-
59
-
60
- class GradMultiply(torch.autograd.Function):
61
- @staticmethod
62
- def forward(ctx, x, scale):
63
- ctx.scale = scale
64
- res = x.new(x)
65
- return res
66
-
67
- @staticmethod
68
- def backward(ctx, grad):
69
- return grad * ctx.scale, None
70
-
71
-
72
- class SamePad(nn.Module):
73
- def __init__(self, kernel_size, causal=False):
74
- super().__init__()
75
- if causal:
76
- self.remove = kernel_size - 1
77
- else:
78
- self.remove = 1 if kernel_size % 2 == 0 else 0
79
-
80
- def forward(self, x):
81
- if self.remove > 0:
82
- x = x[:, :, : -self.remove]
83
- return x
84
-
85
-
86
- class Swish(nn.Module):
87
- """Swish function
88
- """
89
-
90
- def __init__(self):
91
- """Construct an MultiHeadedAttention object."""
92
- super(Swish, self).__init__()
93
- self.act = torch.nn.Sigmoid()
94
-
95
- def forward(self, x):
96
- return x * self.act(x)
97
-
98
-
99
- class GLU_Linear(nn.Module):
100
- def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
101
- super(GLU_Linear, self).__init__()
102
-
103
- self.glu_type = glu_type
104
- self.output_dim = output_dim
105
-
106
- if glu_type == "sigmoid":
107
- self.glu_act = torch.nn.Sigmoid()
108
- elif glu_type == "swish":
109
- self.glu_act = Swish()
110
- elif glu_type == "relu":
111
- self.glu_act = torch.nn.ReLU()
112
- elif glu_type == "gelu":
113
- self.glu_act = torch.nn.GELU()
114
-
115
- if bias_in_glu:
116
- self.linear = nn.Linear(input_dim, output_dim * 2, True)
117
- else:
118
- self.linear = nn.Linear(input_dim, output_dim * 2, False)
119
-
120
- def forward(self, x):
121
- # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
122
- x = self.linear(x)
123
-
124
- if self.glu_type == "bilinear":
125
- x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
126
- else:
127
- x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
128
-
129
- return x
130
-
131
-
132
- def gelu_accurate(x):
133
- if not hasattr(gelu_accurate, "_a"):
134
- gelu_accurate._a = math.sqrt(2 / math.pi)
135
- return (
136
- 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
137
- )
138
-
139
-
140
- def gelu(x: torch.Tensor) -> torch.Tensor:
141
- return torch.nn.functional.gelu(x.float()).type_as(x)
142
-
143
-
144
- def get_activation_fn(activation: str):
145
- """Returns the activation function corresponding to `activation`"""
146
-
147
- if activation == "relu":
148
- return F.relu
149
- elif activation == "gelu":
150
- return gelu
151
- elif activation == "gelu_fast":
152
- warnings.warn(
153
- "--activation-fn=gelu_fast has been renamed to gelu_accurate"
154
- )
155
- return gelu_accurate
156
- elif activation == "gelu_accurate":
157
- return gelu_accurate
158
- elif activation == "tanh":
159
- return torch.tanh
160
- elif activation == "linear":
161
- return lambda x: x
162
- elif activation == "glu":
163
- return lambda x: x
164
- else:
165
- raise RuntimeError("--activation-fn {} not supported".format(activation))
166
-
167
-
168
- def init_bert_params(module):
169
- """
170
- Initialize the weights specific to the BERT Model.
171
- This overrides the default initializations depending on the specified arguments.
172
- 1. If normal_init_linear_weights is set then weights of linear
173
- layer will be initialized using the normal distribution and
174
- bais will be set to the specified value.
175
- 2. If normal_init_embed_weights is set then weights of embedding
176
- layer will be initialized using the normal distribution.
177
- 3. If normal_init_proj_weights is set then weights of
178
- in_project_weight for MultiHeadAttention initialized using
179
- the normal distribution (to be validated).
180
- """
181
-
182
- def normal_(data):
183
- # with FSDP, module params will be on CUDA, so we cast them back to CPU
184
- # so that the RNG is consistent with and without FSDP
185
- data.copy_(
186
- data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
187
- )
188
-
189
- if isinstance(module, nn.Linear):
190
- normal_(module.weight.data)
191
- if module.bias is not None:
192
- module.bias.data.zero_()
193
- if isinstance(module, nn.Embedding):
194
- normal_(module.weight.data)
195
- if module.padding_idx is not None:
196
- module.weight.data[module.padding_idx].zero_()
197
- if isinstance(module, MultiheadAttention):
198
- normal_(module.q_proj.weight.data)
199
- normal_(module.k_proj.weight.data)
200
- normal_(module.v_proj.weight.data)
201
-
202
-
203
- def quant_noise(module, p, block_size):
204
- """
205
- Wraps modules and applies quantization noise to the weights for
206
- subsequent quantization with Iterative Product Quantization as
207
- described in "Training with Quantization Noise for Extreme Model Compression"
208
-
209
- Args:
210
- - module: nn.Module
211
- - p: amount of Quantization Noise
212
- - block_size: size of the blocks for subsequent quantization with iPQ
213
-
214
- Remarks:
215
- - Module weights must have the right sizes wrt the block size
216
- - Only Linear, Embedding and Conv2d modules are supported for the moment
217
- - For more detail on how to quantize by blocks with convolutional weights,
218
- see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
219
- - We implement the simplest form of noise here as stated in the paper
220
- which consists in randomly dropping blocks
221
- """
222
-
223
- # if no quantization noise, don't register hook
224
- if p <= 0:
225
- return module
226
-
227
- # supported modules
228
- assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
229
-
230
- # test whether module.weight has the right sizes wrt block_size
231
- is_conv = module.weight.ndim == 4
232
-
233
- # 2D matrix
234
- if not is_conv:
235
- assert (
236
- module.weight.size(1) % block_size == 0
237
- ), "Input features must be a multiple of block sizes"
238
-
239
- # 4D matrix
240
- else:
241
- # 1x1 convolutions
242
- if module.kernel_size == (1, 1):
243
- assert (
244
- module.in_channels % block_size == 0
245
- ), "Input channels must be a multiple of block sizes"
246
- # regular convolutions
247
- else:
248
- k = module.kernel_size[0] * module.kernel_size[1]
249
- assert k % block_size == 0, "Kernel size must be a multiple of block size"
250
-
251
- def _forward_pre_hook(mod, input):
252
- # no noise for evaluation
253
- if mod.training:
254
- if not is_conv:
255
- # gather weight and sizes
256
- weight = mod.weight
257
- in_features = weight.size(1)
258
- out_features = weight.size(0)
259
-
260
- # split weight matrix into blocks and randomly drop selected blocks
261
- mask = torch.zeros(
262
- in_features // block_size * out_features, device=weight.device
263
- )
264
- mask.bernoulli_(p)
265
- mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
266
-
267
- else:
268
- # gather weight and sizes
269
- weight = mod.weight
270
- in_channels = mod.in_channels
271
- out_channels = mod.out_channels
272
-
273
- # split weight matrix into blocks and randomly drop selected blocks
274
- if mod.kernel_size == (1, 1):
275
- mask = torch.zeros(
276
- int(in_channels // block_size * out_channels),
277
- device=weight.device,
278
- )
279
- mask.bernoulli_(p)
280
- mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
281
- else:
282
- mask = torch.zeros(
283
- weight.size(0), weight.size(1), device=weight.device
284
- )
285
- mask.bernoulli_(p)
286
- mask = (
287
- mask.unsqueeze(2)
288
- .unsqueeze(3)
289
- .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
290
- )
291
-
292
- # scale weights and apply mask
293
- mask = mask.to(
294
- torch.bool
295
- ) # x.bool() is not currently supported in TorchScript
296
- s = 1 / (1 - p)
297
- mod.weight.data = s * weight.masked_fill(mask, 0)
298
-
299
- module.register_forward_pre_hook(_forward_pre_hook)
300
- return module
301
-
302
-
303
- class MultiheadAttention(nn.Module):
304
- """Multi-headed attention.
305
-
306
- See "Attention Is All You Need" for more details.
307
- """
308
-
309
- def __init__(
310
- self,
311
- embed_dim,
312
- num_heads,
313
- kdim=None,
314
- vdim=None,
315
- dropout=0.0,
316
- bias=True,
317
- add_bias_kv=False,
318
- add_zero_attn=False,
319
- self_attention=False,
320
- encoder_decoder_attention=False,
321
- q_noise=0.0,
322
- qn_block_size=8,
323
- has_relative_attention_bias=False,
324
- num_buckets=32,
325
- max_distance=128,
326
- gru_rel_pos=False,
327
- rescale_init=False,
328
- ):
329
- super().__init__()
330
- self.embed_dim = embed_dim
331
- self.kdim = kdim if kdim is not None else embed_dim
332
- self.vdim = vdim if vdim is not None else embed_dim
333
- self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
334
-
335
- self.num_heads = num_heads
336
- self.dropout_module = nn.Dropout(dropout)
337
-
338
- self.has_relative_attention_bias = has_relative_attention_bias
339
- self.num_buckets = num_buckets
340
- self.max_distance = max_distance
341
- if self.has_relative_attention_bias:
342
- self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
343
-
344
- self.head_dim = embed_dim // num_heads
345
- self.q_head_dim = self.head_dim
346
- self.k_head_dim = self.head_dim
347
- assert (
348
- self.head_dim * num_heads == self.embed_dim
349
- ), "embed_dim must be divisible by num_heads"
350
- self.scaling = self.head_dim ** -0.5
351
-
352
- self.self_attention = self_attention
353
- self.encoder_decoder_attention = encoder_decoder_attention
354
-
355
- assert not self.self_attention or self.qkv_same_dim, (
356
- "Self-attention requires query, key and " "value to be of the same size"
357
- )
358
-
359
- k_bias = True
360
- if rescale_init:
361
- k_bias = False
362
-
363
- k_embed_dim = embed_dim
364
- q_embed_dim = embed_dim
365
-
366
- self.k_proj = quant_noise(
367
- nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
368
- )
369
- self.v_proj = quant_noise(
370
- nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
371
- )
372
- self.q_proj = quant_noise(
373
- nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
374
- )
375
-
376
- self.out_proj = quant_noise(
377
- nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
378
- )
379
-
380
- if add_bias_kv:
381
- self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
382
- self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
383
- else:
384
- self.bias_k = self.bias_v = None
385
-
386
- self.add_zero_attn = add_zero_attn
387
-
388
- self.gru_rel_pos = gru_rel_pos
389
- if self.gru_rel_pos:
390
- self.grep_linear = nn.Linear(self.q_head_dim, 8)
391
- self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
392
-
393
- self.reset_parameters()
394
-
395
- def reset_parameters(self):
396
- if self.qkv_same_dim:
397
- # Empirically observed the convergence to be much better with
398
- # the scaled initialization
399
- nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
400
- nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
401
- nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
402
- else:
403
- nn.init.xavier_uniform_(self.k_proj.weight)
404
- nn.init.xavier_uniform_(self.v_proj.weight)
405
- nn.init.xavier_uniform_(self.q_proj.weight)
406
-
407
- nn.init.xavier_uniform_(self.out_proj.weight)
408
- if self.out_proj.bias is not None:
409
- nn.init.constant_(self.out_proj.bias, 0.0)
410
- if self.bias_k is not None:
411
- nn.init.xavier_normal_(self.bias_k)
412
- if self.bias_v is not None:
413
- nn.init.xavier_normal_(self.bias_v)
414
- if self.has_relative_attention_bias:
415
- nn.init.xavier_normal_(self.relative_attention_bias.weight)
416
-
417
- def _relative_positions_bucket(self, relative_positions, bidirectional=True):
418
- num_buckets = self.num_buckets
419
- max_distance = self.max_distance
420
- relative_buckets = 0
421
-
422
- if bidirectional:
423
- num_buckets = num_buckets // 2
424
- relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
425
- relative_positions = torch.abs(relative_positions)
426
- else:
427
- relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
428
-
429
- max_exact = num_buckets // 2
430
- is_small = relative_positions < max_exact
431
-
432
- relative_postion_if_large = max_exact + (
433
- torch.log(relative_positions.float() / max_exact)
434
- / math.log(max_distance / max_exact)
435
- * (num_buckets - max_exact)
436
- ).to(torch.long)
437
- relative_postion_if_large = torch.min(
438
- relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
439
- )
440
-
441
- relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
442
- return relative_buckets
443
-
444
- def compute_bias(self, query_length, key_length):
445
- context_position = torch.arange(query_length, dtype=torch.long)[:, None]
446
- memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
447
- relative_position = memory_position - context_position
448
- relative_position_bucket = self._relative_positions_bucket(
449
- relative_position,
450
- bidirectional=True
451
- )
452
- relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
453
- values = self.relative_attention_bias(relative_position_bucket)
454
- values = values.permute([2, 0, 1])
455
- return values
456
-
457
- def forward(
458
- self,
459
- query,
460
- key: Optional[Tensor],
461
- value: Optional[Tensor],
462
- key_padding_mask: Optional[Tensor] = None,
463
- incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
464
- need_weights: bool = True,
465
- static_kv: bool = False,
466
- attn_mask: Optional[Tensor] = None,
467
- before_softmax: bool = False,
468
- need_head_weights: bool = False,
469
- position_bias: Optional[Tensor] = None
470
- ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
471
- """Input shape: Time x Batch x Channel
472
-
473
- Args:
474
- key_padding_mask (ByteTensor, optional): mask to exclude
475
- keys that are pads, of shape `(batch, src_len)`, where
476
- padding elements are indicated by 1s.
477
- need_weights (bool, optional): return the attention weights,
478
- averaged over heads (default: False).
479
- attn_mask (ByteTensor, optional): typically used to
480
- implement causal attention, where the mask prevents the
481
- attention from looking forward in time (default: None).
482
- before_softmax (bool, optional): return the raw attention
483
- weights and values before the attention softmax.
484
- need_head_weights (bool, optional): return the attention
485
- weights for each head. Implies *need_weights*. Default:
486
- return the average attention weights over all heads.
487
- """
488
- if need_head_weights:
489
- need_weights = True
490
-
491
- is_tpu = query.device.type == "xla"
492
-
493
- tgt_len, bsz, embed_dim = query.size()
494
- src_len = tgt_len
495
- assert embed_dim == self.embed_dim
496
- assert list(query.size()) == [tgt_len, bsz, embed_dim]
497
- if key is not None:
498
- src_len, key_bsz, _ = key.size()
499
- if not torch.jit.is_scripting():
500
- assert key_bsz == bsz
501
- assert value is not None
502
- assert src_len, bsz == value.shape[:2]
503
-
504
- if self.has_relative_attention_bias and position_bias is None:
505
- position_bias = self.compute_bias(tgt_len, src_len)
506
- position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
507
-
508
- if (
509
- not is_tpu # don't use PyTorch version on TPUs
510
- and incremental_state is None
511
- and not static_kv
512
- # A workaround for quantization to work. Otherwise JIT compilation
513
- # treats bias in linear module as method.
514
- and not torch.jit.is_scripting()
515
- and self.q_head_dim == self.head_dim
516
- ):
517
- assert key is not None and value is not None
518
- assert attn_mask is None
519
-
520
- attn_mask_rel_pos = None
521
- if position_bias is not None:
522
- attn_mask_rel_pos = position_bias
523
- if self.gru_rel_pos:
524
- query_layer = query.transpose(0, 1)
525
- new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
526
- query_layer = query_layer.view(*new_x_shape)
527
- query_layer = query_layer.permute(0, 2, 1, 3)
528
- _B, _H, _L, __ = query_layer.size()
529
-
530
- gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
531
- _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
532
- gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
533
- attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
534
-
535
- attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
536
- k_proj_bias = self.k_proj.bias
537
- if k_proj_bias is None:
538
- k_proj_bias = torch.zeros_like(self.q_proj.bias)
539
-
540
- x, attn = F.multi_head_attention_forward(
541
- query,
542
- key,
543
- value,
544
- self.embed_dim,
545
- self.num_heads,
546
- torch.empty([0]),
547
- torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
548
- self.bias_k,
549
- self.bias_v,
550
- self.add_zero_attn,
551
- self.dropout_module.p,
552
- self.out_proj.weight,
553
- self.out_proj.bias,
554
- self.training,
555
- # self.training or self.dropout_module.apply_during_inference,
556
- key_padding_mask,
557
- need_weights,
558
- attn_mask_rel_pos,
559
- use_separate_proj_weight=True,
560
- q_proj_weight=self.q_proj.weight,
561
- k_proj_weight=self.k_proj.weight,
562
- v_proj_weight=self.v_proj.weight,
563
- )
564
- return x, attn, position_bias
565
-
566
- if incremental_state is not None:
567
- saved_state = self._get_input_buffer(incremental_state)
568
- if saved_state is not None and "prev_key" in saved_state:
569
- # previous time steps are cached - no need to recompute
570
- # key and value if they are static
571
- if static_kv:
572
- assert self.encoder_decoder_attention and not self.self_attention
573
- key = value = None
574
- else:
575
- saved_state = None
576
-
577
- if self.self_attention:
578
- q = self.q_proj(query)
579
- k = self.k_proj(query)
580
- v = self.v_proj(query)
581
- elif self.encoder_decoder_attention:
582
- # encoder-decoder attention
583
- q = self.q_proj(query)
584
- if key is None:
585
- assert value is None
586
- k = v = None
587
- else:
588
- k = self.k_proj(key)
589
- v = self.v_proj(key)
590
-
591
- else:
592
- assert key is not None and value is not None
593
- q = self.q_proj(query)
594
- k = self.k_proj(key)
595
- v = self.v_proj(value)
596
- q *= self.scaling
597
-
598
- if self.bias_k is not None:
599
- assert self.bias_v is not None
600
- k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
601
- v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
602
- if attn_mask is not None:
603
- attn_mask = torch.cat(
604
- [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
605
- )
606
- if key_padding_mask is not None:
607
- key_padding_mask = torch.cat(
608
- [
609
- key_padding_mask,
610
- key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
611
- ],
612
- dim=1,
613
- )
614
-
615
- q = (
616
- q.contiguous()
617
- .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
618
- .transpose(0, 1)
619
- )
620
- if k is not None:
621
- k = (
622
- k.contiguous()
623
- .view(-1, bsz * self.num_heads, self.k_head_dim)
624
- .transpose(0, 1)
625
- )
626
- if v is not None:
627
- v = (
628
- v.contiguous()
629
- .view(-1, bsz * self.num_heads, self.head_dim)
630
- .transpose(0, 1)
631
- )
632
-
633
- if saved_state is not None:
634
- # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
635
- if "prev_key" in saved_state:
636
- _prev_key = saved_state["prev_key"]
637
- assert _prev_key is not None
638
- prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
639
- if static_kv:
640
- k = prev_key
641
- else:
642
- assert k is not None
643
- k = torch.cat([prev_key, k], dim=1)
644
- src_len = k.size(1)
645
- if "prev_value" in saved_state:
646
- _prev_value = saved_state["prev_value"]
647
- assert _prev_value is not None
648
- prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
649
- if static_kv:
650
- v = prev_value
651
- else:
652
- assert v is not None
653
- v = torch.cat([prev_value, v], dim=1)
654
- prev_key_padding_mask: Optional[Tensor] = None
655
- if "prev_key_padding_mask" in saved_state:
656
- prev_key_padding_mask = saved_state["prev_key_padding_mask"]
657
- assert k is not None and v is not None
658
- key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
659
- key_padding_mask=key_padding_mask,
660
- prev_key_padding_mask=prev_key_padding_mask,
661
- batch_size=bsz,
662
- src_len=k.size(1),
663
- static_kv=static_kv,
664
- )
665
-
666
- saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
667
- saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
668
- saved_state["prev_key_padding_mask"] = key_padding_mask
669
- # In this branch incremental_state is never None
670
- assert incremental_state is not None
671
- incremental_state = self._set_input_buffer(incremental_state, saved_state)
672
- assert k is not None
673
- assert k.size(1) == src_len
674
-
675
- # This is part of a workaround to get around fork/join parallelism
676
- # not supporting Optional types.
677
- if key_padding_mask is not None and key_padding_mask.dim() == 0:
678
- key_padding_mask = None
679
-
680
- if key_padding_mask is not None:
681
- assert key_padding_mask.size(0) == bsz
682
- assert key_padding_mask.size(1) == src_len
683
-
684
- if self.add_zero_attn:
685
- assert v is not None
686
- src_len += 1
687
- k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
688
- v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
689
- if attn_mask is not None:
690
- attn_mask = torch.cat(
691
- [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
692
- )
693
- if key_padding_mask is not None:
694
- key_padding_mask = torch.cat(
695
- [
696
- key_padding_mask,
697
- torch.zeros(key_padding_mask.size(0), 1).type_as(
698
- key_padding_mask
699
- ),
700
- ],
701
- dim=1,
702
- )
703
-
704
- attn_weights = torch.bmm(q, k.transpose(1, 2))
705
- attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
706
-
707
- assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
708
-
709
- if attn_mask is not None:
710
- attn_mask = attn_mask.unsqueeze(0)
711
- attn_weights += attn_mask
712
-
713
- if key_padding_mask is not None:
714
- # don't attend to padding symbols
715
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
716
- if not is_tpu:
717
- attn_weights = attn_weights.masked_fill(
718
- key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
719
- float("-inf"),
720
- )
721
- else:
722
- attn_weights = attn_weights.transpose(0, 2)
723
- attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
724
- attn_weights = attn_weights.transpose(0, 2)
725
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
726
-
727
- if before_softmax:
728
- return attn_weights, v, position_bias
729
-
730
- if position_bias is not None:
731
- if self.gru_rel_pos == 1:
732
- query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
733
- _B, _H, _L, __ = query_layer.size()
734
- gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
735
- _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
736
- gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
737
- position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
738
-
739
- position_bias = position_bias.view(attn_weights.size())
740
-
741
- attn_weights = attn_weights + position_bias
742
-
743
- attn_weights_float = F.softmax(
744
- attn_weights, dim=-1
745
- )
746
- attn_weights = attn_weights_float.type_as(attn_weights)
747
- attn_probs = self.dropout_module(attn_weights)
748
-
749
- assert v is not None
750
- attn = torch.bmm(attn_probs, v)
751
- assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
752
- attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
753
- attn = self.out_proj(attn)
754
- attn_weights: Optional[Tensor] = None
755
- if need_weights:
756
- attn_weights = attn_weights_float.view(
757
- bsz, self.num_heads, tgt_len, src_len
758
- ).transpose(1, 0)
759
- if not need_head_weights:
760
- # average attention weights over heads
761
- attn_weights = attn_weights.mean(dim=0)
762
-
763
- return attn, attn_weights, position_bias
764
-
765
- @staticmethod
766
- def _append_prev_key_padding_mask(
767
- key_padding_mask: Optional[Tensor],
768
- prev_key_padding_mask: Optional[Tensor],
769
- batch_size: int,
770
- src_len: int,
771
- static_kv: bool,
772
- ) -> Optional[Tensor]:
773
- # saved key padding masks have shape (bsz, seq_len)
774
- if prev_key_padding_mask is not None and static_kv:
775
- new_key_padding_mask = prev_key_padding_mask
776
- elif prev_key_padding_mask is not None and key_padding_mask is not None:
777
- new_key_padding_mask = torch.cat(
778
- [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
779
- )
780
- # During incremental decoding, as the padding token enters and
781
- # leaves the frame, there will be a time when prev or current
782
- # is None
783
- elif prev_key_padding_mask is not None:
784
- if src_len > prev_key_padding_mask.size(1):
785
- filler = torch.zeros(
786
- (batch_size, src_len - prev_key_padding_mask.size(1)),
787
- device=prev_key_padding_mask.device,
788
- )
789
- new_key_padding_mask = torch.cat(
790
- [prev_key_padding_mask.float(), filler.float()], dim=1
791
- )
792
- else:
793
- new_key_padding_mask = prev_key_padding_mask.float()
794
- elif key_padding_mask is not None:
795
- if src_len > key_padding_mask.size(1):
796
- filler = torch.zeros(
797
- (batch_size, src_len - key_padding_mask.size(1)),
798
- device=key_padding_mask.device,
799
- )
800
- new_key_padding_mask = torch.cat(
801
- [filler.float(), key_padding_mask.float()], dim=1
802
- )
803
- else:
804
- new_key_padding_mask = key_padding_mask.float()
805
- else:
806
- new_key_padding_mask = prev_key_padding_mask
807
- return new_key_padding_mask
808
-
809
- def _get_input_buffer(
810
- self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
811
- ) -> Dict[str, Optional[Tensor]]:
812
- result = self.get_incremental_state(incremental_state, "attn_state")
813
- if result is not None:
814
- return result
815
- else:
816
- empty_result: Dict[str, Optional[Tensor]] = {}
817
- return empty_result
818
-
819
- def _set_input_buffer(
820
- self,
821
- incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
822
- buffer: Dict[str, Optional[Tensor]],
823
- ):
824
- return self.set_incremental_state(incremental_state, "attn_state", buffer)
825
-
826
- def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
827
- return attn_weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
manipulate_model/encoder_checkpoints/wavlm/WavLM-Base+.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:fcbcf2a94def92e90e086bb0727275d53b75a9c0e483e2abfa560ac951986b6d
3
- size 377604817
 
 
 
 
manipulate_model/model.py DELETED
@@ -1,25 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from manipulate_model.encoder.encoder import Encoder
5
- from manipulate_model.decoder.decoder import Decoder
6
-
7
-
8
- class Model(nn.Module):
9
- def __init__(self, config):
10
- super(Model, self).__init__()
11
- self.config = config
12
-
13
- self.encoder = Encoder(self.config)
14
- self.config.model.decoder.temporal_dim = self.encoder.get_temporal_dim()
15
- self.config.model.decoder.encoding_dim = self.encoder.get_encoding_dim()
16
- self.decoder = Decoder(self.config)
17
-
18
- def forward(self, x):
19
- if self.config.model.encoder_freeze:
20
- with torch.no_grad():
21
- x = self.encoder(x)
22
- else:
23
- x = self.encoder(x)
24
- x = self.decoder(x)
25
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
manipulate_model/utils.py DELETED
@@ -1,138 +0,0 @@
1
- import os
2
- import torch
3
- import torchaudio
4
- import numpy as np
5
- from omegaconf import OmegaConf
6
- from torchvision.io import read_video
7
- from torch.nn.functional import pad, normalize, softmax
8
-
9
- from manipulate_model.model import Model
10
-
11
-
12
-
13
- def get_config_and_model(model_root="manipulate_model/demo-model/audio"):
14
- config_path = os.path.join(model_root, "config.yaml")
15
- config = OmegaConf.load(config_path)
16
- if isinstance(config.model.encoder, str):
17
- config.model.encoder = OmegaConf.load(config.model.encoder)
18
- if isinstance(config.model.decoder, str):
19
- config.model.decoder = OmegaConf.load(config.model.decoder)
20
-
21
- model = Model(config)
22
- weights = torch.load(os.path.join(model_root, "weights.pt"))
23
- model.load_state_dict(weights["model_state_dict"])
24
-
25
- return config, model
26
-
27
-
28
- def load_audio(file_path, config):
29
- # Load audio
30
- # Parameters
31
- # ----------
32
- # file_path : str
33
- # Path to audio file
34
- # Returns
35
- # -------
36
- # torch.Tensor
37
-
38
- audio = None
39
-
40
- if file_path.endswith(".wav") or file_path.endswith(".flac"):
41
- audio, sample_rate = torchaudio.load(file_path)
42
- elif file_path.endswith(".mp3"):
43
- pass
44
- elif file_path.endswith(".mp4"):
45
- _, audio, _ = read_video(file_path)
46
-
47
- return preprocess_audio(audio, config)
48
-
49
-
50
- def preprocess_audio(audio, config, step_size=1):
51
- # Preprocess audio
52
- # Parameters
53
- # ----------
54
- # audio : torch.Tensor
55
- # Audio signal
56
- # config : OmegaConf
57
- # Configuration object
58
- # Returns
59
- # -------
60
- # torch.Tensor : Normalized audio signal
61
-
62
- window_size = config.data.window_size
63
- sr = config.data.sr
64
- fps = config.data.fps
65
-
66
- audio_len = audio.shape[1]
67
- step_size = step_size * (sr // fps)
68
- window_size = window_size * (sr // fps)
69
- audio = pad(audio, (window_size, window_size), "constant", 0)
70
-
71
- sliced_audio = []
72
-
73
- for i in range(0, audio_len + window_size, step_size):
74
- audio_slice = audio[:, i : i + window_size]
75
-
76
- if audio_slice.shape[1] < window_size:
77
- audio_slice = pad(
78
- audio_slice, (0, window_size - audio_slice.shape[1]), "constant", 0
79
- )
80
-
81
- audio_slice = normalize(audio_slice, dim=1)
82
- sliced_audio.append(audio_slice)
83
-
84
- sliced_audio = torch.stack(sliced_audio).squeeze()
85
-
86
- return sliced_audio
87
-
88
-
89
- def infere(model, x, config, device="cpu", bs=8):
90
- print(x)
91
- model.eval()
92
-
93
- x = load_audio(x, config)
94
-
95
- # Inference (x is a stack of windows)
96
- frame_predictions = []
97
-
98
- with torch.no_grad():
99
- n_iter = x.shape[0]
100
-
101
- for i in range(0, n_iter, bs):
102
- input_batch = x[i: i + bs]
103
- input_batch = input_batch.to(device)
104
-
105
- output = softmax(model(input_batch), dim=1)
106
- frame_predictions.append(output.cpu().numpy())
107
-
108
- frame_predictions = np.concatenate(frame_predictions, axis=0)[:,0]
109
-
110
-
111
- return frame_predictions
112
-
113
- def convert_frame_predictions_to_timestamps(frame_predictions, fps, window_size):
114
- # Convert frame predictions to timestamps
115
- # Parameters
116
- # ----------
117
- # frame_predictions : np.ndarray
118
- # Frame predictions
119
- # fps : int
120
- # Frames per second
121
- # Returns
122
- # -------
123
- # np.ndarray : Timestamps
124
-
125
- frame_predictions = (
126
- frame_predictions[
127
- int(window_size / 2) : -int(window_size / 2), 0
128
- ] # removes the padding, does not consider step size as of now
129
- .round()
130
- .astype(int)
131
- )
132
- timestamps = []
133
-
134
- for i, frame_prediction in enumerate(frame_predictions):
135
- if frame_prediction == 1:
136
- timestamps.append(i / fps)
137
-
138
- return timestamps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,6 +1,5 @@
1
  gradio==3.50.2
2
  torch==2.3.0
3
- torchaudio ==2.3.0
4
  fairseq @ git+https://github.com/facebookresearch/fairseq.git
5
  librosa==0.10.1
6
  numpy==1.24.4
 
1
  gradio==3.50.2
2
  torch==2.3.0
 
3
  fairseq @ git+https://github.com/facebookresearch/fairseq.git
4
  librosa==0.10.1
5
  numpy==1.24.4