asigalov61 commited on
Commit
be361a6
·
verified ·
1 Parent(s): 7c965e1

Upload Score_2_Performance_Transformer_Eval_Colab.ipynb

Browse files
code/Score_2_Performance_Transformer_Eval_Colab.ipynb ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "VGrGd6__l5ch"
7
+ },
8
+ "source": [
9
+ "# Score 2 Performance Transformer Eval Colab (ver. 1.0)\n",
10
+ "\n",
11
+ "***\n",
12
+ "\n",
13
+ "Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools\n",
14
+ "\n",
15
+ "***\n",
16
+ "\n",
17
+ "WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/\n",
18
+ "\n",
19
+ "***\n",
20
+ "\n",
21
+ "#### Project Los Angeles\n",
22
+ "\n",
23
+ "#### Tegridy Code 2024\n",
24
+ "\n",
25
+ "***"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "markdown",
30
+ "metadata": {
31
+ "id": "shLrgoXdl5cj"
32
+ },
33
+ "source": [
34
+ "# GPU check"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "metadata": {
41
+ "id": "X3rABEpKCO02"
42
+ },
43
+ "outputs": [],
44
+ "source": [
45
+ "!nvidia-smi"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "markdown",
50
+ "metadata": {
51
+ "id": "0RcVC4btl5ck"
52
+ },
53
+ "source": [
54
+ "# Setup environment"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "metadata": {
61
+ "id": "viHgEaNACPTs"
62
+ },
63
+ "outputs": [],
64
+ "source": [
65
+ "!git clone --depth 1 https://github.com/asigalov61/tegridy-tools"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": null,
71
+ "metadata": {
72
+ "id": "vK40g6V_BTNj"
73
+ },
74
+ "outputs": [],
75
+ "source": [
76
+ "!sudo pip install torch\n",
77
+ "!sudo pip install einops\n",
78
+ "!sudo pip install torch-summary\n",
79
+ "!sudo pip install tqdm\n",
80
+ "!sudo pip install huggingface_hub\n",
81
+ "!sudo pip install hf-transfer\n",
82
+ "!sudo pip install ipywidgets"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "markdown",
87
+ "metadata": {},
88
+ "source": [
89
+ "# Import modules"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": null,
95
+ "metadata": {
96
+ "id": "DzCOZU_gBiQV"
97
+ },
98
+ "outputs": [],
99
+ "source": [
100
+ "# Load modules and make data dir\n",
101
+ "\n",
102
+ "print('Loading modules...')\n",
103
+ "\n",
104
+ "import os\n",
105
+ "import pickle\n",
106
+ "import random\n",
107
+ "import secrets\n",
108
+ "import tqdm\n",
109
+ "import math\n",
110
+ "\n",
111
+ "!set USE_FLASH_ATTENTION=1\n",
112
+ "os.environ['USE_FLASH_ATTENTION'] = '1'\n",
113
+ "os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'\n",
114
+ "\n",
115
+ "import torch\n",
116
+ "\n",
117
+ "import matplotlib.pyplot as plt\n",
118
+ "\n",
119
+ "from torchsummary import summary\n",
120
+ "\n",
121
+ "%cd /home/ubuntu/tegridy-tools/tegridy-tools/\n",
122
+ "\n",
123
+ "import TMIDIX\n",
124
+ "\n",
125
+ "%cd /home/ubuntu/tegridy-tools/tegridy-tools/X-Transformer\n",
126
+ "\n",
127
+ "from x_transformer_1_23_2 import *\n",
128
+ "\n",
129
+ "torch.set_float32_matmul_precision('high')\n",
130
+ "torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul\n",
131
+ "torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn\n",
132
+ "torch.backends.cuda.enable_flash_sdp(True)\n",
133
+ "torch.backends.cuda.enable_cudnn_sdp(False)\n",
134
+ "\n",
135
+ "!set USE_FLASH_ATTENTION=1\n",
136
+ "\n",
137
+ "%cd /home/ubuntu/\n",
138
+ "\n",
139
+ "if not os.path.exists('/home/ubuntu/INTS'):\n",
140
+ " os.makedirs('/home/ubuntu/INTS')\n",
141
+ "\n",
142
+ "import random\n",
143
+ "\n",
144
+ "from huggingface_hub import hf_hub_download\n",
145
+ "\n",
146
+ "print('Done')\n",
147
+ "\n",
148
+ "print('Torch version:', torch.__version__)"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "markdown",
153
+ "metadata": {},
154
+ "source": [
155
+ "# Download model"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": null,
161
+ "metadata": {
162
+ "id": "SA8qQSzbWslM"
163
+ },
164
+ "outputs": [],
165
+ "source": [
166
+ "hf_hub_download(repo_id='asigalov61/Score-2-Performance-Transformer',\n",
167
+ " filename='Score_2_Performance_Transformer_Small_Trained_Model_5280_steps_1.5374_loss_0.5525_acc.pth',\n",
168
+ " local_dir='/home/ubuntu/Model/',\n",
169
+ " )"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "markdown",
174
+ "metadata": {},
175
+ "source": [
176
+ "# Load model"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "metadata": {
183
+ "id": "gSvqSRLaWslM"
184
+ },
185
+ "outputs": [],
186
+ "source": [
187
+ "SEQ_LEN = 1802\n",
188
+ "PAD_IDX = 771\n",
189
+ "\n",
190
+ "model = TransformerWrapper(\n",
191
+ " num_tokens = PAD_IDX+1,\n",
192
+ " max_seq_len = SEQ_LEN,\n",
193
+ " attn_layers = Decoder(dim = 1024,\n",
194
+ " depth = 8,\n",
195
+ " heads = 8,\n",
196
+ " rotary_pos_emb = True,\n",
197
+ " attn_flash = True\n",
198
+ " )\n",
199
+ " )\n",
200
+ "\n",
201
+ "model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)\n",
202
+ "\n",
203
+ "print('=' * 70)\n",
204
+ "print('Loading model checkpoint...')\n",
205
+ "\n",
206
+ "model_path = '/home/ubuntu/Model/Score_2_Performance_Transformer_Small_Trained_Model_5280_steps_1.5374_loss_0.5525_acc.pth'\n",
207
+ "\n",
208
+ "model.load_state_dict(torch.load(model_path, weights_only=True))\n",
209
+ "\n",
210
+ "print('=' * 70)\n",
211
+ "\n",
212
+ "model = torch.compile(model, mode='max-autotune')\n",
213
+ "\n",
214
+ "model.cuda()\n",
215
+ "model.eval()\n",
216
+ "\n",
217
+ "print('Done!')\n",
218
+ "\n",
219
+ "summary(model)\n",
220
+ "\n",
221
+ "dtype = torch.bfloat16\n",
222
+ "\n",
223
+ "ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "markdown",
228
+ "metadata": {
229
+ "id": "feXay_Ed7mG5"
230
+ },
231
+ "source": [
232
+ "# Eval"
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "markdown",
237
+ "metadata": {},
238
+ "source": [
239
+ "## Load source MIDI composition"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": null,
245
+ "metadata": {
246
+ "id": "enHpaHxaWslM"
247
+ },
248
+ "outputs": [],
249
+ "source": [
250
+ "#=================================================================\n",
251
+ "\n",
252
+ "# This can be a score or performance\n",
253
+ "# MIDI will be converted to solo Piano without drums\n",
254
+ "\n",
255
+ "# PLEASE NOTE THAT the MIDI composition MUST HAVE at least 300 notes for this demo to work properly!\n",
256
+ "\n",
257
+ "#=================================================================\n",
258
+ "\n",
259
+ "midi_file = '/home/ubuntu/tegridy-tools/tegridy-tools/seed2.mid'\n",
260
+ "# midi_file = 'midi_score.mid'\n",
261
+ "\n",
262
+ "#=================================================================\n",
263
+ "\n",
264
+ "raw_score = TMIDIX.midi2single_track_ms_score(midi_file)\n",
265
+ "\n",
266
+ "escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)\n",
267
+ "\n",
268
+ "if escore_notes[0]:\n",
269
+ "\n",
270
+ " escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes[0], timings_divider=16)\n",
271
+ "\n",
272
+ " pe = escore_notes[0]\n",
273
+ "\n",
274
+ " melody_chords = []\n",
275
+ "\n",
276
+ " seen = []\n",
277
+ "\n",
278
+ " for e in escore_notes:\n",
279
+ "\n",
280
+ " if e[3] != 9:\n",
281
+ " \n",
282
+ " #=======================================================\n",
283
+ " \n",
284
+ " dtime = max(0, min(255, e[1]-pe[1]))\n",
285
+ " \n",
286
+ " if dtime != 0:\n",
287
+ " seen = []\n",
288
+ " \n",
289
+ " # Durations\n",
290
+ " dur = max(1, min(255, e[2]))\n",
291
+ " \n",
292
+ " # Pitches\n",
293
+ " ptc = max(1, min(127, e[4]))\n",
294
+ " \n",
295
+ " vel = max(1, min(127, e[5]))\n",
296
+ " \n",
297
+ " if ptc not in seen:\n",
298
+ " \n",
299
+ " melody_chords.append([dtime, dur, ptc, vel])\n",
300
+ " \n",
301
+ " seen.append(ptc)\n",
302
+ " \n",
303
+ " pe = e\n",
304
+ "\n",
305
+ "print('=' * 70)\n",
306
+ "print('Number of notes in a composition:', len(melody_chords))\n",
307
+ "print('=' * 70)\n",
308
+ "\n",
309
+ "src_melody_chords_f = []\n",
310
+ "melody_chords_f = []\n",
311
+ "\n",
312
+ "for i in range(0, len(melody_chords), 300):\n",
313
+ " \n",
314
+ " chunk = melody_chords[i:i+300]\n",
315
+ " \n",
316
+ " src = []\n",
317
+ " src1 = []\n",
318
+ " trg = []\n",
319
+ " \n",
320
+ " if len(chunk) == 300:\n",
321
+ "\n",
322
+ " for mm in chunk:\n",
323
+ " src.extend([mm[0], mm[2]+256])\n",
324
+ " src1.append([mm[0], mm[2]+256, mm[1]+384, mm[3]+640])\n",
325
+ " trg.extend([mm[0], mm[2]+256, mm[1]+384, mm[3]+640])\n",
326
+ "\n",
327
+ " src_melody_chords_f.append(src1)\n",
328
+ " melody_chords_f.append([768] + src + [769] + trg + [770])\n",
329
+ " \n",
330
+ "print('Done!')\n",
331
+ "print('=' * 70)\n",
332
+ "print('Number of composition chunks:', len(melody_chords_f))\n",
333
+ "print('=' * 70)"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "markdown",
338
+ "metadata": {},
339
+ "source": [
340
+ "# Generate new durations and velocities"
341
+ ]
342
+ },
343
+ {
344
+ "cell_type": "code",
345
+ "execution_count": null,
346
+ "metadata": {},
347
+ "outputs": [],
348
+ "source": [
349
+ "model.eval()\n",
350
+ "\n",
351
+ "#================================================================\n",
352
+ "\n",
353
+ "composition_chunk_idx = 0 # Composition chunk idx to generate durations and velocities for. Each chunk is 300 notes\n",
354
+ "\n",
355
+ "num_prime_notes = 4 # Priming improves the results but it is not necessary and you can set it to zero\n",
356
+ "dur_top_k = 2 # Use k == 1 if src composition is score and k > 1 if src composition is performance\n",
357
+ "\n",
358
+ "dur_temperature = 1.3 # For best results, durations temperature should be more than 1.0 but less than velocities temperature\n",
359
+ "vel_temperature = 1.5 # For best results, velocities temperature must be larger than 1.3 and larger than durations temperature\n",
360
+ "\n",
361
+ "#================================================================\n",
362
+ "\n",
363
+ "song_chunk = src_melody_chords_f[composition_chunk_idx]\n",
364
+ "\n",
365
+ "song = [768]\n",
366
+ "\n",
367
+ "for m in song_chunk:\n",
368
+ " song.extend(m[:2])\n",
369
+ "\n",
370
+ "song.append(769)\n",
371
+ "\n",
372
+ "for i in tqdm.tqdm(range(len(song_chunk))):\n",
373
+ "\n",
374
+ " song.extend(song_chunk[i][:2])\n",
375
+ "\n",
376
+ " # Durations\n",
377
+ "\n",
378
+ " if i < num_prime_notes:\n",
379
+ " song.append(song_chunk[i][2])\n",
380
+ "\n",
381
+ " else:\n",
382
+ "\n",
383
+ " x = torch.LongTensor(song).cuda()\n",
384
+ "\n",
385
+ " y = 0 \n",
386
+ "\n",
387
+ " while not 384 < y < 640:\n",
388
+ " \n",
389
+ " with ctx:\n",
390
+ " out = model.generate(x,\n",
391
+ " 1,\n",
392
+ " temperature=dur_temperature,\n",
393
+ " filter_logits_fn=top_k,\n",
394
+ " filter_kwargs={'k': dur_top_k},\n",
395
+ " return_prime=False,\n",
396
+ " verbose=False)\n",
397
+ " \n",
398
+ " y = out.tolist()[0][0]\n",
399
+ " \n",
400
+ " song.append(y)\n",
401
+ "\n",
402
+ "\n",
403
+ " # Velocities\n",
404
+ " \n",
405
+ " if i < num_prime_notes:\n",
406
+ " song.append(song_chunk[i][3])\n",
407
+ "\n",
408
+ " else:\n",
409
+ "\n",
410
+ " x = torch.LongTensor(song).cuda()\n",
411
+ " \n",
412
+ " y = 0 \n",
413
+ "\n",
414
+ " while not 640 < y < 768:\n",
415
+ " \n",
416
+ " with ctx:\n",
417
+ " out = model.generate(x,\n",
418
+ " 1,\n",
419
+ " temperature=vel_temperature,\n",
420
+ " #filter_logits_fn=top_k,\n",
421
+ " #filter_kwargs={'k': 10},\n",
422
+ " return_prime=False,\n",
423
+ " verbose=False)\n",
424
+ " \n",
425
+ " y = out.tolist()[0][0]\n",
426
+ " \n",
427
+ " song.append(y)\n",
428
+ "\n",
429
+ "\n",
430
+ "print('---------------')\n",
431
+ "\n",
432
+ "#===========================================================================\n",
433
+ "# Convert model output to MIDI\n",
434
+ "#===========================================================================\n",
435
+ "\n",
436
+ "song1 = song[602:]\n",
437
+ "\n",
438
+ "print('Sample INTs', song1[:15])\n",
439
+ "\n",
440
+ "song_f = []\n",
441
+ "\n",
442
+ "time = 0\n",
443
+ "dur = 0\n",
444
+ "vel = 90\n",
445
+ "pitch = 60\n",
446
+ "channel = 0\n",
447
+ "patch = 0\n",
448
+ "\n",
449
+ "patches = [0] * 16\n",
450
+ "\n",
451
+ "for ss in song1:\n",
452
+ "\n",
453
+ " if 0 <= ss < 256:\n",
454
+ "\n",
455
+ " time += ss * 16\n",
456
+ "\n",
457
+ " if 256 <= ss < 384:\n",
458
+ "\n",
459
+ " pitch = ss-256\n",
460
+ "\n",
461
+ " if 384 <= ss < 640:\n",
462
+ "\n",
463
+ " dur = (ss-384) * 16\n",
464
+ "\n",
465
+ " if 640 <= ss < 768:\n",
466
+ " \n",
467
+ " vel = (ss-640)\n",
468
+ " \n",
469
+ " song_f.append(['note', time, dur, channel, pitch, vel, patch])\n",
470
+ "\n",
471
+ "detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,\n",
472
+ " output_signature = 'Score 2 Performance Transformer', \n",
473
+ " output_file_name = '/home/ubuntu/Score-2-Performance-Transformer-Music-Composition', \n",
474
+ " track_name='Project Los Angeles',\n",
475
+ " list_of_MIDI_patches=patches\n",
476
+ " )\n",
477
+ "\n",
478
+ "print('Done!')"
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "markdown",
483
+ "metadata": {
484
+ "id": "z87TlDTVl5cp"
485
+ },
486
+ "source": [
487
+ "# Congrats! You did it! :)"
488
+ ]
489
+ }
490
+ ],
491
+ "metadata": {
492
+ "accelerator": "GPU",
493
+ "colab": {
494
+ "gpuClass": "premium",
495
+ "gpuType": "T4",
496
+ "private_outputs": true,
497
+ "provenance": []
498
+ },
499
+ "kernelspec": {
500
+ "display_name": "Python 3 (ipykernel)",
501
+ "language": "python",
502
+ "name": "python3"
503
+ },
504
+ "language_info": {
505
+ "codemirror_mode": {
506
+ "name": "ipython",
507
+ "version": 3
508
+ },
509
+ "file_extension": ".py",
510
+ "mimetype": "text/x-python",
511
+ "name": "python",
512
+ "nbconvert_exporter": "python",
513
+ "pygments_lexer": "ipython3",
514
+ "version": "3.10.12"
515
+ }
516
+ },
517
+ "nbformat": 4,
518
+ "nbformat_minor": 4
519
+ }