{
"cells": [
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
"datapath = \"/data/jyk/aac_dataset/clotho/validation/01 A pug struggles to breathe 1_14_2008.wav\""
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
"from datasets import Audio"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
"from encodec import EncodecModel\n",
"from encodec.utils import convert_audio\n",
"\n",
"import torchaudio\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
"model = EncodecModel.encodec_model_24khz()\n"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"14851810\n"
]
}
],
"source": [
"from utils import count_parameters\n",
"print(count_parameters(model))"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
"model.set_target_bandwidth(6.0)"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"18.8\n",
"44100\n",
"24000\n",
"1\n",
"18.8\n"
]
}
],
"source": [
"wav, sr = torchaudio.load(datapath)\n",
"print(wav.shape[-1]/sr)\n",
"wav = convert_audio(wav, sr, model.sample_rate, model.channels)\n",
"wav = wav.unsqueeze(0)\n",
"print(sr)\n",
"print(model.sample_rate)\n",
"print(model.channels)\n",
"print(wav.shape[-1]/model.sample_rate)"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
"with torch.no_grad():\n",
" encoded_frames = model.encode(wav)"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[(tensor([[[675, 798, 635, ..., 281, 169, 457],\n",
" [184, 740, 961, ..., 603, 831, 857],\n",
" [996, 832, 967, ..., 273, 599, 771],\n",
" ...,\n",
" [763, 611, 140, ..., 18, 95, 918],\n",
" [938, 862, 674, ..., 661, 193, 364],\n",
" [412, 326, 339, ..., 614, 424, 428]]]), None)]\n"
]
}
],
"source": [
"print(encoded_frames)"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 8, 1410])\n"
]
}
],
"source": [
"print(encoded_frames[0][0].shape)"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [],
"source": [
"codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [B, n_q, T]"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 8, 1410])\n"
]
}
],
"source": [
"print(codes.shape)"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[184, 740, 961, ..., 603, 831, 857]])\n"
]
}
],
"source": [
"print(codes[:,1, :])"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(0)\n"
]
}
],
"source": [
"print(codes.transpose(1,2)[:,:,1].min())"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
"outputs": [],
"source": [
"code_1 = codes+1"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[185, 741, 962, ..., 604, 832, 858]])\n"
]
}
],
"source": [
"print(code_1[:,1, :])"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [],
"source": [
"decoded_wav = model.decode(encoded_frames)"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 1, 451200])"
]
},
"execution_count": 89,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"decoded_wav.shape"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [],
"source": [
"decoded_wav = decoded_wav.squeeze().squeeze().detach().numpy()"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import Audio"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"execution_count": 93,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Audio(decoded_wav, rate=24000)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}