{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import evaluate\n", "import json\n", "import logging\n", "import random\n", "import sys\n", "import time\n", "import torch\n", "import transformers\n", "import warnings\n", "import math\n", "import neologdn\n", "import gzip\n", "import base64\n", "import numpy as np\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch import amp, Tensor, optim\n", "from torch.utils.checkpoint import checkpoint\n", "from torch.optim import Adamax\n", "from torch.utils.tensorboard import SummaryWriter\n", "from typing import Optional, Tuple, Dict, List, Any, Union\n", "from dataclasses import dataclass\n", "from transformers import (\n", " WhisperPreTrainedModel, WhisperConfig, Trainer, \n", " TrainingArguments, WhisperTokenizer, WhisperFeatureExtractor, \n", " WhisperProcessor, TrainerCallback, Seq2SeqTrainer, Seq2SeqTrainingArguments, AutoTokenizer\n", ")\n", "from transformers.models.whisper.modeling_whisper import WhisperPreTrainedModel\n", "from transformers.models.whisper.generation_whisper import WhisperGenerationMixin\n", "from transformers.optimization import Adafactor, AdafactorSchedule\n", "from huggingface_hub import PyTorchModelHubMixin\n", "from datasets import load_from_disk, load_dataset\n", "from tqdm import tqdm\n", "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n", "from sklearn.model_selection import train_test_split\n", "from whisper.decoding import decode as decode_function\n", "from whisper.decoding import detect_language as detect_language_function\n", "from whisper.transcribe import transcribe as transcribe_function\n", "\n", "try:\n", " from torch.nn.functional import scaled_dot_product_attention\n", " SDPA_AVAILABLE = True\n", "except (ImportError, RuntimeError, OSError):\n", " scaled_dot_product_attention = None\n", " SDPA_AVAILABLE = False\n", "\n", "transformers.utils.logging.set_verbosity_error()\n", "warnings.filterwarnings(action=\"ignore\")\n", "warnings.warn = lambda *args,**kwargs: None\n", "device = \"cuda\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class LayerNorm(nn.Module):\n", " def __init__(self, num_features, eps=1e-6):\n", " super(LayerNorm, self).__init__()\n", " self.gamma = nn.Parameter(torch.ones(num_features))\n", " self.beta = nn.Parameter(torch.zeros(num_features))\n", " self.eps = eps\n", "\n", " def forward(self, x):\n", " mean = x.mean(dim=-1, keepdim=True)\n", " std = x.std(dim=-1, keepdim=True)\n", " x = (x - mean) / (std + self.eps)\n", " return self.gamma * x + self.beta\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "class Linear(nn.Module):\n", " def __init__(self, in_features: int, out_features: int, dropout_rate = 0.01, use_batchnorm: bool = True, activation: str = 'relu'):\n", " super(Linear, self).__init__()\n", " self.linear = nn.Linear(in_features, out_features)\n", " self.dropout = nn.Dropout(dropout_rate)\n", " self.use_batchnorm = use_batchnorm\n", " self.activation = activation\n", "\n", " if self.use_batchnorm:\n", " self.batchnorm = nn.BatchNorm1d(out_features)\n", " self.reset_parameters()\n", "\n", " def reset_parameters(self):\n", " nn.init.kaiming_uniform_(self.linear.weight, nonlinearity=self.activation)\n", " if self.linear.bias is not None:\n", " nn.init.zeros_(self.linear.bias)\n", "\n", " def forward(self, x):\n", " batch_size, seq_len, _ = x.size()\n", " x = x.view(-1, x.size(-1)) \n", " x = self.linear(x)\n", "\n", " if self.use_batchnorm:\n", " x = self.batchnorm(x)\n", "\n", " x = self.apply_activation(x)\n", " x = self.dropout(x)\n", " x = x.view(batch_size, seq_len, -1) \n", " \n", " return x\n", "\n", " def apply_activation(self, x):\n", " if self.activation == 'relu':\n", " return F.relu(x)\n", " elif self.activation == 'tanh':\n", " return torch.tanh(x)\n", " elif self.activation == 'sigmoid':\n", " return torch.sigmoid(x)\n", " else:\n", " raise ValueError(f'Unsupported activation function: {self.activation}')\n", "\n", "class Conv1d(nn.Conv1d):\n", " def __init__(self, *args, **kwargs):\n", " super().__init__(*args, **kwargs)\n", " self.reset_parameters()\n", "\n", " def reset_parameters(self):\n", " nn.init.kaiming_uniform_(self.weight, nonlinearity='relu')\n", " if self.bias is not None:\n", " nn.init.zeros_(self.bias)\n", "\n", " def _conv_forward(self, x, weight, bias) -> Tensor:\n", " weight = self.weight.to(x.dtype)\n", " bias = None if self.bias is None else self.bias.to(x.dtype)\n", " return super()._conv_forward(x, weight, bias)\n", "\n", "def givens_rotation_matrix(n_state, i, j, theta):\n", " G = torch.eye(n_state)\n", " G[i, i] = math.cos(theta)\n", " G[i, j] = -math.sin(theta)\n", " G[j, i] = math.sin(theta)\n", " G[j, j] = math.cos(theta)\n", " return G\n", "\n", "class GivensRotations(nn.Module):\n", " def __init__(self, h_dim, num_rotations):\n", " super().__init__()\n", " self.h_dim = h_dim\n", " self.num_rotations = num_rotations\n", " self.thetas = nn.Parameter(torch.zeros(num_rotations))\n", "\n", " def forward(self, x):\n", " if x.dim() != 4:\n", " raise ValueError(f\"Expected input tensor to be 4D, but got {x.dim()}D\")\n", " \n", " batch_size, seq_len, n_head, h_dim = x.size()\n", " \n", " if h_dim != self.h_dim:\n", " raise ValueError(f\"Expected h_dim of {self.h_dim}, but got {h_dim}\")\n", " \n", " x = x.view(-1, h_dim) \n", " for k in range(self.num_rotations):\n", " i, j = k % self.h_dim, (k + 1) % self.h_dim\n", " G = givens_rotation_matrix(self.h_dim, i, j, self.thetas[k])\n", " x = torch.matmul(x, G.to(x.device))\n", " \n", " x = x.view(batch_size, seq_len, n_head, h_dim) \n", " return x\n", "\n", "class BiasedCrossAttention(nn.Module):\n", " def __init__(self, n_state, n_head, dropout_rate=0.1):\n", " super().__init__()\n", " self.n_head = n_head\n", " self.n_state = n_state\n", " self.head_dim = n_state // n_head\n", "\n", " self.query = nn.Linear(n_state, n_state)\n", " self.key = nn.Linear(n_state, n_state, bias=False)\n", " self.value = nn.Linear(n_state, n_state)\n", " self.out = nn.Linear(n_state, n_state)\n", "\n", " self.bias = nn.Parameter(torch.zeros(n_head, 1, self.head_dim))\n", " self.dropout = nn.Dropout(dropout_rate)\n", " self.norm = LayerNorm(n_state)\n", " \n", " def forward(self, q, k, v, mask=None):\n", " batch_size, seq_length, _ = q.size()\n", "\n", " q = self.query(q).view(batch_size, seq_length, self.n_head, self.head_dim)\n", " k = self.key(k).view(batch_size, seq_length, self.n_head, self.head_dim)\n", " v = self.value(v).view(batch_size, seq_length, self.n_head, self.head_dim)\n", "\n", " qk = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) + self.bias\n", " if mask is not None:\n", " qk = qk.masked_fill(mask == 0, float('-inf'))\n", "\n", " w = F.softmax(qk, dim=-1)\n", " w = self.dropout(w)\n", "\n", " out = (w @ v).transpose(1, 2).contiguous().view(batch_size, seq_length, -1)\n", " out = self.norm(self.out(out) + q.view(batch_size, seq_length, -1))\n", " return out\n", "\n", "class DynamicConvAttention(nn.Module):\n", " def __init__(self, n_state, n_head, kernel_size=3, dropout_rate=0.1):\n", " super().__init__()\n", " self.n_state = n_state\n", " self.n_head = n_head\n", " self.kernel_size = kernel_size\n", "\n", " self.conv = nn.Conv1d(n_state, n_state, kernel_size, padding=kernel_size // 2, groups=n_head)\n", " self.dropout = nn.Dropout(dropout_rate)\n", "\n", " self.query = nn.Linear(n_state, n_state)\n", " self.key = nn.Linear(n_state, n_state, bias=False)\n", " self.value = nn.Linear(n_state, n_state)\n", " self.out_proj = nn.Linear(n_state, n_state)\n", "\n", " self.norm = LayerNorm(n_state)\n", "\n", " def forward(self, x):\n", " batch_size, seq_len, embed_dim = x.size()\n", " if embed_dim != self.n_state:\n", " raise ValueError(f\"Expected embed_dim of {self.n_state}, but got {embed_dim}\")\n", "\n", " q = self.query(x)\n", " k = self.key(x)\n", " v = self.value(x)\n", "\n", " x = x.permute(0, 2, 1)\n", " conv_out = self.conv(x)\n", " conv_out = conv_out.permute(0, 2, 1)\n", " conv_out = self.norm(conv_out)\n", " conv_out = self.dropout(conv_out)\n", "\n", " attention_out = F.softmax(torch.matmul(q, k.transpose(-2, -1)) / (self.n_state ** 0.5), dim=-1)\n", " attention_out = torch.matmul(attention_out, v)\n", " \n", " combined_out = conv_out + attention_out\n", " combined_out = self.norm(combined_out)\n", " \n", " return self.out_proj(self.dropout(combined_out)) + x.permute(0, 2, 1)\n", "\n", "class HybridAttention(nn.Module):\n", " def __init__(self, n_state, n_head, window_size=1, dropout_rate=0.1):\n", " super().__init__()\n", " self.local_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)\n", " self.global_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)\n", " self.ln_local = LayerNorm(n_state)\n", " self.ln_global = LayerNorm(n_state)\n", "\n", " self.dropout = nn.Dropout(dropout_rate)\n", " self.window_size = window_size\n", "\n", " def forward(self, x):\n", " x_local = self.ln_local(x)\n", " x_global = self.ln_global(x)\n", " x_local = x_local.permute(1, 0, 2)\n", " x_global = x_global.permute(1, 0, 2)\n", " local_out = self.sliding_window_attention(x_local)\n", " global_out, _ = self.global_attn(x_global, x_global, x_global)\n", " combined_out = local_out + global_out\n", " combined_out = combined_out.permute(1, 0, 2)\n", " return self.dropout(combined_out)\n", "\n", " def sliding_window_attention(self, x):\n", " seq_len, batch_size, n_state = x.size()\n", " window_size = min(self.window_size, max(1, seq_len // 4))\n", " output = torch.zeros_like(x, device=x.device, dtype=x.dtype)\n", "\n", " for i in range(0, seq_len, window_size):\n", " end = min(i + window_size, seq_len)\n", " query = x[i:end, :, :]\n", " start = max(0, i - window_size)\n", " key = x[start:end, :, :]\n", " value = x[start:end, :, :]\n", " attn_output, _ = self.local_attn(query, key, value)\n", " output[i:end, :, :] = attn_output[:end - i, :, :]\n", "\n", " return output\n", " \n", "class RotaryEmbeddingWithRotation(nn.Module):\n", " def __init__(self, n_state, n_head, base=10000, checkpointing=False):\n", " super().__init__()\n", " self.n_state = n_state\n", " self.n_head = n_head\n", " self.h_dim = n_state // n_head\n", " self.base = base # Initialize base\n", " self.checkpointing = checkpointing\n", "\n", " self.rotation_matrix = nn.Parameter(torch.eye(self.h_dim))\n", " inv_freq = 1.0 / (base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))\n", " self.register_buffer('inv_freq', inv_freq)\n", "\n", " def update_base(self, new_base):\n", " self.base = new_base\n", " inv_freq = 1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))\n", " self.register_buffer('inv_freq', inv_freq)\n", "\n", " def reset_parameters(self):\n", " nn.init.orthogonal_(self.rotation_matrix)\n", "\n", " def forward(self, x):\n", " if self.checkpointing:\n", " return checkpoint(self._forward, x)\n", " else:\n", " return self._forward(x)\n", "\n", " def _forward(self, x):\n", " if x.dim() == 3:\n", " batch_size, seq_len, n_state = x.size()\n", " elif x.dim() == 4:\n", " batch_size, seq_len, n_head, h_dim = x.size()\n", " n_state = n_head * h_dim\n", " x = x.view(batch_size, seq_len, n_state)\n", " else:\n", " raise ValueError(f\"Expected input tensor to be 3D or 4D, but got {x.dim()}D\")\n", "\n", " if n_state != self.n_state:\n", " raise ValueError(f\"Expected n_state of {self.n_state}, but got {n_state}\")\n", "\n", " x = x.reshape(batch_size, seq_len, self.n_head, self.h_dim)\n", " x = x.reshape(-1, self.h_dim)\n", " rotated_x = torch.matmul(x, self.rotation_matrix)\n", " rotated_x = rotated_x.reshape(batch_size, seq_len, self.n_head, self.h_dim)\n", "\n", " sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(seq_len, device=x.device), self.inv_freq.to(x.device))\n", " sin = sinusoid_inp.sin()[None, :, None, :]\n", " cos = sinusoid_inp.cos()[None, :, None, :]\n", " x1, x2 = rotated_x[..., ::2], rotated_x[..., 1::2]\n", " rotated_x = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)\n", " \n", " rotated_x = rotated_x.reshape(batch_size, seq_len, self.n_state)\n", " return rotated_x\n", "\n", "class LearnedSinusoidalEmbeddings(nn.Module):\n", " def __init__(self, n_ctx, n_state, checkpointing=False):\n", " super().__init__()\n", " self.n_ctx = n_ctx\n", " self.n_state = n_state\n", " self.checkpointing = checkpointing\n", "\n", " position = torch.arange(0, n_ctx, dtype=torch.float).unsqueeze(1)\n", " div_term = torch.exp(torch.arange(0, n_state, 2).float() * -(math.log(10000.0) / n_state))\n", " features = torch.zeros(n_ctx, n_state)\n", " features[:, 0::2] = torch.sin(position * div_term)\n", " features[:, 1::2] = torch.cos(position * div_term)\n", " self.register_buffer('sinusoidal_features', features)\n", "\n", " self.positional_embeddings = nn.Parameter(self.sinusoidal_features.clone())\n", "\n", " def forward(self, positions):\n", " if self.checkpointing:\n", " position_embeddings = checkpoint(lambda x: self.positional_embeddings[x], positions)\n", " else:\n", " position_embeddings = self.positional_embeddings[positions]\n", "\n", " position_embeddings = torch.nn.functional.normalize(position_embeddings, p=2, dim=-1)\n", " return position_embeddings\n", "\n", "class MultiHeadAttention(nn.Module):\n", " use_sdpa = True\n", "\n", " def __init__(self, n_state: int, n_head: int, base: int = 10000, max_rel_dist: int = 1):\n", " super().__init__()\n", " assert n_state % n_head == 0, \"n_state must be divisible by n_head\"\n", " self.n_head = n_head\n", " self.h_dim = n_state // n_head\n", " assert self.h_dim % 2 == 0, \"Head dimension must be even for rotary embeddings\"\n", "\n", " self.positional_scaling = nn.Parameter(torch.ones(1))\n", "\n", " self.query = nn.Linear(n_state, n_state)\n", " self.key = nn.Linear(n_state, n_state, bias=False)\n", " self.value = nn.Linear(n_state, n_state)\n", " self.out = nn.Linear(n_state, n_state)\n", "\n", " self.max_rel_dist = max_rel_dist\n", " self.base = base\n", " inv_freq = 1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))\n", " self.register_buffer('inv_freq', inv_freq)\n", "\n", " self.rotary_embedding = RotaryEmbeddingWithRotation(n_state, n_head, base=10000)\n", "\n", " self.rotation_matrix = nn.Parameter(torch.empty(self.h_dim, self.h_dim))\n", " nn.init.orthogonal_(self.rotation_matrix)\n", "\n", " self.givens_rotations = GivensRotations(self.h_dim, num_rotations=self.h_dim // 2) \n", "\n", " self.rel_pos_bias = nn.Embedding(2 * self.max_rel_dist - 1, self.n_head)\n", " self.rel_pos_bias.weight.data.fill_(0)\n", "\n", " if device:\n", " self.to(device)\n", "\n", " def update_base(self, new_base): \n", " self.base = new_base \n", " inv_freq = 1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim)) \n", " self.register_buffer('inv_freq', inv_freq) \n", " self.rotary_embedding.update_base(new_base)\n", "\n", " def apply_rotary_embedding(self, x: torch.Tensor) -> torch.Tensor:\n", " seq_len = x.shape[1]\n", " positions = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)\n", " scaled_positions = self.positional_scaling * positions\n", " sinusoid_inp = torch.outer(scaled_positions, self.inv_freq.to(x.device)) \n", " sin = sinusoid_inp.sin()[None, :, None, :]\n", " cos = sinusoid_inp.cos()[None, :, None, :]\n", "\n", " x1, x2 = x[..., ::2], x[..., 1::2]\n", " x_rotated = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)\n", " return x_rotated\n", "\n", " def forward(self, x, xa: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, kv_cache: Optional[dict] = None):\n", " q = self.query(x)\n", "\n", " if kv_cache is None or xa is None or 'k' not in kv_cache:\n", " k_input = x if xa is None else xa\n", " k = self.key(k_input)\n", " v = self.value(k_input)\n", " if kv_cache is not None:\n", " kv_cache['k'] = k\n", " kv_cache['v'] = v\n", " else:\n", " k = kv_cache['k']\n", " v = kv_cache['v']\n", "\n", " q = q.view(q.shape[0], q.shape[1], self.n_head, -1)\n", " k = k.view(k.shape[0], k.shape[1], self.n_head, -1)\n", " v = v.view(v.shape[0], v.shape[1], self.n_head, -1)\n", "\n", " q = self.apply_rotary_embedding(q)\n", " k = self.apply_rotary_embedding(k)\n", "\n", " q = torch.matmul(q, self.rotation_matrix)\n", " k = torch.matmul(k, self.rotation_matrix)\n", "\n", " q = self.givens_rotations(q) \n", " k = self.givens_rotations(k)\n", "\n", " q = q.view(q.shape[0], q.shape[1], -1)\n", " k = k.view(k.shape[0], k.shape[1], -1)\n", "\n", " wv, qk = self.qkv_attention(q, k, v, mask)\n", " return self.out(wv), qk\n", " \n", " def qkv_attention(self, q, k, v, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n", " n_batch, n_ctx, n_state = q.shape\n", "\n", " scale = (n_state // self.n_head) ** -0.25\n", " q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)\n", " k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)\n", " v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)\n", "\n", " qk = (q * scale) @ (k * scale).transpose(-1, -2)\n", "\n", " seq_len_q = q.size(2)\n", " seq_len_k = k.size(2)\n", "\n", " positions = torch.arange(seq_len_q, device=q.device).unsqueeze(1) - torch.arange(seq_len_k, device=q.device).unsqueeze(0)\n", " positions = positions.clamp(-self.max_rel_dist + 1, self.max_rel_dist - 1) + self.max_rel_dist - 1\n", " rel_bias = self.rel_pos_bias(positions) \n", " rel_bias = rel_bias.permute(2, 0, 1).unsqueeze(0) \n", "\n", " qk = qk + rel_bias\n", "\n", " if mask is not None:\n", " qk = qk + mask[:n_ctx, :n_ctx]\n", " qk = qk.float()\n", "\n", " w = F.softmax(qk, dim=-1).to(q.dtype)\n", " out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)\n", " qk = qk.detach()\n", "\n", " return out, qk\n", " \n", "class ResidualAttentionBlock(nn.Module):\n", " def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, max_rel_dist = 1, checkpointing=False):\n", " super().__init__()\n", "\n", " self.attn = MultiHeadAttention(n_state, n_head)\n", " self.attn_ln = LayerNorm(n_state)\n", " self.checkpointing = checkpointing\n", " self.max_rel_dist = max_rel_dist\n", "\n", " self.cross_attn = (\n", " MultiHeadAttention(n_state, n_head) if cross_attention else None\n", " )\n", " self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None\n", "\n", " n_mlp = n_state * 4\n", " self.mlp = nn.Sequential(\n", " Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)\n", " )\n", " self.mlp_ln = LayerNorm(n_state)\n", "\n", " def forward(self, x, xa: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, kv_cache: Optional[dict] = None):\n", " if self.checkpointing:\n", " x = checkpoint(self._attn_forward, x, mask, kv_cache)\n", " else:\n", " x = self._attn_forward(x, mask, kv_cache)\n", "\n", " if self.cross_attn:\n", " if self.checkpointing:\n", " x = checkpoint(self._cross_attn_forward, x, xa, kv_cache)\n", " else:\n", " x = self._cross_attn_forward(x, xa, kv_cache)\n", "\n", " if self.checkpointing:\n", " x = checkpoint(self._mlp_forward, x)\n", " else:\n", " x = self._mlp_forward(x)\n", "\n", " return x\n", "\n", " def _attn_forward(self, x, mask, kv_cache):\n", " residual = x\n", " x = self.attn_ln(x)\n", " x = residual + self.attn(x, mask=mask, kv_cache=kv_cache)[0]\n", " return x\n", "\n", " def _cross_attn_forward(self, x, xa, kv_cache):\n", " residual = x\n", " x = self.cross_attn_ln(x)\n", " x = residual + self.cross_attn(x, xa, kv_cache=kv_cache)[0]\n", " return x\n", "\n", " def _mlp_forward(self, x):\n", " residual = x\n", " x = self.mlp_ln(x)\n", " x = residual + self.mlp(x)\n", " return x\n", "\n", "class AudioEncoder(nn.Module):\n", " def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, max_rel_dist, checkpointing=False):\n", " super().__init__()\n", " self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)\n", " self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)\n", " self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx, n_state, checkpointing=checkpointing)\n", " self.rotary_embedding = RotaryEmbeddingWithRotation(n_state, n_head, base=10000)\n", " self.checkpointing = checkpointing\n", "\n", " self.blocks = nn.ModuleList(\n", " [ResidualAttentionBlock(n_state, n_head, max_rel_dist, checkpointing=checkpointing) for _ in range(n_layer)]\n", " )\n", " self.ln_post = LayerNorm(n_state)\n", "\n", " def update_base(self, new_base):\n", " self.rotary_embedding.update_base(new_base)\n", " for block in self.blocks:\n", " if isinstance(block.attn, MultiHeadAttention):\n", " block.attn.update_base(new_base)\n", " if block.cross_attn and isinstance(block.cross_attn, MultiHeadAttention):\n", " block.cross_attn.update_base(new_base)\n", "\n", " def forward(self, x):\n", " if self.checkpointing:\n", " x = checkpoint(self._conv_forward, x)\n", " else:\n", " x = self._conv_forward(x)\n", "\n", " for block in self.blocks:\n", " if self.checkpointing:\n", " x = checkpoint(block, x)\n", " else:\n", " x = block(x)\n", "\n", " x = self.ln_post(x)\n", " return x\n", "\n", " def _conv_forward(self, x):\n", " x = F.gelu(self.conv1(x))\n", " x = F.gelu(self.conv2(x))\n", " x = x.permute(0, 2, 1)\n", " x = self.rotary_embedding(x)\n", " \n", " pos_emb = self.positional_embedding(torch.arange(x.size(1), device=x.device)).unsqueeze(0)\n", " x = x + pos_emb\n", " return x\n", "\n", "class TextDecoder(nn.Module):\n", " def __init__(self, vocab_size, n_ctx, n_state, n_head, n_layer, max_rel_dist, cross_attention, checkpointing=False):\n", " super().__init__()\n", " self.token_embedding = nn.Embedding(vocab_size, n_state)\n", " self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx, n_state, checkpointing=checkpointing)\n", " self.rotary_embedding = RotaryEmbeddingWithRotation(n_state, n_head, base=10000)\n", " self.checkpointing = checkpointing\n", " self.n_head = n_head\n", "\n", " self.blocks = nn.ModuleList([\n", " ResidualAttentionBlock(n_state, n_head, max_rel_dist, cross_attention, checkpointing=checkpointing)\n", " for _ in range(n_layer)\n", " ])\n", " self.ln = LayerNorm(n_state)\n", " mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)\n", " self.register_buffer(\"mask\", mask, persistent=False)\n", "\n", " def update_base(self, new_base):\n", " self.rotary_embedding.update_base(new_base)\n", " for block in self.blocks:\n", " if isinstance(block.attn, MultiHeadAttention):\n", " block.attn.update_base(new_base)\n", " if block.cross_attn and isinstance(block.cross_attn, MultiHeadAttention):\n", " block.cross_attn.update_base(new_base)\n", "\n", " def forward(self, x, xa, kv_cache: Optional[dict] = None):\n", " if self.checkpointing:\n", " x = checkpoint(self._embedding_forward, x, xa, kv_cache)\n", " else:\n", " x = self._embedding_forward(x, xa, kv_cache)\n", "\n", " for block in self.blocks:\n", " if self.checkpointing:\n", " x = checkpoint(block, x, xa, self.mask, kv_cache)\n", " else:\n", " x = block(x, xa, self.mask, kv_cache)\n", "\n", " x = self.ln(x)\n", " logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()\n", "\n", " return logits\n", "\n", " def _embedding_forward(self, x, xa, kv_cache):\n", " offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0\n", " positions = torch.arange(x.shape[1], device=x.device) + offset\n", " pos_emb = self.positional_embedding(positions).unsqueeze(0)\n", "\n", " x = self.token_embedding(x) + pos_emb\n", " x = x.to(xa.dtype)\n", "\n", " batch_size, seq_length, embedding_dim = x.shape\n", " num_heads = self.n_head\n", " head_dim = embedding_dim // num_heads\n", " x = x.view(batch_size, seq_length, num_heads, head_dim)\n", "\n", " x = self.rotary_embedding(x)\n", " x = x.view(batch_size, seq_length, embedding_dim)\n", " return x\n", " \n", "class Echo(WhisperPreTrainedModel, PyTorchModelHubMixin):\n", " config_class = WhisperConfig\n", "\n", " def __init__(self, config: WhisperConfig):\n", " super().__init__(config)\n", " self.config = config\n", "\n", " self.n_mels = self.config.num_mel_bins\n", " self.n_audio_ctx = self.config.max_source_positions\n", " self.n_audio_state = self.config.d_model\n", " self.n_audio_head = self.config.encoder_attention_heads\n", " self.n_audio_layer = self.config.encoder_layers\n", " self.vocab_size = self.config.vocab_size\n", " self.n_text_ctx = self.config.max_target_positions\n", " self.n_text_state = self.config.d_model\n", " self.n_text_head = self.config.decoder_attention_heads\n", " self.n_text_layer = self.config.decoder_layers\n", " self.max_rel_dist = self.config.max_rel_dist \n", " self.checkpointing = self.config.checkpointing\n", " self.base = self.config.base\n", "\n", " self.encoder = AudioEncoder(\n", " self.config.n_mels,\n", " self.config.n_audio_ctx,\n", " self.config.n_audio_state,\n", " self.config.n_audio_head,\n", " self.config.n_audio_layer,\n", " self.config.checkpointing,\n", " self.config.max_rel_dist\n", " )\n", " self.decoder = TextDecoder(\n", " self.config.vocab_size,\n", " self.config.n_text_ctx,\n", " self.config.n_text_state,\n", " self.config.n_text_head,\n", " self.config.n_text_layer,\n", " self.config.checkpointing,\n", " self.config.max_rel_dist\n", " )\n", "\n", " all_heads = torch.zeros(self.config.n_text_layer, self.config.n_text_head, dtype=torch.bool)\n", " all_heads[self.config.n_text_layer // 2:] = True\n", " self.register_buffer(\"alignment_heads\", all_heads.to_sparse(), persistent=False)\n", "\n", " self.best_loss = float('inf')\n", " self.base = 10000 \n", "\n", " def update_base(self, new_base):\n", " self.encoder.rotary_embedding.update_base(new_base)\n", " self.decoder.rotary_embedding.update_base(new_base)\n", " for name, module in self.encoder.named_modules():\n", " if isinstance(module, MultiHeadAttention):\n", " module.update_base(new_base)\n", " for name, module in self.decoder.named_modules():\n", " if isinstance(module, MultiHeadAttention):\n", " module.update_base(new_base)\n", "\n", " def adjust_base(self, loss, factor=1.05):\n", " if loss < self.best_loss:\n", " new_base = self.base * factor\n", " else:\n", " new_base = self.base / factor\n", "\n", " self.update_base(new_base)\n", " self.best_loss = loss\n", " #print(f\"Adjusted base: {new_base}\")\n", "\n", "\n", " @staticmethod\n", " def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id) -> torch.Tensor:\n", " shifted_input_ids = input_ids.new_zeros(input_ids.shape)\n", " shifted_input_ids[:, 1:] = input_ids[:, :-1]\n", " shifted_input_ids[:, 0] = decoder_start_token_id\n", " shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)\n", " return shifted_input_ids\n", "\n", " def forward(self, input_features, labels=None, dec_input_ids=None):\n", " if labels is not None:\n", " if dec_input_ids is None:\n", " dec_input_ids = self.shift_tokens_right(\n", " labels, self.config.pad_token_id, self.config.decoder_start_token_id\n", " )\n", "\n", " encoded_features = self.encoder(input_features).to(device)\n", " logits = self.decoder(dec_input_ids, encoded_features)\n", "\n", " loss = None\n", " if labels is not None:\n", " loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100) \n", " labels = labels.to(logits.device).long()\n", " loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))\n", "\n", " self.adjust_base(loss.item())\n", "\n", " return {\n", " \"loss\": loss,\n", " \"logits\": logits,\n", " \"input_features\": encoded_features,\n", " \"labels\": labels,\n", " \"decoder_input_ids\": dec_input_ids\n", " }\n", "\n", " def _initialize_weights(self):\n", " nn.init.normal_(self.decoder.token_embedding.weight, mean=0.0, std=self.config.init_std)\n", " if hasattr(self.decoder.positional_embedding, 'weight'):\n", " nn.init.normal_(self.decoder.positional_embedding.weight, mean=0.0, std=self.config.init_std)\n", " for block in self.decoder.blocks:\n", " for layer in block.children():\n", " if isinstance(layer, nn.Linear):\n", " nn.init.xavier_normal_(layer.weight)\n", " if layer.bias is not None:\n", " nn.init.zeros_(layer.bias)\n", "\n", " nn.init.constant_(self.decoder.ln.gamma, 1)\n", " if self.decoder.ln.beta is not None:\n", " nn.init.constant_(self.decoder.ln.beta, 0)\n", "\n", " nn.init.xavier_normal_(self.encoder.conv1.weight)\n", " if self.encoder.conv1.bias is not None:\n", " nn.init.zeros_(self.encoder.conv1.bias)\n", "\n", " nn.init.kaiming_normal_(self.encoder.conv2.weight, mode='fan_out', nonlinearity='relu')\n", " if self.encoder.conv2.bias is not None:\n", " nn.init.zeros_(self.encoder.conv2.bias)\n", "\n", " nn.init.constant_(self.encoder.ln_post.gamma, 1)\n", " if self.encoder.ln_post.beta is not None:\n", " nn.init.constant_(self.encoder.ln_post.beta, 0)\n", " \n", " def apply_initialization(self):\n", " self._initialize_weights()\n", "\n", " def set_alignment_heads(self, dump: bytes):\n", " array = np.frombuffer(\n", " gzip.decompress(base64.b85decode(dump)), dtype=bool\n", " ).copy()\n", " mask = torch.from_numpy(array).reshape(\n", " self.config.n_text_layer, self.config.n_text_head\n", " )\n", " self.register_buffer(\"alignment_heads\", mask.to_sparse(), persistent=False)\n", "\n", " def embed_audio(self, mel):\n", " return self.encoder(mel)\n", "\n", " def logits(self, labels, input_features):\n", " return self.decoder(labels, input_features)\n", "\n", " @property\n", " def device(self):\n", " return next(self.parameters()).device\n", "\n", " @property\n", " def is_multilingual(self):\n", " return self.config.vocab_size >= len(tokenizer)\n", "\n", " @property\n", " def num_languages(self):\n", " return self.config.vocab_size - (len(tokenizer)-100) - int(self.is_multilingual)\n", "\n", " def install_kv_cache_hooks(self, cache: Optional[dict] = None):\n", " cache = {**cache} if cache is not None else {}\n", " hooks = []\n", "\n", " def save_to_cache(module, _, output):\n", " if module not in cache or output.shape[1] > self.config.n_text_ctx:\n", " cache[module] = output\n", " else:\n", " cache[module] = torch.cat([cache[module], output], dim=1).detach()\n", " return cache[module]\n", "\n", " def install_hooks(layer: nn.Module):\n", " if isinstance(layer, MultiHeadAttention):\n", " hooks.append(layer.key.register_forward_hook(save_to_cache))\n", " hooks.append(layer.value.register_forward_hook(save_to_cache))\n", "\n", " self.decoder.apply(install_hooks)\n", " return cache, hooks\n", "\n", " detect_language = detect_language_function\n", " transcribe = transcribe_function\n", " decode = decode_function\n", "\n", " def get_encoder(self):\n", " return self.encoder\n", "\n", " def prepare_inputs_for_generation(self, input_ids, **kwargs):\n", " return {'input_features': input_ids}\n", "\n", " def _prepare_decoder_input_ids_for_generation(self, batch_size, decoder_start_token_id=None, bos_token_id=None):\n", " return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * self.config.decoder_start_token_id\n", "\n", " def can_generate(self):\n", " return True\n", " \n", " def generate(self, inputs, **kwargs):\n", " encoder_outputs = self.encoder(inputs)\n", " decoder_input_ids = torch.zeros((inputs.size(0), 1), dtype=torch.long, device=inputs.device)\n", " outputs = self.decoder(decoder_input_ids, encoder_outputs)\n", " return outputs.argmax(dim=-1)\n", "\n", "#rasa" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "feature_extractor = WhisperFeatureExtractor.from_pretrained(\"openai/whisper-small\", sampling_rate=16000, n_fft=1024, hop_length=256, feature_size=128, do_normalize=True)\n", "tokenizer = WhisperTokenizer.from_pretrained(\"openai/whisper-small\", language='ja', task='transcribe')#, pad_token=\"[PAD]\", unk_token=\"[UNK]\", model_max_length=1024)\n", "processor = WhisperProcessor.from_pretrained(\"openai/whisper-small\", tokenizer=tokenizer, feature_extractor=feature_extractor)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "config = WhisperConfig(\n", " n_mels=128,\n", " n_audio_ctx=1500,\n", " n_audio_state=1024,\n", " n_audio_head=16,\n", " n_audio_layer=24,\n", " vocab_size=(len(tokenizer)),\n", " n_text_ctx=448,\n", " n_text_state=1024,\n", " n_text_head=16,\n", " n_text_layer=16,\n", " max_rel_dist=10,\n", " cross_attention=True,\n", " checkpointing=True,\n", " base=10000\n", " )\n", "\n", "model = Echo(config).to(device)\n", "model.apply_initialization()\n", "model.save_pretrained(\"./models/echo2\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datetime import datetime\n", "log_dir = os.path.join('./output/', datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))\n", "os.makedirs(log_dir, exist_ok=True)\n", "\n", "optimizer = transformers.Adafactor(model.parameters(), \n", " clip_threshold=0.99, \n", " weight_decay=0.005, \n", " scale_parameter=True, \n", " relative_step=True, \n", " warmup_init=True, \n", " lr=None)\n", "\n", "scheduler = transformers.optimization.AdafactorSchedule(optimizer, initial_lr=2.25e-5)\n", "loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)\n", "\n", "ds_a = load_from_disk(\"D:/proj/datasets/gvjas\")[\"train\"].to_iterable_dataset(num_shards=200).filter(lambda sample: bool(sample[\"sentence\"])).map(lambda sample: {\"sentence\": neologdn.normalize(sample['sentence'], repeat=1)}).shuffle(buffer_size=10000)\n", "ds_b = load_from_disk(\"D:/proj/datasets/gvjas\")[\"test\"].to_iterable_dataset(num_shards=20).filter(lambda sample: bool(sample[\"sentence\"])).map(lambda sample: {\"sentence\": neologdn.normalize(sample['sentence'], repeat=1)}).shuffle(buffer_size=100)\n", "\n", "def prepare_dataset(batch):\n", " audio = batch[\"audio\"]\n", " batch[\"input_features\"] = feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n", " batch[\"labels\"] = tokenizer(batch[\"sentence\"]).input_ids\n", " return batch\n", "\n", "train = ds_a.map(prepare_dataset).select_columns([\"input_features\", \"labels\"])\n", "test = ds_b.map(prepare_dataset).select_columns([\"input_features\", \"labels\"])\n", "\n", "@dataclass\n", "class DataCollatorSpeechSeq2SeqWithPadding:\n", " processor: Any\n", " tokenizer: Any\n", " feature_extractor: Any\n", " decoder_start_token_id: Any\n", "\n", " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n", " input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n", " batch = self.feature_extractor.pad(input_features, return_tensors=\"pt\")\n", " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n", " labels_batch = self.tokenizer.pad(label_features, return_tensors=\"pt\")\n", " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n", " if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():\n", " labels = labels[:, 1:]\n", " batch[\"labels\"] = labels\n", " return batch\n", "\n", "data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor, tokenizer=tokenizer, feature_extractor=feature_extractor, decoder_start_token_id=model.config.decoder_start_token_id)\n", "\n", "class GradientClippingCallback(TrainerCallback):\n", " def on_step_end(self, args, state, control, **kwargs):\n", " torch.nn.utils.clip_grad_norm_(kwargs[\"model\"].parameters(), max_norm=0.95)\n", "\n", "class MetricsCallback(TrainerCallback):\n", " def __init__(self, tb_writer, tokenizer, metric, log_every_n_steps=30):\n", " super().__init__()\n", " self.tb_writer = tb_writer\n", " self.tokenizer = tokenizer\n", " self.metric = metric\n", " self.log_every_n_steps = log_every_n_steps\n", " self.predictions = None\n", " self.label_ids = None\n", "\n", " def compute_cer(self, pred_str, label_str):\n", " cer = 100 * self.metric.compute(predictions=pred_str, references=label_str)\n", " return cer\n", "\n", " def on_evaluate(self, args, state, control, metrics=None, **kwargs):\n", " if metrics is not None:\n", " for key, value in metrics.items():\n", " if key.startswith(\"eval_\"):\n", " self.tb_writer.add_scalar(key, value, state.global_step)\n", " print(f\"Step {state.global_step} - {key}: {value}\")\n", "\n", " if self.predictions is not None and self.label_ids is not None:\n", " pred_str = self.tokenizer.batch_decode(self.predictions, skip_special_tokens=True)\n", " label_str = self.tokenizer.batch_decode(self.label_ids, skip_special_tokens=True)\n", "\n", " sample_index = 1\n", " self.tb_writer.add_text(\"Prediction\", pred_str[sample_index], state.global_step)\n", " self.tb_writer.add_text(\"Label\", label_str[sample_index], state.global_step)\n", "\n", " print(f\"Step {state.global_step} - Sample Prediction: {pred_str[sample_index]}\")\n", " print(f\"Step {state.global_step} - Sample Label: {label_str[sample_index]}\")\n", "\n", " self.predictions = None\n", " self.label_ids = None\n", "\n", "def create_compute_metrics(callback_instance):\n", " def compute_metrics(eval_pred):\n", " pred_logits = eval_pred.predictions\n", " label_ids = eval_pred.label_ids\n", "\n", " if isinstance(pred_logits, tuple):\n", " pred_ids = pred_logits[0]\n", " else:\n", " pred_ids = pred_logits\n", " if pred_ids.ndim == 3:\n", " pred_ids = np.argmax(pred_ids, axis=-1)\n", "\n", " label_ids[label_ids == -100] = callback_instance.tokenizer.pad_token_id\n", " callback_instance.predictions = pred_ids\n", " callback_instance.label_ids = label_ids\n", "\n", " pred_str = callback_instance.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n", " label_str = callback_instance.tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n", " cer = 100 * callback_instance.metric.compute(predictions=pred_str, references=label_str)\n", "\n", " pred_flat = pred_ids.flatten()\n", " labels_flat = label_ids.flatten()\n", " mask = labels_flat != callback_instance.tokenizer.pad_token_id\n", "\n", " accuracy = accuracy_score(labels_flat[mask], pred_flat[mask])\n", " precision = precision_score(labels_flat[mask], pred_flat[mask], average='weighted', zero_division=0)\n", " recall = recall_score(labels_flat[mask], pred_flat[mask], average='weighted', zero_division=0)\n", " f1 = f1_score(labels_flat[mask], pred_flat[mask], average='weighted', zero_division=0)\n", "\n", " return {\n", " \"cer\": cer,\n", " \"accuracy\": accuracy,\n", " \"precision\": precision,\n", " \"recall\": recall,\n", " \"f1\": f1\n", " }\n", " return compute_metrics\n", "\n", "training_args = Seq2SeqTrainingArguments(\n", " output_dir=log_dir,\n", " logging_dir=log_dir,\n", " overwrite_output_dir=True,\n", " per_device_train_batch_size=1, \n", " gradient_accumulation_steps=1,\n", " eval_accumulation_steps=1,\n", " num_train_epochs=1,\n", " tf32=True,\n", " bf16=True,\n", " max_steps=10000,\n", " save_steps=1000,\n", " eval_steps=20,\n", " eval_strategy=\"steps\",\n", " eval_on_start=False,\n", " warmup_steps=100,\n", " logging_steps=10,\n", " logging_strategy=\"steps\",\n", " save_strategy=\"steps\",\n", " report_to=[\"tensorboard\"],\n", " push_to_hub=False,\n", " remove_unused_columns=False,\n", " label_names=[\"labels\"],\n", " hub_private_repo=True,\n", " metric_for_best_model=\"cer\",\n", " greater_is_better=False,\n", " load_best_model_at_end=True,\n", " optim=\"adafactor\",\n", " weight_decay=0.00025,\n", " disable_tqdm=False,\n", " save_total_limit=2,\n", " use_cpu=False,\n", " torch_empty_cache_steps=10\n", " \n", ")\n", "\n", "torch.backends.cuda.matmul.allow_tf32 = True\n", "torch.backends.cudnn.allow_tf32 = True\n", "torch.cuda.empty_cache()\n", "torch.cuda.set_device(0)\n", "\n", "cer_metric = evaluate.load(\"cer\")\n", "tb_writer = SummaryWriter(log_dir)\n", "\n", "metrics_callback = MetricsCallback(tb_writer, tokenizer, cer_metric, log_every_n_steps=30)\n", "compute_metrics = create_compute_metrics(metrics_callback)\n", "\n", "trainer = Seq2SeqTrainer(\n", " args=training_args,\n", " model=model,\n", " train_dataset=train,\n", " eval_dataset=test,\n", " data_collator=data_collator,\n", " tokenizer=processor.feature_extractor,\n", " compute_metrics=compute_metrics,\n", " callbacks=[metrics_callback]\n", ")\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trainer.train(resume_from_checkpoint=True)\n", "tb_writer.close()\n", "from torch.utils.tensorboard import SummaryWriter" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = \"./models/echo2_4k\"\n", "model.save_pretrained(path)\n", "processor.save_pretrained(path)\n", "tokenizer.save_pretrained(path)\n", "feature_extractor.save_pretrained(path)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 2 }