Sin2pi commited on
Commit
b86695c
·
verified ·
1 Parent(s): 6bcc57f

Update model3.py

Browse files

fixed issue with hf trainer / evaluate

Files changed (1) hide show
  1. model3.py +938 -1055
model3.py CHANGED
@@ -1,1055 +1,938 @@
1
-
2
- import os
3
- import evaluate
4
- import json
5
- import logging
6
- import random
7
- import sys
8
- import time
9
- import torch
10
- import transformers
11
- import warnings
12
- import math
13
- import neologdn
14
- import gzip
15
- import base64
16
- import numpy as np
17
- import torch.nn as nn
18
- import torch.nn.functional as F
19
- from torch import amp, Tensor, optim
20
- from torch.utils.checkpoint import checkpoint
21
- from torch.optim import Adamax
22
- from torch.utils.tensorboard import SummaryWriter
23
- from typing import Optional, Tuple, Dict, List, Any, Union
24
- from dataclasses import dataclass
25
- from transformers import (
26
- WhisperPreTrainedModel, WhisperConfig, Trainer,
27
- TrainingArguments, WhisperTokenizer, WhisperFeatureExtractor,
28
- WhisperProcessor, TrainerCallback, Seq2SeqTrainer, Seq2SeqTrainingArguments, AutoTokenizer
29
- )
30
- from transformers.models.whisper.modeling_whisper import WhisperPreTrainedModel
31
- from transformers.models.whisper.generation_whisper import WhisperGenerationMixin
32
- from transformers.optimization import Adafactor, AdafactorSchedule
33
- from huggingface_hub import PyTorchModelHubMixin
34
- from datasets import load_from_disk, load_dataset
35
- from tqdm import tqdm
36
- from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
37
- from sklearn.model_selection import train_test_split
38
- from whisper.decoding import decode as decode_function
39
- from whisper.decoding import detect_language as detect_language_function
40
- from whisper.transcribe import transcribe as transcribe_function
41
-
42
- try:
43
- from torch.nn.functional import scaled_dot_product_attention
44
- SDPA_AVAILABLE = True
45
- except (ImportError, RuntimeError, OSError):
46
- scaled_dot_product_attention = None
47
- SDPA_AVAILABLE = False
48
-
49
- transformers.utils.logging.set_verbosity_error()
50
- warnings.filterwarnings(action="ignore")
51
- warnings.warn = lambda *args,**kwargs: None
52
- device = "cuda"
53
-
54
- class LayerNorm(nn.Module):
55
- def __init__(self, num_features, eps=1e-6):
56
- super(LayerNorm, self).__init__()
57
- self.gamma = nn.Parameter(torch.ones(num_features))
58
- self.beta = nn.Parameter(torch.zeros(num_features))
59
- self.eps = eps
60
-
61
- def forward(self, x):
62
- mean = x.mean(dim=-1, keepdim=True)
63
- std = x.std(dim=-1, keepdim=True)
64
- x = (x - mean) / (std + self.eps)
65
- return self.gamma * x + self.beta
66
-
67
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68
-
69
- class Linear(nn.Module):
70
- def __init__(self, in_features: int, out_features: int, dropout_rate = 0.01, use_batchnorm: bool = True, activation: str = 'relu'):
71
- super(Linear, self).__init__()
72
- self.linear = nn.Linear(in_features, out_features)
73
- self.dropout = nn.Dropout(dropout_rate)
74
- self.use_batchnorm = use_batchnorm
75
- self.activation = activation
76
-
77
- if self.use_batchnorm:
78
- self.batchnorm = nn.BatchNorm1d(out_features)
79
- self.reset_parameters()
80
-
81
- def reset_parameters(self):
82
- nn.init.kaiming_uniform_(self.linear.weight, nonlinearity=self.activation)
83
- if self.linear.bias is not None:
84
- nn.init.zeros_(self.linear.bias)
85
-
86
- def forward(self, x):
87
- batch_size, seq_len, _ = x.size()
88
- x = x.view(-1, x.size(-1))
89
- x = self.linear(x)
90
-
91
- if self.use_batchnorm:
92
- x = self.batchnorm(x)
93
-
94
- x = self.apply_activation(x)
95
- x = self.dropout(x)
96
- x = x.view(batch_size, seq_len, -1)
97
-
98
- return x
99
-
100
- def apply_activation(self, x):
101
- if self.activation == 'relu':
102
- return F.relu(x)
103
- elif self.activation == 'tanh':
104
- return torch.tanh(x)
105
- elif self.activation == 'sigmoid':
106
- return torch.sigmoid(x)
107
- else:
108
- raise ValueError(f'Unsupported activation function: {self.activation}')
109
-
110
- class Conv1d(nn.Conv1d):
111
- def __init__(self, *args, **kwargs):
112
- super().__init__(*args, **kwargs)
113
- self.reset_parameters()
114
-
115
- def reset_parameters(self):
116
- nn.init.kaiming_uniform_(self.weight, nonlinearity='relu')
117
- if self.bias is not None:
118
- nn.init.zeros_(self.bias)
119
-
120
- def _conv_forward(self, x, weight, bias) -> Tensor:
121
- weight = self.weight.to(x.dtype)
122
- bias = None if self.bias is None else self.bias.to(x.dtype)
123
- return super()._conv_forward(x, weight, bias)
124
-
125
- def givens_rotation_matrix(n_state, i, j, theta):
126
- G = torch.eye(n_state)
127
- G[i, i] = math.cos(theta)
128
- G[i, j] = -math.sin(theta)
129
- G[j, i] = math.sin(theta)
130
- G[j, j] = math.cos(theta)
131
- return G
132
-
133
- class GivensRotations(nn.Module):
134
- def __init__(self, h_dim, num_rotations):
135
- super().__init__()
136
- self.h_dim = h_dim
137
- self.num_rotations = num_rotations
138
- self.thetas = nn.Parameter(torch.zeros(num_rotations))
139
-
140
- def forward(self, x):
141
- if x.dim() != 4:
142
- raise ValueError(f"Expected input tensor to be 4D, but got {x.dim()}D")
143
-
144
- batch_size, seq_len, n_head, h_dim = x.size()
145
-
146
- if h_dim != self.h_dim:
147
- raise ValueError(f"Expected h_dim of {self.h_dim}, but got {h_dim}")
148
-
149
- x = x.view(-1, h_dim)
150
- for k in range(self.num_rotations):
151
- i, j = k % self.h_dim, (k + 1) % self.h_dim
152
- G = givens_rotation_matrix(self.h_dim, i, j, self.thetas[k])
153
- x = torch.matmul(x, G.to(x.device))
154
-
155
- x = x.view(batch_size, seq_len, n_head, h_dim)
156
- return x
157
-
158
- class BiasedCrossAttention(nn.Module):
159
- def __init__(self, n_state, n_head, dropout_rate=0.1):
160
- super().__init__()
161
- self.n_head = n_head
162
- self.n_state = n_state
163
- self.head_dim = n_state // n_head
164
-
165
- self.query = nn.Linear(n_state, n_state)
166
- self.key = nn.Linear(n_state, n_state, bias=False)
167
- self.value = nn.Linear(n_state, n_state)
168
- self.out = nn.Linear(n_state, n_state)
169
-
170
- self.bias = nn.Parameter(torch.zeros(n_head, 1, self.head_dim))
171
- self.dropout = nn.Dropout(dropout_rate)
172
- self.norm = LayerNorm(n_state)
173
-
174
- def forward(self, q, k, v, mask=None):
175
- batch_size, seq_length, _ = q.size()
176
-
177
- q = self.query(q).view(batch_size, seq_length, self.n_head, self.head_dim)
178
- k = self.key(k).view(batch_size, seq_length, self.n_head, self.head_dim)
179
- v = self.value(v).view(batch_size, seq_length, self.n_head, self.head_dim)
180
-
181
- qk = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) + self.bias
182
- if mask is not None:
183
- qk = qk.masked_fill(mask == 0, float('-inf'))
184
-
185
- w = F.softmax(qk, dim=-1)
186
- w = self.dropout(w)
187
-
188
- out = (w @ v).transpose(1, 2).contiguous().view(batch_size, seq_length, -1)
189
- out = self.norm(self.out(out) + q.view(batch_size, seq_length, -1))
190
- return out
191
-
192
- class DynamicConvAttention(nn.Module):
193
- def __init__(self, n_state, n_head, kernel_size=3, dropout_rate=0.1):
194
- super().__init__()
195
- self.n_state = n_state
196
- self.n_head = n_head
197
- self.kernel_size = kernel_size
198
-
199
- self.conv = nn.Conv1d(n_state, n_state, kernel_size, padding=kernel_size // 2, groups=n_head)
200
- self.dropout = nn.Dropout(dropout_rate)
201
-
202
- self.query = nn.Linear(n_state, n_state)
203
- self.key = nn.Linear(n_state, n_state, bias=False)
204
- self.value = nn.Linear(n_state, n_state)
205
- self.out_proj = nn.Linear(n_state, n_state)
206
-
207
- self.norm = LayerNorm(n_state)
208
-
209
- def forward(self, x):
210
- batch_size, seq_len, embed_dim = x.size()
211
- if embed_dim != self.n_state:
212
- raise ValueError(f"Expected embed_dim of {self.n_state}, but got {embed_dim}")
213
-
214
- q = self.query(x)
215
- k = self.key(x)
216
- v = self.value(x)
217
-
218
- x = x.permute(0, 2, 1)
219
- conv_out = self.conv(x)
220
- conv_out = conv_out.permute(0, 2, 1)
221
- conv_out = self.norm(conv_out)
222
- conv_out = self.dropout(conv_out)
223
-
224
- attention_out = F.softmax(torch.matmul(q, k.transpose(-2, -1)) / (self.n_state ** 0.5), dim=-1)
225
- attention_out = torch.matmul(attention_out, v)
226
-
227
- combined_out = conv_out + attention_out
228
- combined_out = self.norm(combined_out)
229
-
230
- return self.out_proj(self.dropout(combined_out)) + x.permute(0, 2, 1)
231
-
232
- class HybridAttention(nn.Module):
233
- def __init__(self, n_state, n_head, window_size=1, dropout_rate=0.1):
234
- super().__init__()
235
- self.local_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)
236
- self.global_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)
237
- self.ln_local = LayerNorm(n_state)
238
- self.ln_global = LayerNorm(n_state)
239
-
240
- self.dropout = nn.Dropout(dropout_rate)
241
- self.window_size = window_size
242
-
243
- def forward(self, x):
244
- x_local = self.ln_local(x)
245
- x_global = self.ln_global(x)
246
- x_local = x_local.permute(1, 0, 2)
247
- x_global = x_global.permute(1, 0, 2)
248
- local_out = self.sliding_window_attention(x_local)
249
- global_out, _ = self.global_attn(x_global, x_global, x_global)
250
- combined_out = local_out + global_out
251
- combined_out = combined_out.permute(1, 0, 2)
252
- return self.dropout(combined_out)
253
-
254
- def sliding_window_attention(self, x):
255
- seq_len, batch_size, n_state = x.size()
256
- window_size = min(self.window_size, max(1, seq_len // 4))
257
- output = torch.zeros_like(x, device=x.device, dtype=x.dtype)
258
-
259
- for i in range(0, seq_len, window_size):
260
- end = min(i + window_size, seq_len)
261
- query = x[i:end, :, :]
262
- start = max(0, i - window_size)
263
- key = x[start:end, :, :]
264
- value = x[start:end, :, :]
265
- attn_output, _ = self.local_attn(query, key, value)
266
- output[i:end, :, :] = attn_output[:end - i, :, :]
267
-
268
- return output
269
-
270
- class RotaryEmbeddingWithRotation(nn.Module):
271
- def __init__(self, n_state, n_head, base=10000, checkpointing=False):
272
- super().__init__()
273
- self.n_state = n_state
274
- self.n_head = n_head
275
- self.h_dim = n_state // n_head
276
- self.base = base # Initialize base
277
- self.checkpointing = checkpointing
278
-
279
- self.rotation_matrix = nn.Parameter(torch.eye(self.h_dim))
280
- inv_freq = 1.0 / (base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))
281
- self.register_buffer('inv_freq', inv_freq)
282
-
283
- def update_base(self, new_base):
284
- self.base = new_base
285
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))
286
- self.register_buffer('inv_freq', inv_freq)
287
-
288
- def reset_parameters(self):
289
- nn.init.orthogonal_(self.rotation_matrix)
290
-
291
- def forward(self, x):
292
- if self.checkpointing:
293
- return checkpoint(self._forward, x)
294
- else:
295
- return self._forward(x)
296
-
297
- def _forward(self, x):
298
- if x.dim() == 3:
299
- batch_size, seq_len, n_state = x.size()
300
- elif x.dim() == 4:
301
- batch_size, seq_len, n_head, h_dim = x.size()
302
- n_state = n_head * h_dim
303
- x = x.view(batch_size, seq_len, n_state)
304
- else:
305
- raise ValueError(f"Expected input tensor to be 3D or 4D, but got {x.dim()}D")
306
-
307
- if n_state != self.n_state:
308
- raise ValueError(f"Expected n_state of {self.n_state}, but got {n_state}")
309
-
310
- x = x.reshape(batch_size, seq_len, self.n_head, self.h_dim)
311
- x = x.reshape(-1, self.h_dim)
312
- rotated_x = torch.matmul(x, self.rotation_matrix)
313
- rotated_x = rotated_x.reshape(batch_size, seq_len, self.n_head, self.h_dim)
314
-
315
- sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(seq_len, device=x.device), self.inv_freq.to(x.device))
316
- sin = sinusoid_inp.sin()[None, :, None, :]
317
- cos = sinusoid_inp.cos()[None, :, None, :]
318
- x1, x2 = rotated_x[..., ::2], rotated_x[..., 1::2]
319
- rotated_x = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
320
-
321
- rotated_x = rotated_x.reshape(batch_size, seq_len, self.n_state)
322
- return rotated_x
323
-
324
- class LearnedSinusoidalEmbeddings(nn.Module):
325
- def __init__(self, n_ctx, n_state, checkpointing=False):
326
- super().__init__()
327
- self.n_ctx = n_ctx
328
- self.n_state = n_state
329
- self.checkpointing = checkpointing
330
-
331
- position = torch.arange(0, n_ctx, dtype=torch.float).unsqueeze(1)
332
- div_term = torch.exp(torch.arange(0, n_state, 2).float() * -(math.log(10000.0) / n_state))
333
- features = torch.zeros(n_ctx, n_state)
334
- features[:, 0::2] = torch.sin(position * div_term)
335
- features[:, 1::2] = torch.cos(position * div_term)
336
- self.register_buffer('sinusoidal_features', features)
337
-
338
- self.positional_embeddings = nn.Parameter(self.sinusoidal_features.clone())
339
-
340
- def forward(self, positions):
341
- if self.checkpointing:
342
- position_embeddings = checkpoint(lambda x: self.positional_embeddings[x], positions)
343
- else:
344
- position_embeddings = self.positional_embeddings[positions]
345
-
346
- position_embeddings = torch.nn.functional.normalize(position_embeddings, p=2, dim=-1)
347
- return position_embeddings
348
-
349
- class MultiHeadAttention(nn.Module):
350
- use_sdpa = True
351
-
352
- def __init__(self, n_state: int, n_head: int, base: int = 10000, max_rel_dist: int = 1):
353
- super().__init__()
354
- assert n_state % n_head == 0, "n_state must be divisible by n_head"
355
- self.n_head = n_head
356
- self.h_dim = n_state // n_head
357
- assert self.h_dim % 2 == 0, "Head dimension must be even for rotary embeddings"
358
-
359
- self.positional_scaling = nn.Parameter(torch.ones(1))
360
-
361
- self.query = nn.Linear(n_state, n_state)
362
- self.key = nn.Linear(n_state, n_state, bias=False)
363
- self.value = nn.Linear(n_state, n_state)
364
- self.out = nn.Linear(n_state, n_state)
365
-
366
- self.max_rel_dist = max_rel_dist
367
- self.base = base
368
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))
369
- self.register_buffer('inv_freq', inv_freq)
370
-
371
- self.rotary_embedding = RotaryEmbeddingWithRotation(n_state, n_head, base=10000)
372
-
373
- self.rotation_matrix = nn.Parameter(torch.empty(self.h_dim, self.h_dim))
374
- nn.init.orthogonal_(self.rotation_matrix)
375
-
376
- self.givens_rotations = GivensRotations(self.h_dim, num_rotations=self.h_dim // 2)
377
-
378
- self.rel_pos_bias = nn.Embedding(2 * self.max_rel_dist - 1, self.n_head)
379
- self.rel_pos_bias.weight.data.fill_(0)
380
-
381
- if device:
382
- self.to(device)
383
-
384
- def update_base(self, new_base):
385
- self.base = new_base
386
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))
387
- self.register_buffer('inv_freq', inv_freq)
388
- self.rotary_embedding.update_base(new_base)
389
-
390
- def apply_rotary_embedding(self, x: torch.Tensor) -> torch.Tensor:
391
- seq_len = x.shape[1]
392
- positions = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
393
- scaled_positions = self.positional_scaling * positions
394
- sinusoid_inp = torch.outer(scaled_positions, self.inv_freq.to(x.device))
395
- sin = sinusoid_inp.sin()[None, :, None, :]
396
- cos = sinusoid_inp.cos()[None, :, None, :]
397
-
398
- x1, x2 = x[..., ::2], x[..., 1::2]
399
- x_rotated = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
400
- return x_rotated
401
-
402
- def forward(self, x, xa: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, kv_cache: Optional[dict] = None):
403
- q = self.query(x)
404
-
405
- if kv_cache is None or xa is None or 'k' not in kv_cache:
406
- k_input = x if xa is None else xa
407
- k = self.key(k_input)
408
- v = self.value(k_input)
409
- if kv_cache is not None:
410
- kv_cache['k'] = k
411
- kv_cache['v'] = v
412
- else:
413
- k = kv_cache['k']
414
- v = kv_cache['v']
415
-
416
- q = q.view(q.shape[0], q.shape[1], self.n_head, -1)
417
- k = k.view(k.shape[0], k.shape[1], self.n_head, -1)
418
- v = v.view(v.shape[0], v.shape[1], self.n_head, -1)
419
-
420
- q = self.apply_rotary_embedding(q)
421
- k = self.apply_rotary_embedding(k)
422
-
423
- q = torch.matmul(q, self.rotation_matrix)
424
- k = torch.matmul(k, self.rotation_matrix)
425
-
426
- q = self.givens_rotations(q)
427
- k = self.givens_rotations(k)
428
-
429
- q = q.view(q.shape[0], q.shape[1], -1)
430
- k = k.view(k.shape[0], k.shape[1], -1)
431
-
432
- wv, qk = self.qkv_attention(q, k, v, mask)
433
- return self.out(wv), qk
434
-
435
- def qkv_attention(self, q, k, v, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
436
- n_batch, n_ctx, n_state = q.shape
437
-
438
- scale = (n_state // self.n_head) ** -0.25
439
- q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
440
- k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
441
- v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
442
-
443
- qk = (q * scale) @ (k * scale).transpose(-1, -2)
444
-
445
- seq_len_q = q.size(2)
446
- seq_len_k = k.size(2)
447
-
448
- positions = torch.arange(seq_len_q, device=q.device).unsqueeze(1) - torch.arange(seq_len_k, device=q.device).unsqueeze(0)
449
- positions = positions.clamp(-self.max_rel_dist + 1, self.max_rel_dist - 1) + self.max_rel_dist - 1
450
- rel_bias = self.rel_pos_bias(positions)
451
- rel_bias = rel_bias.permute(2, 0, 1).unsqueeze(0)
452
-
453
- qk = qk + rel_bias
454
-
455
- if mask is not None:
456
- qk = qk + mask[:n_ctx, :n_ctx]
457
- qk = qk.float()
458
-
459
- w = F.softmax(qk, dim=-1).to(q.dtype)
460
- out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
461
- qk = qk.detach()
462
-
463
- return out, qk
464
-
465
- class ResidualAttentionBlock(nn.Module):
466
- def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, max_rel_dist = 1, checkpointing=False):
467
- super().__init__()
468
-
469
- self.attn = MultiHeadAttention(n_state, n_head)
470
- self.attn_ln = LayerNorm(n_state)
471
- self.checkpointing = checkpointing
472
- self.max_rel_dist = max_rel_dist
473
-
474
- self.cross_attn = (
475
- MultiHeadAttention(n_state, n_head) if cross_attention else None
476
- )
477
- self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
478
-
479
- n_mlp = n_state * 4
480
- self.mlp = nn.Sequential(
481
- Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
482
- )
483
- self.mlp_ln = LayerNorm(n_state)
484
-
485
- def forward(self, x, xa: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, kv_cache: Optional[dict] = None):
486
- if self.checkpointing:
487
- x = checkpoint(self._attn_forward, x, mask, kv_cache)
488
- else:
489
- x = self._attn_forward(x, mask, kv_cache)
490
-
491
- if self.cross_attn:
492
- if self.checkpointing:
493
- x = checkpoint(self._cross_attn_forward, x, xa, kv_cache)
494
- else:
495
- x = self._cross_attn_forward(x, xa, kv_cache)
496
-
497
- if self.checkpointing:
498
- x = checkpoint(self._mlp_forward, x)
499
- else:
500
- x = self._mlp_forward(x)
501
-
502
- return x
503
-
504
- def _attn_forward(self, x, mask, kv_cache):
505
- residual = x
506
- x = self.attn_ln(x)
507
- x = residual + self.attn(x, mask=mask, kv_cache=kv_cache)[0]
508
- return x
509
-
510
- def _cross_attn_forward(self, x, xa, kv_cache):
511
- residual = x
512
- x = self.cross_attn_ln(x)
513
- x = residual + self.cross_attn(x, xa, kv_cache=kv_cache)[0]
514
- return x
515
-
516
- def _mlp_forward(self, x):
517
- residual = x
518
- x = self.mlp_ln(x)
519
- x = residual + self.mlp(x)
520
- return x
521
-
522
- class AudioEncoder(nn.Module):
523
- def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, max_rel_dist, checkpointing=False):
524
- super().__init__()
525
- self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
526
- self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
527
- self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx, n_state, checkpointing=checkpointing)
528
- self.rotary_embedding = RotaryEmbeddingWithRotation(n_state, n_head, base=10000)
529
- self.checkpointing = checkpointing
530
-
531
- self.blocks = nn.ModuleList(
532
- [ResidualAttentionBlock(n_state, n_head, max_rel_dist, checkpointing=checkpointing) for _ in range(n_layer)]
533
- )
534
- self.ln_post = LayerNorm(n_state)
535
-
536
- def update_base(self, new_base):
537
- self.rotary_embedding.update_base(new_base)
538
- for block in self.blocks:
539
- if isinstance(block.attn, MultiHeadAttention):
540
- block.attn.update_base(new_base)
541
- if block.cross_attn and isinstance(block.cross_attn, MultiHeadAttention):
542
- block.cross_attn.update_base(new_base)
543
-
544
- def forward(self, x):
545
- if self.checkpointing:
546
- x = checkpoint(self._conv_forward, x)
547
- else:
548
- x = self._conv_forward(x)
549
-
550
- for block in self.blocks:
551
- if self.checkpointing:
552
- x = checkpoint(block, x)
553
- else:
554
- x = block(x)
555
-
556
- x = self.ln_post(x)
557
- return x
558
-
559
- def _conv_forward(self, x):
560
- x = F.gelu(self.conv1(x))
561
- x = F.gelu(self.conv2(x))
562
- x = x.permute(0, 2, 1)
563
- x = self.rotary_embedding(x)
564
-
565
- pos_emb = self.positional_embedding(torch.arange(x.size(1), device=x.device)).unsqueeze(0)
566
- x = x + pos_emb
567
- return x
568
-
569
- class TextDecoder(nn.Module):
570
- def __init__(self, vocab_size, n_ctx, n_state, n_head, n_layer, max_rel_dist, cross_attention, checkpointing=False):
571
- super().__init__()
572
- self.token_embedding = nn.Embedding(vocab_size, n_state)
573
- self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx, n_state, checkpointing=checkpointing)
574
- self.rotary_embedding = RotaryEmbeddingWithRotation(n_state, n_head, base=10000)
575
- self.checkpointing = checkpointing
576
- self.n_head = n_head
577
-
578
- self.blocks = nn.ModuleList([
579
- ResidualAttentionBlock(n_state, n_head, max_rel_dist, cross_attention, checkpointing=checkpointing)
580
- for _ in range(n_layer)
581
- ])
582
- self.ln = LayerNorm(n_state)
583
- mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
584
- self.register_buffer("mask", mask, persistent=False)
585
-
586
- def update_base(self, new_base):
587
- self.rotary_embedding.update_base(new_base)
588
- for block in self.blocks:
589
- if isinstance(block.attn, MultiHeadAttention):
590
- block.attn.update_base(new_base)
591
- if block.cross_attn and isinstance(block.cross_attn, MultiHeadAttention):
592
- block.cross_attn.update_base(new_base)
593
-
594
- def forward(self, x, xa, kv_cache: Optional[dict] = None):
595
- if self.checkpointing:
596
- x = checkpoint(self._embedding_forward, x, xa, kv_cache)
597
- else:
598
- x = self._embedding_forward(x, xa, kv_cache)
599
-
600
- for block in self.blocks:
601
- if self.checkpointing:
602
- x = checkpoint(block, x, xa, self.mask, kv_cache)
603
- else:
604
- x = block(x, xa, self.mask, kv_cache)
605
-
606
- x = self.ln(x)
607
- logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
608
-
609
- return logits
610
-
611
- def _embedding_forward(self, x, xa, kv_cache):
612
- offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
613
- positions = torch.arange(x.shape[1], device=x.device) + offset
614
- pos_emb = self.positional_embedding(positions).unsqueeze(0)
615
-
616
- x = self.token_embedding(x) + pos_emb
617
- x = x.to(xa.dtype)
618
-
619
- batch_size, seq_length, embedding_dim = x.shape
620
- num_heads = self.n_head
621
- head_dim = embedding_dim // num_heads
622
- x = x.view(batch_size, seq_length, num_heads, head_dim)
623
-
624
- x = self.rotary_embedding(x)
625
- x = x.view(batch_size, seq_length, embedding_dim)
626
- return x
627
-
628
- class Echo(WhisperPreTrainedModel, PyTorchModelHubMixin):
629
- config_class = WhisperConfig
630
-
631
- def __init__(self, config: WhisperConfig):
632
- super().__init__(config)
633
- self.config = config
634
-
635
- self.n_mels = self.config.num_mel_bins
636
- self.n_audio_ctx = self.config.max_source_positions
637
- self.n_audio_state = self.config.d_model
638
- self.n_audio_head = self.config.encoder_attention_heads
639
- self.n_audio_layer = self.config.encoder_layers
640
- self.vocab_size = self.config.vocab_size
641
- self.n_text_ctx = self.config.max_target_positions
642
- self.n_text_state = self.config.d_model
643
- self.n_text_head = self.config.decoder_attention_heads
644
- self.n_text_layer = self.config.decoder_layers
645
- self.max_rel_dist = self.config.max_rel_dist
646
- self.checkpointing = self.config.checkpointing
647
- self.base = self.config.base
648
-
649
- self.encoder = AudioEncoder(
650
- self.config.n_mels,
651
- self.config.n_audio_ctx,
652
- self.config.n_audio_state,
653
- self.config.n_audio_head,
654
- self.config.n_audio_layer,
655
- self.config.checkpointing,
656
- self.config.max_rel_dist
657
- )
658
- self.decoder = TextDecoder(
659
- self.config.vocab_size,
660
- self.config.n_text_ctx,
661
- self.config.n_text_state,
662
- self.config.n_text_head,
663
- self.config.n_text_layer,
664
- self.config.checkpointing,
665
- self.config.max_rel_dist
666
- )
667
-
668
- all_heads = torch.zeros(self.config.n_text_layer, self.config.n_text_head, dtype=torch.bool)
669
- all_heads[self.config.n_text_layer // 2:] = True
670
- self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
671
-
672
- self.best_loss = float('inf')
673
- self.base = 10000
674
-
675
- def update_base(self, new_base):
676
- self.encoder.rotary_embedding.update_base(new_base)
677
- self.decoder.rotary_embedding.update_base(new_base)
678
- for name, module in self.encoder.named_modules():
679
- if isinstance(module, MultiHeadAttention):
680
- module.update_base(new_base)
681
- for name, module in self.decoder.named_modules():
682
- if isinstance(module, MultiHeadAttention):
683
- module.update_base(new_base)
684
-
685
- def adjust_base(self, loss, factor=1.05):
686
- if loss < self.best_loss:
687
- new_base = self.base * factor
688
- else:
689
- new_base = self.base / factor
690
-
691
- self.update_base(new_base)
692
- self.best_loss = loss
693
- #print(f"Adjusted base: {new_base}")
694
-
695
-
696
- @staticmethod
697
- def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id) -> torch.Tensor:
698
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
699
- shifted_input_ids[:, 1:] = input_ids[:, :-1]
700
- shifted_input_ids[:, 0] = decoder_start_token_id
701
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
702
- return shifted_input_ids
703
-
704
- def forward(self, input_features, labels=None, dec_input_ids=None):
705
- if labels is not None:
706
- if dec_input_ids is None:
707
- dec_input_ids = self.shift_tokens_right(
708
- labels, self.config.pad_token_id, self.config.decoder_start_token_id
709
- )
710
-
711
- encoded_features = self.encoder(input_features).to(device)
712
- logits = self.decoder(dec_input_ids, encoded_features)
713
-
714
- loss = None
715
- if labels is not None:
716
- loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
717
- labels = labels.to(logits.device).long()
718
- loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
719
-
720
- self.adjust_base(loss.item())
721
-
722
- return {
723
- "loss": loss,
724
- "logits": logits,
725
- "input_features": encoded_features,
726
- "labels": labels,
727
- "decoder_input_ids": dec_input_ids
728
- }
729
-
730
- def _initialize_weights(self):
731
- nn.init.normal_(self.decoder.token_embedding.weight, mean=0.0, std=self.config.init_std)
732
- if hasattr(self.decoder.positional_embedding, 'weight'):
733
- nn.init.normal_(self.decoder.positional_embedding.weight, mean=0.0, std=self.config.init_std)
734
- for block in self.decoder.blocks:
735
- for layer in block.children():
736
- if isinstance(layer, nn.Linear):
737
- nn.init.xavier_normal_(layer.weight)
738
- if layer.bias is not None:
739
- nn.init.zeros_(layer.bias)
740
-
741
- nn.init.constant_(self.decoder.ln.gamma, 1)
742
- if self.decoder.ln.beta is not None:
743
- nn.init.constant_(self.decoder.ln.beta, 0)
744
-
745
- nn.init.xavier_normal_(self.encoder.conv1.weight)
746
- if self.encoder.conv1.bias is not None:
747
- nn.init.zeros_(self.encoder.conv1.bias)
748
-
749
- nn.init.kaiming_normal_(self.encoder.conv2.weight, mode='fan_out', nonlinearity='relu')
750
- if self.encoder.conv2.bias is not None:
751
- nn.init.zeros_(self.encoder.conv2.bias)
752
-
753
- nn.init.constant_(self.encoder.ln_post.gamma, 1)
754
- if self.encoder.ln_post.beta is not None:
755
- nn.init.constant_(self.encoder.ln_post.beta, 0)
756
-
757
- def apply_initialization(self):
758
- self._initialize_weights()
759
-
760
- def set_alignment_heads(self, dump: bytes):
761
- array = np.frombuffer(
762
- gzip.decompress(base64.b85decode(dump)), dtype=bool
763
- ).copy()
764
- mask = torch.from_numpy(array).reshape(
765
- self.config.n_text_layer, self.config.n_text_head
766
- )
767
- self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
768
-
769
- def embed_audio(self, mel):
770
- return self.encoder(mel)
771
-
772
- def logits(self, labels, input_features):
773
- return self.decoder(labels, input_features)
774
-
775
- @property
776
- def device(self):
777
- return next(self.parameters()).device
778
-
779
- @property
780
- def is_multilingual(self):
781
- return self.config.vocab_size >= len(tokenizer)
782
-
783
- @property
784
- def num_languages(self):
785
- return self.config.vocab_size - (len(tokenizer)-100) - int(self.is_multilingual)
786
-
787
- def install_kv_cache_hooks(self, cache: Optional[dict] = None):
788
- cache = {**cache} if cache is not None else {}
789
- hooks = []
790
-
791
- def save_to_cache(module, _, output):
792
- if module not in cache or output.shape[1] > self.config.n_text_ctx:
793
- cache[module] = output
794
- else:
795
- cache[module] = torch.cat([cache[module], output], dim=1).detach()
796
- return cache[module]
797
-
798
- def install_hooks(layer: nn.Module):
799
- if isinstance(layer, MultiHeadAttention):
800
- hooks.append(layer.key.register_forward_hook(save_to_cache))
801
- hooks.append(layer.value.register_forward_hook(save_to_cache))
802
-
803
- self.decoder.apply(install_hooks)
804
- return cache, hooks
805
-
806
- detect_language = detect_language_function
807
- transcribe = transcribe_function
808
- decode = decode_function
809
-
810
- def get_encoder(self):
811
- return self.encoder
812
-
813
- def prepare_inputs_for_generation(self, input_ids, **kwargs):
814
- return {'input_features': input_ids}
815
-
816
- def _prepare_decoder_input_ids_for_generation(self, batch_size, decoder_start_token_id=None, bos_token_id=None):
817
- return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * self.config.decoder_start_token_id
818
-
819
- def can_generate(self):
820
- return True
821
-
822
- def generate(self, inputs, **kwargs):
823
- encoder_outputs = self.encoder(inputs)
824
- decoder_input_ids = torch.zeros((inputs.size(0), 1), dtype=torch.long, device=inputs.device)
825
- outputs = self.decoder(decoder_input_ids, encoder_outputs)
826
- return outputs.argmax(dim=-1)
827
-
828
- #rasa
829
-
830
-
831
- feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small", sampling_rate=16000, n_fft=1024, hop_length=256, feature_size=128, do_normalize=True)
832
- tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language='ja', task='transcribe')#, pad_token="[PAD]", unk_token="[UNK]", model_max_length=1024)
833
- processor = WhisperProcessor.from_pretrained("openai/whisper-small", tokenizer=tokenizer, feature_extractor=feature_extractor)
834
-
835
-
836
- config = WhisperConfig(
837
- n_mels=128,
838
- n_audio_ctx=1500,
839
- n_audio_state=1024,
840
- n_audio_head=16,
841
- n_audio_layer=24,
842
- vocab_size=(len(tokenizer)),
843
- n_text_ctx=448,
844
- n_text_state=1024,
845
- n_text_head=16,
846
- n_text_layer=16,
847
- max_rel_dist=10,
848
- cross_attention=True,
849
- checkpointing=True,
850
- base=10000
851
- )
852
-
853
- model = Echo(config).to(device)
854
- model.apply_initialization()
855
- model.save_pretrained("./models/echo2")
856
-
857
-
858
-
859
- from datetime import datetime
860
- log_dir = os.path.join('./output/', datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
861
- os.makedirs(log_dir, exist_ok=True)
862
-
863
- optimizer = transformers.Adafactor(model.parameters(),
864
- clip_threshold=0.99,
865
- weight_decay=0.005,
866
- scale_parameter=True,
867
- relative_step=True,
868
- warmup_init=True,
869
- lr=None)
870
-
871
- scheduler = transformers.optimization.AdafactorSchedule(optimizer, initial_lr=2.25e-5)
872
- loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)
873
-
874
- ds_a = load_from_disk("D:/proj/datasets/gvjas")["train"].to_iterable_dataset(num_shards=200).filter(lambda sample: bool(sample["sentence"])).map(lambda sample: {"sentence": neologdn.normalize(sample['sentence'], repeat=1)}).shuffle(buffer_size=10000)
875
- ds_b = load_from_disk("D:/proj/datasets/gvjas")["test"].to_iterable_dataset(num_shards=20).filter(lambda sample: bool(sample["sentence"])).map(lambda sample: {"sentence": neologdn.normalize(sample['sentence'], repeat=1)}).shuffle(buffer_size=100)
876
-
877
- def prepare_dataset(batch):
878
- audio = batch["audio"]
879
- batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
880
- batch["labels"] = tokenizer(batch["sentence"]).input_ids
881
- return batch
882
-
883
- train = ds_a.map(prepare_dataset).select_columns(["input_features", "labels"])
884
- test = ds_b.map(prepare_dataset).select_columns(["input_features", "labels"])
885
-
886
- @dataclass
887
- class DataCollatorSpeechSeq2SeqWithPadding:
888
- processor: Any
889
- tokenizer: Any
890
- feature_extractor: Any
891
- decoder_start_token_id: Any
892
-
893
- def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
894
- input_features = [{"input_features": feature["input_features"]} for feature in features]
895
- batch = self.feature_extractor.pad(input_features, return_tensors="pt")
896
- label_features = [{"input_ids": feature["labels"]} for feature in features]
897
- labels_batch = self.tokenizer.pad(label_features, return_tensors="pt")
898
- labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
899
- if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
900
- labels = labels[:, 1:]
901
- batch["labels"] = labels
902
- return batch
903
-
904
- data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor, tokenizer=tokenizer, feature_extractor=feature_extractor, decoder_start_token_id=model.config.decoder_start_token_id)
905
-
906
- class GradientClippingCallback(TrainerCallback):
907
- def on_step_end(self, args, state, control, **kwargs):
908
- torch.nn.utils.clip_grad_norm_(kwargs["model"].parameters(), max_norm=0.95)
909
-
910
- class MetricsCallback(TrainerCallback):
911
- def __init__(self, tb_writer, tokenizer, metric, log_every_n_steps=30):
912
- super().__init__()
913
- self.tb_writer = tb_writer
914
- self.tokenizer = tokenizer
915
- self.metric = metric
916
- self.log_every_n_steps = log_every_n_steps
917
- self.predictions = None
918
- self.label_ids = None
919
-
920
- def compute_cer(self, pred_str, label_str):
921
- cer = 100 * self.metric.compute(predictions=pred_str, references=label_str)
922
- return cer
923
-
924
- def on_evaluate(self, args, state, control, metrics=None, **kwargs):
925
- if metrics is not None:
926
- for key, value in metrics.items():
927
- if key.startswith("eval_"):
928
- self.tb_writer.add_scalar(key, value, state.global_step)
929
- print(f"Step {state.global_step} - {key}: {value}")
930
-
931
- if self.predictions is not None and self.label_ids is not None:
932
- pred_str = self.tokenizer.batch_decode(self.predictions, skip_special_tokens=True)
933
- label_str = self.tokenizer.batch_decode(self.label_ids, skip_special_tokens=True)
934
-
935
- sample_index = 1
936
- self.tb_writer.add_text("Prediction", pred_str[sample_index], state.global_step)
937
- self.tb_writer.add_text("Label", label_str[sample_index], state.global_step)
938
-
939
- print(f"Step {state.global_step} - Sample Prediction: {pred_str[sample_index]}")
940
- print(f"Step {state.global_step} - Sample Label: {label_str[sample_index]}")
941
-
942
- self.predictions = None
943
- self.label_ids = None
944
-
945
- def create_compute_metrics(callback_instance):
946
- def compute_metrics(eval_pred):
947
- pred_logits = eval_pred.predictions
948
- label_ids = eval_pred.label_ids
949
-
950
- if isinstance(pred_logits, tuple):
951
- pred_ids = pred_logits[0]
952
- else:
953
- pred_ids = pred_logits
954
- if pred_ids.ndim == 3:
955
- pred_ids = np.argmax(pred_ids, axis=-1)
956
-
957
- label_ids[label_ids == -100] = callback_instance.tokenizer.pad_token_id
958
- callback_instance.predictions = pred_ids
959
- callback_instance.label_ids = label_ids
960
-
961
- pred_str = callback_instance.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
962
- label_str = callback_instance.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
963
- cer = 100 * callback_instance.metric.compute(predictions=pred_str, references=label_str)
964
-
965
- pred_flat = pred_ids.flatten()
966
- labels_flat = label_ids.flatten()
967
- mask = labels_flat != callback_instance.tokenizer.pad_token_id
968
-
969
- accuracy = accuracy_score(labels_flat[mask], pred_flat[mask])
970
- precision = precision_score(labels_flat[mask], pred_flat[mask], average='weighted', zero_division=0)
971
- recall = recall_score(labels_flat[mask], pred_flat[mask], average='weighted', zero_division=0)
972
- f1 = f1_score(labels_flat[mask], pred_flat[mask], average='weighted', zero_division=0)
973
-
974
- return {
975
- "cer": cer,
976
- "accuracy": accuracy,
977
- "precision": precision,
978
- "recall": recall,
979
- "f1": f1
980
- }
981
- return compute_metrics
982
-
983
- training_args = Seq2SeqTrainingArguments(
984
- output_dir=log_dir,
985
- logging_dir=log_dir,
986
- overwrite_output_dir=True,
987
- per_device_train_batch_size=1,
988
- gradient_accumulation_steps=1,
989
- eval_accumulation_steps=1,
990
- num_train_epochs=1,
991
- tf32=True,
992
- bf16=True,
993
- max_steps=10000,
994
- save_steps=1000,
995
- eval_steps=20,
996
- eval_strategy="steps",
997
- eval_on_start=False,
998
- warmup_steps=100,
999
- logging_steps=10,
1000
- logging_strategy="steps",
1001
- save_strategy="steps",
1002
- report_to=["tensorboard"],
1003
- push_to_hub=False,
1004
- remove_unused_columns=False,
1005
- label_names=["labels"],
1006
- hub_private_repo=True,
1007
- metric_for_best_model="cer",
1008
- greater_is_better=False,
1009
- load_best_model_at_end=True,
1010
- optim="adafactor",
1011
- weight_decay=0.00025,
1012
- disable_tqdm=False,
1013
- save_total_limit=2,
1014
- use_cpu=False,
1015
- torch_empty_cache_steps=10
1016
-
1017
- )
1018
-
1019
- torch.backends.cuda.matmul.allow_tf32 = True
1020
- torch.backends.cudnn.allow_tf32 = True
1021
- torch.cuda.empty_cache()
1022
- torch.cuda.set_device(0)
1023
-
1024
- cer_metric = evaluate.load("cer")
1025
- tb_writer = SummaryWriter(log_dir)
1026
-
1027
- metrics_callback = MetricsCallback(tb_writer, tokenizer, cer_metric, log_every_n_steps=30)
1028
- compute_metrics = create_compute_metrics(metrics_callback)
1029
-
1030
- trainer = Seq2SeqTrainer(
1031
- args=training_args,
1032
- model=model,
1033
- train_dataset=train,
1034
- eval_dataset=test,
1035
- data_collator=data_collator,
1036
- tokenizer=processor.feature_extractor,
1037
- compute_metrics=compute_metrics,
1038
- callbacks=[metrics_callback]
1039
- )
1040
-
1041
-
1042
-
1043
-
1044
- trainer.train(resume_from_checkpoint=True)
1045
- tb_writer.close()
1046
- from torch.utils.tensorboard import SummaryWriter
1047
-
1048
-
1049
- path = "./models/echo2_4k"
1050
- model.save_pretrained(path)
1051
- processor.save_pretrained(path)
1052
- tokenizer.save_pretrained(path)
1053
- feature_extractor.save_pretrained(path)
1054
-
1055
-
 
1
+
2
+ import base64, gzip, torch, evaluate, math, os, sys, time
3
+ import gzip
4
+ from torch import amp, Tensor, optim
5
+ from torch.utils.checkpoint import checkpoint
6
+ from contextlib import contextmanager
7
+ from dataclasses import dataclass
8
+ from transformers.models.whisper.modeling_whisper import WhisperPreTrainedModel
9
+ from transformers.models.whisper.generation_whisper import WhisperGenerationMixin
10
+ from transformers.optimization import Adafactor, AdafactorSchedule
11
+ from huggingface_hub import PyTorchModelHubMixin
12
+ from datasets import IterableDatasetDict, Audio, load_dataset
13
+ import numpy as np
14
+ import torch, transformers, warnings
15
+ from typing import Dict, Iterable, Optional, Tuple, Union, List, Any, Type
16
+ import torch.nn.functional as F
17
+ from torch import Tensor, nn
18
+ import torchaudio, torchaudio.transforms as T
19
+ from transformers import Seq2SeqTrainer, TrainerCallback, Seq2SeqTrainingArguments, WhisperTokenizer, WhisperForConditionalGeneration, WhisperConfig, WhisperProcessor, WhisperFeatureExtractor, WhisperTokenizer, WhisperForConditionalGeneration
20
+ from whisper.decoding import decode as decode_function
21
+ from whisper.decoding import detect_language as detect_language_function
22
+ from whisper.transcribe import transcribe as transcribe_function
23
+
24
+ try:
25
+ from torch.nn.functional import scaled_dot_product_attention
26
+
27
+ SDPA_AVAILABLE = True
28
+ except (ImportError, RuntimeError, OSError):
29
+ scaled_dot_product_attention = None
30
+ SDPA_AVAILABLE = False
31
+
32
+ transformers.utils.logging.set_verbosity_error()
33
+ warnings.filterwarnings(action="ignore")
34
+ warnings.warn = lambda *args,**kwargs: None
35
+ device = "cuda"
36
+
37
+
38
+
39
+ class LayerNorm(nn.Module):
40
+ def __init__(self, num_features, eps=1e-6):
41
+ super(LayerNorm, self).__init__()
42
+ self.gamma = nn.Parameter(torch.ones(num_features))
43
+ self.beta = nn.Parameter(torch.zeros(num_features))
44
+ self.eps = eps
45
+
46
+ def forward(self, x):
47
+ mean = x.mean(dim=-1, keepdim=True)
48
+ std = x.std(dim=-1, keepdim=True)
49
+ x = (x - mean) / (std + self.eps)
50
+ return self.gamma * x + self.beta
51
+
52
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
+
54
+ class Linear(nn.Module):
55
+ def __init__(self, in_features: int, out_features: int, dropout_rate = 0.01, use_batchnorm: bool = True, activation: str = 'relu'):
56
+ super(Linear, self).__init__()
57
+ self.linear = nn.Linear(in_features, out_features)
58
+ self.dropout = nn.Dropout(dropout_rate)
59
+ self.use_batchnorm = use_batchnorm
60
+ self.activation = activation
61
+
62
+ if self.use_batchnorm:
63
+ self.batchnorm = nn.BatchNorm1d(out_features)
64
+ self.reset_parameters()
65
+
66
+ def reset_parameters(self):
67
+ nn.init.kaiming_uniform_(self.linear.weight, nonlinearity=self.activation)
68
+ if self.linear.bias is not None:
69
+ nn.init.zeros_(self.linear.bias)
70
+
71
+ def forward(self, x):
72
+ batch_size, seq_len, _ = x.size()
73
+ x = x.view(-1, x.size(-1))
74
+ x = self.linear(x)
75
+
76
+ if self.use_batchnorm:
77
+ x = self.batchnorm(x)
78
+
79
+ x = self.apply_activation(x)
80
+ x = self.dropout(x)
81
+ x = x.view(batch_size, seq_len, -1)
82
+
83
+ return x
84
+
85
+ def apply_activation(self, x):
86
+ if self.activation == 'relu':
87
+ return F.relu(x)
88
+ elif self.activation == 'tanh':
89
+ return torch.tanh(x)
90
+ elif self.activation == 'sigmoid':
91
+ return torch.sigmoid(x)
92
+ else:
93
+ raise ValueError(f'Unsupported activation function: {self.activation}')
94
+
95
+ class Conv1d(nn.Conv1d):
96
+ def __init__(self, *args, **kwargs):
97
+ super().__init__(*args, **kwargs)
98
+ self.reset_parameters()
99
+
100
+ def reset_parameters(self):
101
+ nn.init.kaiming_uniform_(self.weight, nonlinearity='relu')
102
+ if self.bias is not None:
103
+ nn.init.zeros_(self.bias)
104
+
105
+ def _conv_forward(self, x, weight, bias) -> Tensor:
106
+ weight = self.weight.to(x.dtype)
107
+ bias = None if self.bias is None else self.bias.to(x.dtype)
108
+ return super()._conv_forward(x, weight, bias)
109
+
110
+ class BiasedCrossAttention(nn.Module):
111
+ def __init__(self, n_state, n_head, dropout_rate=0.1):
112
+ super().__init__()
113
+ self.n_head = n_head
114
+ self.n_state = n_state
115
+ self.head_dim = n_state // n_head
116
+
117
+ self.query = nn.Linear(n_state, n_state)
118
+ self.key = nn.Linear(n_state, n_state, bias=False)
119
+ self.value = nn.Linear(n_state, n_state)
120
+ self.out = nn.Linear(n_state, n_state)
121
+
122
+ self.bias = nn.Parameter(torch.zeros(n_head, 1, self.head_dim))
123
+ self.dropout = nn.Dropout(dropout_rate)
124
+ self.norm = LayerNorm(n_state)
125
+
126
+ def forward(self, q, k, v, mask=None):
127
+ batch_size, seq_length, _ = q.size()
128
+
129
+ q = self.query(q).view(batch_size, seq_length, self.n_head, self.head_dim)
130
+ k = self.key(k).view(batch_size, seq_length, self.n_head, self.head_dim)
131
+ v = self.value(v).view(batch_size, seq_length, self.n_head, self.head_dim)
132
+
133
+ qk = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) + self.bias
134
+ if mask is not None:
135
+ qk = qk.masked_fill(mask == 0, float('-inf'))
136
+
137
+ w = F.softmax(qk, dim=-1)
138
+ w = self.dropout(w)
139
+
140
+ out = (w @ v).transpose(1, 2).contiguous().view(batch_size, seq_length, -1)
141
+ out = self.norm(self.out(out) + q.view(batch_size, seq_length, -1))
142
+ return out
143
+
144
+ class DynamicConvAttention(nn.Module):
145
+ def __init__(self, n_state, n_head, kernel_size=3, dropout_rate=0.1):
146
+ super().__init__()
147
+ self.n_state = n_state
148
+ self.n_head = n_head
149
+ self.kernel_size = kernel_size
150
+
151
+ self.conv = nn.Conv1d(n_state, n_state, kernel_size, padding=kernel_size // 2, groups=n_head)
152
+ self.dropout = nn.Dropout(dropout_rate)
153
+
154
+ self.query = nn.Linear(n_state, n_state)
155
+ self.key = nn.Linear(n_state, n_state, bias=False)
156
+ self.value = nn.Linear(n_state, n_state)
157
+ self.out_proj = nn.Linear(n_state, n_state)
158
+
159
+ self.norm = LayerNorm(n_state)
160
+
161
+ def forward(self, x):
162
+ batch_size, seq_len, embed_dim = x.size()
163
+ if embed_dim != self.n_state:
164
+ raise ValueError(f"Expected embed_dim of {self.n_state}, but got {embed_dim}")
165
+
166
+ q = self.query(x)
167
+ k = self.key(x)
168
+ v = self.value(x)
169
+
170
+ x = x.permute(0, 2, 1)
171
+ conv_out = self.conv(x)
172
+ conv_out = conv_out.permute(0, 2, 1)
173
+ conv_out = self.norm(conv_out)
174
+ conv_out = self.dropout(conv_out)
175
+
176
+ attention_out = F.softmax(torch.matmul(q, k.transpose(-2, -1)) / (self.n_state ** 0.5), dim=-1)
177
+ attention_out = torch.matmul(attention_out, v)
178
+
179
+ combined_out = conv_out + attention_out
180
+ combined_out = self.norm(combined_out)
181
+
182
+ return self.out_proj(self.dropout(combined_out)) + x.permute(0, 2, 1)
183
+
184
+ class HybridAttention(nn.Module):
185
+ def __init__(self, n_state, n_head, window_size=1, dropout_rate=0.1):
186
+ super().__init__()
187
+ self.local_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)
188
+ self.global_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)
189
+ self.ln_local = LayerNorm(n_state)
190
+ self.ln_global = LayerNorm(n_state)
191
+
192
+ self.dropout = nn.Dropout(dropout_rate)
193
+ self.window_size = window_size
194
+
195
+ def forward(self, x):
196
+ x_local = self.ln_local(x)
197
+ x_global = self.ln_global(x)
198
+ x_local = x_local.permute(1, 0, 2)
199
+ x_global = x_global.permute(1, 0, 2)
200
+ local_out = self.sliding_window_attention(x_local)
201
+ global_out, _ = self.global_attn(x_global, x_global, x_global)
202
+ combined_out = local_out + global_out
203
+ combined_out = combined_out.permute(1, 0, 2)
204
+ return self.dropout(combined_out)
205
+
206
+ def sliding_window_attention(self, x):
207
+ seq_len, batch_size, n_state = x.size()
208
+ window_size = min(self.window_size, max(1, seq_len // 4))
209
+ output = torch.zeros_like(x, device=x.device, dtype=x.dtype)
210
+
211
+ for i in range(0, seq_len, window_size):
212
+ end = min(i + window_size, seq_len)
213
+ query = x[i:end, :, :]
214
+ start = max(0, i - window_size)
215
+ key = x[start:end, :, :]
216
+ value = x[start:end, :, :]
217
+ attn_output, _ = self.local_attn(query, key, value)
218
+ output[i:end, :, :] = attn_output[:end - i, :, :]
219
+
220
+ return output
221
+
222
+ def givens_rotation_matrix(n_state, i, j, theta):
223
+ G = torch.eye(n_state)
224
+ G[i, i] = math.cos(theta)
225
+ G[i, j] = -math.sin(theta)
226
+ G[j, i] = math.sin(theta)
227
+ G[j, j] = math.cos(theta)
228
+ return G
229
+
230
+ class GivensRotations(nn.Module):
231
+ def __init__(self, h_dim, num_rotations):
232
+ super().__init__()
233
+ self.h_dim = h_dim
234
+ self.num_rotations = num_rotations
235
+ self.thetas = nn.Parameter(torch.zeros(num_rotations))
236
+
237
+ def forward(self, x):
238
+ if x.dim() != 4:
239
+ raise ValueError(f"Expected input tensor to be 4D, but got {x.dim()}D")
240
+
241
+ batch_size, seq_len, n_head, h_dim = x.size()
242
+
243
+ if h_dim != self.h_dim:
244
+ raise ValueError(f"Expected h_dim of {self.h_dim}, but got {h_dim}")
245
+
246
+ x = x.view(-1, h_dim)
247
+ for k in range(self.num_rotations):
248
+ i, j = k % self.h_dim, (k + 1) % self.h_dim
249
+ G = givens_rotation_matrix(self.h_dim, i, j, self.thetas[k])
250
+ x = torch.matmul(x, G.to(x.device))
251
+
252
+ x = x.view(batch_size, seq_len, n_head, h_dim)
253
+ return x
254
+
255
+ class RotaryEmbeddingWithRotation(nn.Module):
256
+ def __init__(self, n_state, n_head, base=10000, checkpointing=False):
257
+ super().__init__()
258
+ self.n_state = n_state
259
+ self.n_head = n_head
260
+ self.h_dim = n_state // n_head
261
+ self.base = base # Initialize base
262
+ self.checkpointing = checkpointing
263
+
264
+ self.rotation_matrix = nn.Parameter(torch.eye(self.h_dim))
265
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))
266
+ self.register_buffer('inv_freq', inv_freq)
267
+
268
+ def update_base(self, new_base):
269
+ self.base = new_base
270
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))
271
+ self.register_buffer('inv_freq', inv_freq)
272
+
273
+ def reset_parameters(self):
274
+ nn.init.orthogonal_(self.rotation_matrix)
275
+
276
+ def forward(self, x):
277
+ if self.checkpointing:
278
+ return checkpoint(self._forward, x)
279
+ else:
280
+ return self._forward(x)
281
+
282
+ def _forward(self, x):
283
+ if x.dim() == 3:
284
+ batch_size, seq_len, n_state = x.size()
285
+ elif x.dim() == 4:
286
+ batch_size, seq_len, n_head, h_dim = x.size()
287
+ n_state = n_head * h_dim
288
+ x = x.view(batch_size, seq_len, n_state)
289
+ else:
290
+ raise ValueError(f"Expected input tensor to be 3D or 4D, but got {x.dim()}D")
291
+
292
+ if n_state != self.n_state:
293
+ raise ValueError(f"Expected n_state of {self.n_state}, but got {n_state}")
294
+
295
+ x = x.reshape(batch_size, seq_len, self.n_head, self.h_dim)
296
+ x = x.reshape(-1, self.h_dim)
297
+ rotated_x = torch.matmul(x, self.rotation_matrix)
298
+ rotated_x = rotated_x.reshape(batch_size, seq_len, self.n_head, self.h_dim)
299
+
300
+ sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(seq_len, device=x.device), self.inv_freq.to(x.device))
301
+ sin = sinusoid_inp.sin()[None, :, None, :]
302
+ cos = sinusoid_inp.cos()[None, :, None, :]
303
+ x1, x2 = rotated_x[..., ::2], rotated_x[..., 1::2]
304
+ rotated_x = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
305
+
306
+ rotated_x = rotated_x.reshape(batch_size, seq_len, self.n_state)
307
+ return rotated_x
308
+
309
+ class LearnedSinusoidalEmbeddings(nn.Module):
310
+ def __init__(self, n_ctx, n_state, checkpointing=False):
311
+ super().__init__()
312
+ self.n_ctx = n_ctx
313
+ self.n_state = n_state
314
+ self.checkpointing = checkpointing
315
+
316
+ position = torch.arange(0, n_ctx, dtype=torch.float).unsqueeze(1)
317
+ div_term = torch.exp(torch.arange(0, n_state, 2).float() * -(math.log(10000.0) / n_state))
318
+ features = torch.zeros(n_ctx, n_state)
319
+ features[:, 0::2] = torch.sin(position * div_term)
320
+ features[:, 1::2] = torch.cos(position * div_term)
321
+ self.register_buffer('sinusoidal_features', features)
322
+
323
+ self.positional_embeddings = nn.Parameter(self.sinusoidal_features.clone())
324
+
325
+ def forward(self, positions):
326
+ if self.checkpointing:
327
+ position_embeddings = checkpoint(lambda x: self.positional_embeddings[x], positions)
328
+ else:
329
+ position_embeddings = self.positional_embeddings[positions]
330
+
331
+ position_embeddings = torch.nn.functional.normalize(position_embeddings, p=2, dim=-1)
332
+ return position_embeddings
333
+
334
+ class MultiHeadAttention(nn.Module):
335
+ use_sdpa = True
336
+
337
+ def __init__(self, n_state: int, n_head: int, base: int = 10000, max_rel_dist: int = 1):
338
+ super().__init__()
339
+ assert n_state % n_head == 0, "n_state must be divisible by n_head"
340
+ self.n_head = n_head
341
+ self.h_dim = n_state // n_head
342
+ assert self.h_dim % 2 == 0, "Head dimension must be even for rotary embeddings"
343
+
344
+ self.positional_scaling = nn.Parameter(torch.ones(1))
345
+
346
+ self.query = nn.Linear(n_state, n_state)
347
+ self.key = nn.Linear(n_state, n_state, bias=False)
348
+ self.value = nn.Linear(n_state, n_state)
349
+ self.out = nn.Linear(n_state, n_state)
350
+
351
+ self.max_rel_dist = max_rel_dist
352
+ self.base = base
353
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))
354
+ self.register_buffer('inv_freq', inv_freq)
355
+
356
+ self.rotary_embedding = RotaryEmbeddingWithRotation(n_state, n_head, base=10000)
357
+
358
+ self.rotation_matrix = nn.Parameter(torch.empty(self.h_dim, self.h_dim))
359
+ nn.init.orthogonal_(self.rotation_matrix)
360
+
361
+ self.givens_rotations = GivensRotations(self.h_dim, num_rotations=self.h_dim // 2)
362
+
363
+ self.rel_pos_bias = nn.Embedding(2 * self.max_rel_dist - 1, self.n_head)
364
+ self.rel_pos_bias.weight.data.fill_(0)
365
+
366
+ if device:
367
+ self.to(device)
368
+
369
+ def update_base(self, new_base):
370
+ self.base = new_base
371
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))
372
+ self.register_buffer('inv_freq', inv_freq)
373
+ self.rotary_embedding.update_base(new_base)
374
+
375
+ def apply_rotary_embedding(self, x: torch.Tensor) -> torch.Tensor:
376
+ seq_len = x.shape[1]
377
+ positions = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
378
+ scaled_positions = self.positional_scaling * positions
379
+ sinusoid_inp = torch.outer(scaled_positions, self.inv_freq.to(x.device))
380
+ sin = sinusoid_inp.sin()[None, :, None, :]
381
+ cos = sinusoid_inp.cos()[None, :, None, :]
382
+
383
+ x1, x2 = x[..., ::2], x[..., 1::2]
384
+ x_rotated = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
385
+ return x_rotated
386
+
387
+ def forward(self, x, xa: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, kv_cache: Optional[dict] = None):
388
+ q = self.query(x)
389
+
390
+ if kv_cache is None or xa is None or 'k' not in kv_cache:
391
+ k_input = x if xa is None else xa
392
+ k = self.key(k_input)
393
+ v = self.value(k_input)
394
+ if kv_cache is not None:
395
+ kv_cache['k'] = k
396
+ kv_cache['v'] = v
397
+ else:
398
+ k = kv_cache['k']
399
+ v = kv_cache['v']
400
+
401
+ q = q.view(q.shape[0], q.shape[1], self.n_head, -1)
402
+ k = k.view(k.shape[0], k.shape[1], self.n_head, -1)
403
+ v = v.view(v.shape[0], v.shape[1], self.n_head, -1)
404
+
405
+ q = self.apply_rotary_embedding(q)
406
+ k = self.apply_rotary_embedding(k)
407
+
408
+ q = torch.matmul(q, self.rotation_matrix)
409
+ k = torch.matmul(k, self.rotation_matrix)
410
+
411
+ q = self.givens_rotations(q)
412
+ k = self.givens_rotations(k)
413
+
414
+ q = q.view(q.shape[0], q.shape[1], -1)
415
+ k = k.view(k.shape[0], k.shape[1], -1)
416
+
417
+ wv, qk = self.qkv_attention(q, k, v, mask)
418
+ return self.out(wv), qk
419
+
420
+ def qkv_attention(self, q, k, v, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
421
+ n_batch, n_ctx, n_state = q.shape
422
+
423
+ scale = (n_state // self.n_head) ** -0.25
424
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
425
+ k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
426
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
427
+
428
+ qk = (q * scale) @ (k * scale).transpose(-1, -2)
429
+
430
+ seq_len_q = q.size(2)
431
+ seq_len_k = k.size(2)
432
+
433
+ positions = torch.arange(seq_len_q, device=q.device).unsqueeze(1) - torch.arange(seq_len_k, device=q.device).unsqueeze(0)
434
+ positions = positions.clamp(-self.max_rel_dist + 1, self.max_rel_dist - 1) + self.max_rel_dist - 1
435
+ rel_bias = self.rel_pos_bias(positions)
436
+ rel_bias = rel_bias.permute(2, 0, 1).unsqueeze(0)
437
+
438
+ qk = qk + rel_bias
439
+
440
+ if mask is not None:
441
+ qk = qk + mask[:n_ctx, :n_ctx]
442
+ qk = qk.float()
443
+
444
+ w = F.softmax(qk, dim=-1).to(q.dtype)
445
+ out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
446
+ qk = qk.detach()
447
+
448
+ return out, qk
449
+
450
+ class ResidualAttentionBlock(nn.Module):
451
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, max_rel_dist = 1, checkpointing=False):
452
+ super().__init__()
453
+
454
+ self.attn = MultiHeadAttention(n_state, n_head)
455
+ self.attn_ln = LayerNorm(n_state)
456
+ self.checkpointing = checkpointing
457
+ self.max_rel_dist = max_rel_dist
458
+
459
+ self.cross_attn = (
460
+ MultiHeadAttention(n_state, n_head) if cross_attention else None
461
+ )
462
+ self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
463
+
464
+ n_mlp = n_state * 4
465
+ self.mlp = nn.Sequential(
466
+ Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
467
+ )
468
+ self.mlp_ln = LayerNorm(n_state)
469
+
470
+ def forward(self, x, xa: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, kv_cache: Optional[dict] = None):
471
+ if self.checkpointing:
472
+ x = checkpoint(self._attn_forward, x, mask, kv_cache)
473
+ else:
474
+ x = self._attn_forward(x, mask, kv_cache)
475
+
476
+ if self.cross_attn:
477
+ if self.checkpointing:
478
+ x = checkpoint(self._cross_attn_forward, x, xa, kv_cache)
479
+ else:
480
+ x = self._cross_attn_forward(x, xa, kv_cache)
481
+
482
+ if self.checkpointing:
483
+ x = checkpoint(self._mlp_forward, x)
484
+ else:
485
+ x = self._mlp_forward(x)
486
+
487
+ return x
488
+
489
+ def _attn_forward(self, x, mask, kv_cache):
490
+ residual = x
491
+ x = self.attn_ln(x)
492
+ x = residual + self.attn(x, mask=mask, kv_cache=kv_cache)[0]
493
+ return x
494
+
495
+ def _cross_attn_forward(self, x, xa, kv_cache):
496
+ residual = x
497
+ x = self.cross_attn_ln(x)
498
+ x = residual + self.cross_attn(x, xa, kv_cache=kv_cache)[0]
499
+ return x
500
+
501
+ def _mlp_forward(self, x):
502
+ residual = x
503
+ x = self.mlp_ln(x)
504
+ x = residual + self.mlp(x)
505
+ return x
506
+
507
+ class AudioEncoder(nn.Module):
508
+ def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, max_rel_dist, checkpointing=False):
509
+ super().__init__()
510
+ self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
511
+ self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
512
+ self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx, n_state, checkpointing=checkpointing)
513
+ self.rotary_embedding = RotaryEmbeddingWithRotation(n_state, n_head, base=10000)
514
+ self.checkpointing = checkpointing
515
+
516
+ self.blocks = nn.ModuleList(
517
+ [ResidualAttentionBlock(n_state, n_head, max_rel_dist, checkpointing=checkpointing) for _ in range(n_layer)]
518
+ )
519
+ self.ln_post = LayerNorm(n_state)
520
+
521
+ def update_base(self, new_base):
522
+ self.rotary_embedding.update_base(new_base)
523
+ for block in self.blocks:
524
+ if isinstance(block.attn, MultiHeadAttention):
525
+ block.attn.update_base(new_base)
526
+ if block.cross_attn and isinstance(block.cross_attn, MultiHeadAttention):
527
+ block.cross_attn.update_base(new_base)
528
+
529
+ def forward(self, x):
530
+ if self.checkpointing:
531
+ x = checkpoint(self._conv_forward, x)
532
+ else:
533
+ x = self._conv_forward(x)
534
+
535
+ for block in self.blocks:
536
+ if self.checkpointing:
537
+ x = checkpoint(block, x)
538
+ else:
539
+ x = block(x)
540
+
541
+ x = self.ln_post(x)
542
+ return x
543
+
544
+ def _conv_forward(self, x):
545
+ x = F.gelu(self.conv1(x))
546
+ x = F.gelu(self.conv2(x))
547
+ x = x.permute(0, 2, 1)
548
+ x = self.rotary_embedding(x)
549
+
550
+ pos_emb = self.positional_embedding(torch.arange(x.size(1), device=x.device)).unsqueeze(0)
551
+ x = x + pos_emb
552
+ return x
553
+
554
+ class TextDecoder(nn.Module):
555
+ def __init__(self, vocab_size, n_ctx, n_state, n_head, n_layer, max_rel_dist, cross_attention, checkpointing=False):
556
+ super().__init__()
557
+ self.token_embedding = nn.Embedding(vocab_size, n_state)
558
+ self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx, n_state, checkpointing=checkpointing)
559
+ self.rotary_embedding = RotaryEmbeddingWithRotation(n_state, n_head, base=10000)
560
+ self.checkpointing = checkpointing
561
+ self.n_head = n_head
562
+
563
+ self.blocks = nn.ModuleList([
564
+ ResidualAttentionBlock(n_state, n_head, max_rel_dist, cross_attention, checkpointing=checkpointing)
565
+ for _ in range(n_layer)
566
+ ])
567
+ self.ln = LayerNorm(n_state)
568
+ mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
569
+ self.register_buffer("mask", mask, persistent=False)
570
+
571
+ def update_base(self, new_base):
572
+ self.rotary_embedding.update_base(new_base)
573
+ for block in self.blocks:
574
+ if isinstance(block.attn, MultiHeadAttention):
575
+ block.attn.update_base(new_base)
576
+ if block.cross_attn and isinstance(block.cross_attn, MultiHeadAttention):
577
+ block.cross_attn.update_base(new_base)
578
+
579
+ def forward(self, x, xa, kv_cache: Optional[dict] = None):
580
+ if self.checkpointing:
581
+ x = checkpoint(self._embedding_forward, x, xa, kv_cache)
582
+ else:
583
+ x = self._embedding_forward(x, xa, kv_cache)
584
+
585
+ for block in self.blocks:
586
+ if self.checkpointing:
587
+ x = checkpoint(block, x, xa, self.mask, kv_cache)
588
+ else:
589
+ x = block(x, xa, self.mask, kv_cache)
590
+
591
+ x = self.ln(x)
592
+ logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
593
+
594
+ return logits
595
+
596
+ def _embedding_forward(self, x, xa, kv_cache):
597
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
598
+ positions = torch.arange(x.shape[1], device=x.device) + offset
599
+ pos_emb = self.positional_embedding(positions).unsqueeze(0)
600
+
601
+ x = self.token_embedding(x) + pos_emb
602
+ x = x.to(xa.dtype)
603
+
604
+ batch_size, seq_length, embedding_dim = x.shape
605
+ num_heads = self.n_head
606
+ head_dim = embedding_dim // num_heads
607
+ x = x.view(batch_size, seq_length, num_heads, head_dim)
608
+
609
+ x = self.rotary_embedding(x)
610
+ x = x.view(batch_size, seq_length, embedding_dim)
611
+ return x
612
+
613
+ class Echo(WhisperPreTrainedModel, PyTorchModelHubMixin):
614
+ config_class = WhisperConfig
615
+
616
+ def __init__(self, config: WhisperConfig):
617
+ super().__init__(config)
618
+ self.config = config
619
+
620
+ self.n_mels = self.config.num_mel_bins
621
+ self.n_audio_ctx = self.config.max_source_positions
622
+ self.n_audio_state = self.config.d_model
623
+ self.n_audio_head = self.config.encoder_attention_heads
624
+ self.n_audio_layer = self.config.encoder_layers
625
+ self.vocab_size = self.config.vocab_size
626
+ self.n_text_ctx = self.config.max_target_positions
627
+ self.n_text_state = self.config.d_model
628
+ self.n_text_head = self.config.decoder_attention_heads
629
+ self.n_text_layer = self.config.decoder_layers
630
+ self.max_rel_dist = self.config.max_rel_dist
631
+ self.checkpointing = self.config.checkpointing
632
+ self.base = self.config.base
633
+
634
+ self.encoder = AudioEncoder(
635
+ self.config.n_mels,
636
+ self.config.n_audio_ctx,
637
+ self.config.n_audio_state,
638
+ self.config.n_audio_head,
639
+ self.config.n_audio_layer,
640
+ self.config.checkpointing,
641
+ self.config.max_rel_dist
642
+ )
643
+ self.decoder = TextDecoder(
644
+ self.config.vocab_size,
645
+ self.config.n_text_ctx,
646
+ self.config.n_text_state,
647
+ self.config.n_text_head,
648
+ self.config.n_text_layer,
649
+ self.config.checkpointing,
650
+ self.config.max_rel_dist
651
+ )
652
+
653
+ all_heads = torch.zeros(self.config.n_text_layer, self.config.n_text_head, dtype=torch.bool)
654
+ all_heads[self.config.n_text_layer // 2:] = True
655
+ self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
656
+
657
+ self.best_loss = float('inf')
658
+ self.base = 10000
659
+
660
+ def update_base(self, new_base):
661
+ self.encoder.rotary_embedding.update_base(new_base)
662
+ self.decoder.rotary_embedding.update_base(new_base)
663
+ for name, module in self.encoder.named_modules():
664
+ if isinstance(module, MultiHeadAttention):
665
+ module.update_base(new_base)
666
+ for name, module in self.decoder.named_modules():
667
+ if isinstance(module, MultiHeadAttention):
668
+ module.update_base(new_base)
669
+
670
+ def adjust_base(self, loss, factor=1.05):
671
+ if loss < self.best_loss:
672
+ new_base = self.base * factor
673
+ else:
674
+ new_base = self.base / factor
675
+
676
+ self.update_base(new_base)
677
+ self.best_loss = loss
678
+ # print(f"Adjusted base: {new_base}")
679
+
680
+ @staticmethod
681
+ def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id) -> torch.Tensor:
682
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
683
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
684
+ shifted_input_ids[:, 0] = decoder_start_token_id
685
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
686
+ return shifted_input_ids
687
+
688
+ def forward(self, input_features, labels=None, dec_input_ids=None):
689
+ if labels is not None:
690
+ if dec_input_ids is None:
691
+ dec_input_ids = self.shift_tokens_right(
692
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
693
+ )
694
+
695
+ encoded_features = self.encoder(input_features).to(device)
696
+ logits = self.decoder(dec_input_ids, encoded_features)
697
+
698
+ loss = None
699
+ if labels is not None:
700
+ loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
701
+ labels = labels.to(logits.device).long()
702
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
703
+
704
+ self.adjust_base(loss.item())
705
+
706
+ return {
707
+ "loss": loss,
708
+ "logits": logits,
709
+ "input_features": encoded_features,
710
+ "labels": labels,
711
+ "decoder_input_ids": dec_input_ids
712
+ }
713
+
714
+ def _initialize_weights(self):
715
+ nn.init.normal_(self.decoder.token_embedding.weight, mean=0.0, std=self.config.init_std)
716
+ if hasattr(self.decoder.positional_embedding, 'weight'):
717
+ nn.init.normal_(self.decoder.positional_embedding.weight, mean=0.0, std=self.config.init_std)
718
+ for block in self.decoder.blocks:
719
+ for layer in block.children():
720
+ if isinstance(layer, nn.Linear):
721
+ nn.init.xavier_normal_(layer.weight)
722
+ if layer.bias is not None:
723
+ nn.init.zeros_(layer.bias)
724
+
725
+ nn.init.constant_(self.decoder.ln.gamma, 1)
726
+ if self.decoder.ln.beta is not None:
727
+ nn.init.constant_(self.decoder.ln.beta, 0)
728
+
729
+ nn.init.xavier_normal_(self.encoder.conv1.weight)
730
+ if self.encoder.conv1.bias is not None:
731
+ nn.init.zeros_(self.encoder.conv1.bias)
732
+
733
+ nn.init.kaiming_normal_(self.encoder.conv2.weight, mode='fan_out', nonlinearity='relu')
734
+ if self.encoder.conv2.bias is not None:
735
+ nn.init.zeros_(self.encoder.conv2.bias)
736
+
737
+ nn.init.constant_(self.encoder.ln_post.gamma, 1)
738
+ if self.encoder.ln_post.beta is not None:
739
+ nn.init.constant_(self.encoder.ln_post.beta, 0)
740
+
741
+ def apply_initialization(self):
742
+ self._initialize_weights()
743
+
744
+ def set_alignment_heads(self, dump: bytes):
745
+ array = np.frombuffer(
746
+ gzip.decompress(base64.b85decode(dump)), dtype=bool
747
+ ).copy()
748
+ mask = torch.from_numpy(array).reshape(
749
+ self.config.n_text_layer, self.config.n_text_head
750
+ )
751
+ self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
752
+
753
+ def embed_audio(self, mel):
754
+ return self.encoder(mel)
755
+
756
+ def logits(self, labels, input_features):
757
+ return self.decoder(labels, input_features)
758
+
759
+ @property
760
+ def device(self):
761
+ return next(self.parameters()).device
762
+
763
+ @property
764
+ def is_multilingual(self):
765
+ return self.config.vocab_size >= len(tokenizer)
766
+
767
+ @property
768
+ def num_languages(self):
769
+ return self.config.vocab_size - (len(tokenizer)-100) - int(self.is_multilingual)
770
+
771
+ def install_kv_cache_hooks(self, cache: Optional[dict] = None):
772
+ cache = {**cache} if cache is not None else {}
773
+ hooks = []
774
+
775
+ def save_to_cache(module, _, output):
776
+ if module not in cache or output.shape[1] > self.config.n_text_ctx:
777
+ cache[module] = output
778
+ else:
779
+ cache[module] = torch.cat([cache[module], output], dim=1).detach()
780
+ return cache[module]
781
+
782
+ def install_hooks(layer: nn.Module):
783
+ if isinstance(layer, MultiHeadAttention):
784
+ hooks.append(layer.key.register_forward_hook(save_to_cache))
785
+ hooks.append(layer.value.register_forward_hook(save_to_cache))
786
+
787
+ self.decoder.apply(install_hooks)
788
+ return cache, hooks
789
+
790
+ detect_language = detect_language_function
791
+ transcribe = transcribe_function
792
+ decode = decode_function
793
+
794
+ def get_encoder(self):
795
+ return self.encoder
796
+
797
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
798
+ return {'input_features': input_ids}
799
+
800
+ def _prepare_decoder_input_ids_for_generation(self, batch_size, decoder_start_token_id=None, bos_token_id=None):
801
+ return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * self.config.decoder_start_token_id
802
+
803
+ def can_generate(self):
804
+ return True
805
+
806
+ def generate(self, inputs, **kwargs):
807
+ encoder_outputs = self.encoder(inputs)
808
+ decoder_input_ids = torch.zeros((inputs.size(0), 1), dtype=torch.long, device=inputs.device)
809
+ outputs = self.decoder(decoder_input_ids, encoder_outputs)
810
+ return outputs.argmax(dim=-1)
811
+
812
+ tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="japanese", task="transcribe")
813
+ processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="japanese", task="transcribe")
814
+
815
+ config = WhisperConfig(
816
+ n_mels=80,
817
+ n_audio_ctx=1500,
818
+ n_audio_state=1024,
819
+ n_audio_head=16,
820
+ n_audio_layer=20,
821
+ vocab_size=(len(tokenizer)),
822
+ n_text_ctx=448,
823
+ n_text_state=1024,
824
+ n_text_head=16,
825
+ n_text_layer=16,
826
+ max_rel_dist=10,
827
+ cross_attention=True,
828
+ checkpointing=True,
829
+ base=10000
830
+ )
831
+
832
+ model = Echo(config).to(device)
833
+ model.apply_initialization()
834
+
835
+
836
+ class CustomCallback(TrainerCallback):
837
+ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
838
+ print(f"Evaluation metrics at step {state.global_step}: {metrics}")
839
+
840
+ raw_datasets = IterableDatasetDict()
841
+
842
+ raw_datasets["train"] = load_dataset("mozilla-foundation/common_voice_17_0", "ja", split="train", trust_remote_code=True, streaming=True)
843
+ raw_datasets["test"] = load_dataset("mozilla-foundation/common_voice_17_0", "ja", split="test", trust_remote_code=True, streaming=True).take(100)
844
+
845
+ raw_datasets = raw_datasets.cast_column("audio", Audio(sampling_rate=16000))
846
+
847
+ tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="japanese", task="transcribe")
848
+ processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="japanese", task="transcribe")
849
+
850
+ def prepare_dataset(batch):
851
+ audio = batch["audio"]
852
+ batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
853
+ transcription = batch["sentence"]
854
+ batch["labels"] = processor.tokenizer(transcription).input_ids
855
+ return batch
856
+
857
+ vectorized_datasets = raw_datasets.map(prepare_dataset, remove_columns=list(next(iter(raw_datasets.values())).features)).with_format("torch")
858
+
859
+ @dataclass
860
+ class DataCollatorSpeechSeq2SeqWithPadding:
861
+ processor: Any
862
+
863
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
864
+ input_features = [{"input_features": feature["input_features"]} for feature in features]
865
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
866
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
867
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
868
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
869
+ if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
870
+ labels = labels[:, 1:]
871
+ batch["labels"] = labels
872
+ return batch
873
+
874
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
875
+
876
+ metric = evaluate.load("cer")
877
+
878
+ def compute_metrics(pred):
879
+ pred_logits = pred.predictions
880
+ label_ids = pred.label_ids
881
+
882
+ if isinstance(pred_logits, tuple):
883
+ pred_ids = pred_logits[0]
884
+ else:
885
+ pred_ids = pred_logits
886
+ if pred_ids.ndim == 3:
887
+ pred_ids = np.argmax(pred_ids, axis=-1)
888
+
889
+ label_ids[label_ids == -100] = tokenizer.pad_token_id
890
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
891
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
892
+ cer = 100 * metric.compute(predictions=pred_str, references=label_str)
893
+ return {"cer": cer}
894
+
895
+ training_args = Seq2SeqTrainingArguments(
896
+ output_dir="./test",
897
+ per_device_train_batch_size=1,
898
+ per_device_eval_batch_size=1,
899
+ gradient_accumulation_steps=1,
900
+ eval_accumulation_steps=1,
901
+ num_train_epochs=1,
902
+ tf32=True,
903
+ bf16=True,
904
+ learning_rate=1e-5,
905
+ # warmup_steps=500,
906
+ evaluation_strategy="steps",
907
+ # predict_with_generate=True,
908
+ # generation_max_length=225,
909
+ max_steps=100,
910
+ save_steps=100,
911
+ eval_steps=10,
912
+ logging_steps=5,
913
+ report_to=["tensorboard"],
914
+ load_best_model_at_end=True,
915
+ metric_for_best_model="wer",
916
+ greater_is_better=False,
917
+ push_to_hub=False,
918
+ optim="adafactor",
919
+ weight_decay=0.0025,
920
+ disable_tqdm=False,
921
+ save_total_limit=2,
922
+ torch_empty_cache_steps=10,
923
+ )
924
+
925
+ trainer = Seq2SeqTrainer(
926
+ args=training_args,
927
+ model=model,
928
+ train_dataset=vectorized_datasets["train"],
929
+ eval_dataset=vectorized_datasets["test"],
930
+ data_collator=data_collator,
931
+ compute_metrics=compute_metrics,
932
+ tokenizer=processor,
933
+ )
934
+
935
+ trainer.add_callback(CustomCallback)
936
+
937
+ trainer.train()
938
+