Update model3.py
Browse filesfixed issue with hf trainer / evaluate
model3.py
CHANGED
@@ -1,1055 +1,938 @@
|
|
1 |
-
|
2 |
-
import os
|
3 |
-
import
|
4 |
-
import
|
5 |
-
import
|
6 |
-
import
|
7 |
-
import
|
8 |
-
import
|
9 |
-
import
|
10 |
-
import
|
11 |
-
import
|
12 |
-
import
|
13 |
-
import
|
14 |
-
import
|
15 |
-
import
|
16 |
-
import
|
17 |
-
|
18 |
-
import
|
19 |
-
from
|
20 |
-
from
|
21 |
-
from
|
22 |
-
from
|
23 |
-
|
24 |
-
|
25 |
-
from
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
device = "cuda"
|
53 |
-
|
54 |
-
class
|
55 |
-
def __init__(self,
|
56 |
-
super(
|
57 |
-
self.
|
58 |
-
self.
|
59 |
-
self.
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
self.
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
def
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
else
|
108 |
-
|
109 |
-
|
110 |
-
class
|
111 |
-
def __init__(self,
|
112 |
-
super().__init__(
|
113 |
-
self.
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
bias =
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
self.
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
self.
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
self.
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
self.
|
275 |
-
|
276 |
-
|
277 |
-
self.checkpointing
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
if self.
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
def
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
self.
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
self.
|
362 |
-
|
363 |
-
self.
|
364 |
-
self.
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
self.
|
372 |
-
|
373 |
-
self.
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
self.
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
self.
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
self.checkpointing
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
def
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
self.
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
else
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
self.
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
self.config.
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
)
|
658 |
-
self.
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
)
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
self.
|
677 |
-
self.
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
|
736 |
-
|
737 |
-
|
738 |
-
|
739 |
-
|
740 |
-
|
741 |
-
|
742 |
-
|
743 |
-
|
744 |
-
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
def
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
|
785 |
-
|
786 |
-
|
787 |
-
|
788 |
-
|
789 |
-
|
790 |
-
|
791 |
-
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
|
798 |
-
|
799 |
-
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
return
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
|
822 |
-
|
823 |
-
|
824 |
-
|
825 |
-
|
826 |
-
|
827 |
-
|
828 |
-
|
829 |
-
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
|
857 |
-
|
858 |
-
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
|
865 |
-
|
866 |
-
|
867 |
-
|
868 |
-
|
869 |
-
|
870 |
-
|
871 |
-
|
872 |
-
|
873 |
-
|
874 |
-
|
875 |
-
|
876 |
-
|
877 |
-
|
878 |
-
|
879 |
-
|
880 |
-
|
881 |
-
|
882 |
-
|
883 |
-
|
884 |
-
|
885 |
-
|
886 |
-
|
887 |
-
|
888 |
-
|
889 |
-
tokenizer
|
890 |
-
|
891 |
-
|
892 |
-
|
893 |
-
|
894 |
-
|
895 |
-
|
896 |
-
|
897 |
-
|
898 |
-
|
899 |
-
|
900 |
-
|
901 |
-
|
902 |
-
|
903 |
-
|
904 |
-
|
905 |
-
|
906 |
-
|
907 |
-
|
908 |
-
|
909 |
-
|
910 |
-
|
911 |
-
|
912 |
-
|
913 |
-
|
914 |
-
|
915 |
-
|
916 |
-
|
917 |
-
|
918 |
-
|
919 |
-
|
920 |
-
|
921 |
-
|
922 |
-
|
923 |
-
|
924 |
-
|
925 |
-
|
926 |
-
|
927 |
-
|
928 |
-
|
929 |
-
|
930 |
-
|
931 |
-
|
932 |
-
|
933 |
-
|
934 |
-
|
935 |
-
|
936 |
-
|
937 |
-
|
938 |
-
|
939 |
-
print(f"Step {state.global_step} - Sample Prediction: {pred_str[sample_index]}")
|
940 |
-
print(f"Step {state.global_step} - Sample Label: {label_str[sample_index]}")
|
941 |
-
|
942 |
-
self.predictions = None
|
943 |
-
self.label_ids = None
|
944 |
-
|
945 |
-
def create_compute_metrics(callback_instance):
|
946 |
-
def compute_metrics(eval_pred):
|
947 |
-
pred_logits = eval_pred.predictions
|
948 |
-
label_ids = eval_pred.label_ids
|
949 |
-
|
950 |
-
if isinstance(pred_logits, tuple):
|
951 |
-
pred_ids = pred_logits[0]
|
952 |
-
else:
|
953 |
-
pred_ids = pred_logits
|
954 |
-
if pred_ids.ndim == 3:
|
955 |
-
pred_ids = np.argmax(pred_ids, axis=-1)
|
956 |
-
|
957 |
-
label_ids[label_ids == -100] = callback_instance.tokenizer.pad_token_id
|
958 |
-
callback_instance.predictions = pred_ids
|
959 |
-
callback_instance.label_ids = label_ids
|
960 |
-
|
961 |
-
pred_str = callback_instance.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
962 |
-
label_str = callback_instance.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
963 |
-
cer = 100 * callback_instance.metric.compute(predictions=pred_str, references=label_str)
|
964 |
-
|
965 |
-
pred_flat = pred_ids.flatten()
|
966 |
-
labels_flat = label_ids.flatten()
|
967 |
-
mask = labels_flat != callback_instance.tokenizer.pad_token_id
|
968 |
-
|
969 |
-
accuracy = accuracy_score(labels_flat[mask], pred_flat[mask])
|
970 |
-
precision = precision_score(labels_flat[mask], pred_flat[mask], average='weighted', zero_division=0)
|
971 |
-
recall = recall_score(labels_flat[mask], pred_flat[mask], average='weighted', zero_division=0)
|
972 |
-
f1 = f1_score(labels_flat[mask], pred_flat[mask], average='weighted', zero_division=0)
|
973 |
-
|
974 |
-
return {
|
975 |
-
"cer": cer,
|
976 |
-
"accuracy": accuracy,
|
977 |
-
"precision": precision,
|
978 |
-
"recall": recall,
|
979 |
-
"f1": f1
|
980 |
-
}
|
981 |
-
return compute_metrics
|
982 |
-
|
983 |
-
training_args = Seq2SeqTrainingArguments(
|
984 |
-
output_dir=log_dir,
|
985 |
-
logging_dir=log_dir,
|
986 |
-
overwrite_output_dir=True,
|
987 |
-
per_device_train_batch_size=1,
|
988 |
-
gradient_accumulation_steps=1,
|
989 |
-
eval_accumulation_steps=1,
|
990 |
-
num_train_epochs=1,
|
991 |
-
tf32=True,
|
992 |
-
bf16=True,
|
993 |
-
max_steps=10000,
|
994 |
-
save_steps=1000,
|
995 |
-
eval_steps=20,
|
996 |
-
eval_strategy="steps",
|
997 |
-
eval_on_start=False,
|
998 |
-
warmup_steps=100,
|
999 |
-
logging_steps=10,
|
1000 |
-
logging_strategy="steps",
|
1001 |
-
save_strategy="steps",
|
1002 |
-
report_to=["tensorboard"],
|
1003 |
-
push_to_hub=False,
|
1004 |
-
remove_unused_columns=False,
|
1005 |
-
label_names=["labels"],
|
1006 |
-
hub_private_repo=True,
|
1007 |
-
metric_for_best_model="cer",
|
1008 |
-
greater_is_better=False,
|
1009 |
-
load_best_model_at_end=True,
|
1010 |
-
optim="adafactor",
|
1011 |
-
weight_decay=0.00025,
|
1012 |
-
disable_tqdm=False,
|
1013 |
-
save_total_limit=2,
|
1014 |
-
use_cpu=False,
|
1015 |
-
torch_empty_cache_steps=10
|
1016 |
-
|
1017 |
-
)
|
1018 |
-
|
1019 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
1020 |
-
torch.backends.cudnn.allow_tf32 = True
|
1021 |
-
torch.cuda.empty_cache()
|
1022 |
-
torch.cuda.set_device(0)
|
1023 |
-
|
1024 |
-
cer_metric = evaluate.load("cer")
|
1025 |
-
tb_writer = SummaryWriter(log_dir)
|
1026 |
-
|
1027 |
-
metrics_callback = MetricsCallback(tb_writer, tokenizer, cer_metric, log_every_n_steps=30)
|
1028 |
-
compute_metrics = create_compute_metrics(metrics_callback)
|
1029 |
-
|
1030 |
-
trainer = Seq2SeqTrainer(
|
1031 |
-
args=training_args,
|
1032 |
-
model=model,
|
1033 |
-
train_dataset=train,
|
1034 |
-
eval_dataset=test,
|
1035 |
-
data_collator=data_collator,
|
1036 |
-
tokenizer=processor.feature_extractor,
|
1037 |
-
compute_metrics=compute_metrics,
|
1038 |
-
callbacks=[metrics_callback]
|
1039 |
-
)
|
1040 |
-
|
1041 |
-
|
1042 |
-
|
1043 |
-
|
1044 |
-
trainer.train(resume_from_checkpoint=True)
|
1045 |
-
tb_writer.close()
|
1046 |
-
from torch.utils.tensorboard import SummaryWriter
|
1047 |
-
|
1048 |
-
|
1049 |
-
path = "./models/echo2_4k"
|
1050 |
-
model.save_pretrained(path)
|
1051 |
-
processor.save_pretrained(path)
|
1052 |
-
tokenizer.save_pretrained(path)
|
1053 |
-
feature_extractor.save_pretrained(path)
|
1054 |
-
|
1055 |
-
|
|
|
1 |
+
|
2 |
+
import base64, gzip, torch, evaluate, math, os, sys, time
|
3 |
+
import gzip
|
4 |
+
from torch import amp, Tensor, optim
|
5 |
+
from torch.utils.checkpoint import checkpoint
|
6 |
+
from contextlib import contextmanager
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from transformers.models.whisper.modeling_whisper import WhisperPreTrainedModel
|
9 |
+
from transformers.models.whisper.generation_whisper import WhisperGenerationMixin
|
10 |
+
from transformers.optimization import Adafactor, AdafactorSchedule
|
11 |
+
from huggingface_hub import PyTorchModelHubMixin
|
12 |
+
from datasets import IterableDatasetDict, Audio, load_dataset
|
13 |
+
import numpy as np
|
14 |
+
import torch, transformers, warnings
|
15 |
+
from typing import Dict, Iterable, Optional, Tuple, Union, List, Any, Type
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from torch import Tensor, nn
|
18 |
+
import torchaudio, torchaudio.transforms as T
|
19 |
+
from transformers import Seq2SeqTrainer, TrainerCallback, Seq2SeqTrainingArguments, WhisperTokenizer, WhisperForConditionalGeneration, WhisperConfig, WhisperProcessor, WhisperFeatureExtractor, WhisperTokenizer, WhisperForConditionalGeneration
|
20 |
+
from whisper.decoding import decode as decode_function
|
21 |
+
from whisper.decoding import detect_language as detect_language_function
|
22 |
+
from whisper.transcribe import transcribe as transcribe_function
|
23 |
+
|
24 |
+
try:
|
25 |
+
from torch.nn.functional import scaled_dot_product_attention
|
26 |
+
|
27 |
+
SDPA_AVAILABLE = True
|
28 |
+
except (ImportError, RuntimeError, OSError):
|
29 |
+
scaled_dot_product_attention = None
|
30 |
+
SDPA_AVAILABLE = False
|
31 |
+
|
32 |
+
transformers.utils.logging.set_verbosity_error()
|
33 |
+
warnings.filterwarnings(action="ignore")
|
34 |
+
warnings.warn = lambda *args,**kwargs: None
|
35 |
+
device = "cuda"
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
class LayerNorm(nn.Module):
|
40 |
+
def __init__(self, num_features, eps=1e-6):
|
41 |
+
super(LayerNorm, self).__init__()
|
42 |
+
self.gamma = nn.Parameter(torch.ones(num_features))
|
43 |
+
self.beta = nn.Parameter(torch.zeros(num_features))
|
44 |
+
self.eps = eps
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
mean = x.mean(dim=-1, keepdim=True)
|
48 |
+
std = x.std(dim=-1, keepdim=True)
|
49 |
+
x = (x - mean) / (std + self.eps)
|
50 |
+
return self.gamma * x + self.beta
|
51 |
+
|
52 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
53 |
+
|
54 |
+
class Linear(nn.Module):
|
55 |
+
def __init__(self, in_features: int, out_features: int, dropout_rate = 0.01, use_batchnorm: bool = True, activation: str = 'relu'):
|
56 |
+
super(Linear, self).__init__()
|
57 |
+
self.linear = nn.Linear(in_features, out_features)
|
58 |
+
self.dropout = nn.Dropout(dropout_rate)
|
59 |
+
self.use_batchnorm = use_batchnorm
|
60 |
+
self.activation = activation
|
61 |
+
|
62 |
+
if self.use_batchnorm:
|
63 |
+
self.batchnorm = nn.BatchNorm1d(out_features)
|
64 |
+
self.reset_parameters()
|
65 |
+
|
66 |
+
def reset_parameters(self):
|
67 |
+
nn.init.kaiming_uniform_(self.linear.weight, nonlinearity=self.activation)
|
68 |
+
if self.linear.bias is not None:
|
69 |
+
nn.init.zeros_(self.linear.bias)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
batch_size, seq_len, _ = x.size()
|
73 |
+
x = x.view(-1, x.size(-1))
|
74 |
+
x = self.linear(x)
|
75 |
+
|
76 |
+
if self.use_batchnorm:
|
77 |
+
x = self.batchnorm(x)
|
78 |
+
|
79 |
+
x = self.apply_activation(x)
|
80 |
+
x = self.dropout(x)
|
81 |
+
x = x.view(batch_size, seq_len, -1)
|
82 |
+
|
83 |
+
return x
|
84 |
+
|
85 |
+
def apply_activation(self, x):
|
86 |
+
if self.activation == 'relu':
|
87 |
+
return F.relu(x)
|
88 |
+
elif self.activation == 'tanh':
|
89 |
+
return torch.tanh(x)
|
90 |
+
elif self.activation == 'sigmoid':
|
91 |
+
return torch.sigmoid(x)
|
92 |
+
else:
|
93 |
+
raise ValueError(f'Unsupported activation function: {self.activation}')
|
94 |
+
|
95 |
+
class Conv1d(nn.Conv1d):
|
96 |
+
def __init__(self, *args, **kwargs):
|
97 |
+
super().__init__(*args, **kwargs)
|
98 |
+
self.reset_parameters()
|
99 |
+
|
100 |
+
def reset_parameters(self):
|
101 |
+
nn.init.kaiming_uniform_(self.weight, nonlinearity='relu')
|
102 |
+
if self.bias is not None:
|
103 |
+
nn.init.zeros_(self.bias)
|
104 |
+
|
105 |
+
def _conv_forward(self, x, weight, bias) -> Tensor:
|
106 |
+
weight = self.weight.to(x.dtype)
|
107 |
+
bias = None if self.bias is None else self.bias.to(x.dtype)
|
108 |
+
return super()._conv_forward(x, weight, bias)
|
109 |
+
|
110 |
+
class BiasedCrossAttention(nn.Module):
|
111 |
+
def __init__(self, n_state, n_head, dropout_rate=0.1):
|
112 |
+
super().__init__()
|
113 |
+
self.n_head = n_head
|
114 |
+
self.n_state = n_state
|
115 |
+
self.head_dim = n_state // n_head
|
116 |
+
|
117 |
+
self.query = nn.Linear(n_state, n_state)
|
118 |
+
self.key = nn.Linear(n_state, n_state, bias=False)
|
119 |
+
self.value = nn.Linear(n_state, n_state)
|
120 |
+
self.out = nn.Linear(n_state, n_state)
|
121 |
+
|
122 |
+
self.bias = nn.Parameter(torch.zeros(n_head, 1, self.head_dim))
|
123 |
+
self.dropout = nn.Dropout(dropout_rate)
|
124 |
+
self.norm = LayerNorm(n_state)
|
125 |
+
|
126 |
+
def forward(self, q, k, v, mask=None):
|
127 |
+
batch_size, seq_length, _ = q.size()
|
128 |
+
|
129 |
+
q = self.query(q).view(batch_size, seq_length, self.n_head, self.head_dim)
|
130 |
+
k = self.key(k).view(batch_size, seq_length, self.n_head, self.head_dim)
|
131 |
+
v = self.value(v).view(batch_size, seq_length, self.n_head, self.head_dim)
|
132 |
+
|
133 |
+
qk = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) + self.bias
|
134 |
+
if mask is not None:
|
135 |
+
qk = qk.masked_fill(mask == 0, float('-inf'))
|
136 |
+
|
137 |
+
w = F.softmax(qk, dim=-1)
|
138 |
+
w = self.dropout(w)
|
139 |
+
|
140 |
+
out = (w @ v).transpose(1, 2).contiguous().view(batch_size, seq_length, -1)
|
141 |
+
out = self.norm(self.out(out) + q.view(batch_size, seq_length, -1))
|
142 |
+
return out
|
143 |
+
|
144 |
+
class DynamicConvAttention(nn.Module):
|
145 |
+
def __init__(self, n_state, n_head, kernel_size=3, dropout_rate=0.1):
|
146 |
+
super().__init__()
|
147 |
+
self.n_state = n_state
|
148 |
+
self.n_head = n_head
|
149 |
+
self.kernel_size = kernel_size
|
150 |
+
|
151 |
+
self.conv = nn.Conv1d(n_state, n_state, kernel_size, padding=kernel_size // 2, groups=n_head)
|
152 |
+
self.dropout = nn.Dropout(dropout_rate)
|
153 |
+
|
154 |
+
self.query = nn.Linear(n_state, n_state)
|
155 |
+
self.key = nn.Linear(n_state, n_state, bias=False)
|
156 |
+
self.value = nn.Linear(n_state, n_state)
|
157 |
+
self.out_proj = nn.Linear(n_state, n_state)
|
158 |
+
|
159 |
+
self.norm = LayerNorm(n_state)
|
160 |
+
|
161 |
+
def forward(self, x):
|
162 |
+
batch_size, seq_len, embed_dim = x.size()
|
163 |
+
if embed_dim != self.n_state:
|
164 |
+
raise ValueError(f"Expected embed_dim of {self.n_state}, but got {embed_dim}")
|
165 |
+
|
166 |
+
q = self.query(x)
|
167 |
+
k = self.key(x)
|
168 |
+
v = self.value(x)
|
169 |
+
|
170 |
+
x = x.permute(0, 2, 1)
|
171 |
+
conv_out = self.conv(x)
|
172 |
+
conv_out = conv_out.permute(0, 2, 1)
|
173 |
+
conv_out = self.norm(conv_out)
|
174 |
+
conv_out = self.dropout(conv_out)
|
175 |
+
|
176 |
+
attention_out = F.softmax(torch.matmul(q, k.transpose(-2, -1)) / (self.n_state ** 0.5), dim=-1)
|
177 |
+
attention_out = torch.matmul(attention_out, v)
|
178 |
+
|
179 |
+
combined_out = conv_out + attention_out
|
180 |
+
combined_out = self.norm(combined_out)
|
181 |
+
|
182 |
+
return self.out_proj(self.dropout(combined_out)) + x.permute(0, 2, 1)
|
183 |
+
|
184 |
+
class HybridAttention(nn.Module):
|
185 |
+
def __init__(self, n_state, n_head, window_size=1, dropout_rate=0.1):
|
186 |
+
super().__init__()
|
187 |
+
self.local_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)
|
188 |
+
self.global_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)
|
189 |
+
self.ln_local = LayerNorm(n_state)
|
190 |
+
self.ln_global = LayerNorm(n_state)
|
191 |
+
|
192 |
+
self.dropout = nn.Dropout(dropout_rate)
|
193 |
+
self.window_size = window_size
|
194 |
+
|
195 |
+
def forward(self, x):
|
196 |
+
x_local = self.ln_local(x)
|
197 |
+
x_global = self.ln_global(x)
|
198 |
+
x_local = x_local.permute(1, 0, 2)
|
199 |
+
x_global = x_global.permute(1, 0, 2)
|
200 |
+
local_out = self.sliding_window_attention(x_local)
|
201 |
+
global_out, _ = self.global_attn(x_global, x_global, x_global)
|
202 |
+
combined_out = local_out + global_out
|
203 |
+
combined_out = combined_out.permute(1, 0, 2)
|
204 |
+
return self.dropout(combined_out)
|
205 |
+
|
206 |
+
def sliding_window_attention(self, x):
|
207 |
+
seq_len, batch_size, n_state = x.size()
|
208 |
+
window_size = min(self.window_size, max(1, seq_len // 4))
|
209 |
+
output = torch.zeros_like(x, device=x.device, dtype=x.dtype)
|
210 |
+
|
211 |
+
for i in range(0, seq_len, window_size):
|
212 |
+
end = min(i + window_size, seq_len)
|
213 |
+
query = x[i:end, :, :]
|
214 |
+
start = max(0, i - window_size)
|
215 |
+
key = x[start:end, :, :]
|
216 |
+
value = x[start:end, :, :]
|
217 |
+
attn_output, _ = self.local_attn(query, key, value)
|
218 |
+
output[i:end, :, :] = attn_output[:end - i, :, :]
|
219 |
+
|
220 |
+
return output
|
221 |
+
|
222 |
+
def givens_rotation_matrix(n_state, i, j, theta):
|
223 |
+
G = torch.eye(n_state)
|
224 |
+
G[i, i] = math.cos(theta)
|
225 |
+
G[i, j] = -math.sin(theta)
|
226 |
+
G[j, i] = math.sin(theta)
|
227 |
+
G[j, j] = math.cos(theta)
|
228 |
+
return G
|
229 |
+
|
230 |
+
class GivensRotations(nn.Module):
|
231 |
+
def __init__(self, h_dim, num_rotations):
|
232 |
+
super().__init__()
|
233 |
+
self.h_dim = h_dim
|
234 |
+
self.num_rotations = num_rotations
|
235 |
+
self.thetas = nn.Parameter(torch.zeros(num_rotations))
|
236 |
+
|
237 |
+
def forward(self, x):
|
238 |
+
if x.dim() != 4:
|
239 |
+
raise ValueError(f"Expected input tensor to be 4D, but got {x.dim()}D")
|
240 |
+
|
241 |
+
batch_size, seq_len, n_head, h_dim = x.size()
|
242 |
+
|
243 |
+
if h_dim != self.h_dim:
|
244 |
+
raise ValueError(f"Expected h_dim of {self.h_dim}, but got {h_dim}")
|
245 |
+
|
246 |
+
x = x.view(-1, h_dim)
|
247 |
+
for k in range(self.num_rotations):
|
248 |
+
i, j = k % self.h_dim, (k + 1) % self.h_dim
|
249 |
+
G = givens_rotation_matrix(self.h_dim, i, j, self.thetas[k])
|
250 |
+
x = torch.matmul(x, G.to(x.device))
|
251 |
+
|
252 |
+
x = x.view(batch_size, seq_len, n_head, h_dim)
|
253 |
+
return x
|
254 |
+
|
255 |
+
class RotaryEmbeddingWithRotation(nn.Module):
|
256 |
+
def __init__(self, n_state, n_head, base=10000, checkpointing=False):
|
257 |
+
super().__init__()
|
258 |
+
self.n_state = n_state
|
259 |
+
self.n_head = n_head
|
260 |
+
self.h_dim = n_state // n_head
|
261 |
+
self.base = base # Initialize base
|
262 |
+
self.checkpointing = checkpointing
|
263 |
+
|
264 |
+
self.rotation_matrix = nn.Parameter(torch.eye(self.h_dim))
|
265 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))
|
266 |
+
self.register_buffer('inv_freq', inv_freq)
|
267 |
+
|
268 |
+
def update_base(self, new_base):
|
269 |
+
self.base = new_base
|
270 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))
|
271 |
+
self.register_buffer('inv_freq', inv_freq)
|
272 |
+
|
273 |
+
def reset_parameters(self):
|
274 |
+
nn.init.orthogonal_(self.rotation_matrix)
|
275 |
+
|
276 |
+
def forward(self, x):
|
277 |
+
if self.checkpointing:
|
278 |
+
return checkpoint(self._forward, x)
|
279 |
+
else:
|
280 |
+
return self._forward(x)
|
281 |
+
|
282 |
+
def _forward(self, x):
|
283 |
+
if x.dim() == 3:
|
284 |
+
batch_size, seq_len, n_state = x.size()
|
285 |
+
elif x.dim() == 4:
|
286 |
+
batch_size, seq_len, n_head, h_dim = x.size()
|
287 |
+
n_state = n_head * h_dim
|
288 |
+
x = x.view(batch_size, seq_len, n_state)
|
289 |
+
else:
|
290 |
+
raise ValueError(f"Expected input tensor to be 3D or 4D, but got {x.dim()}D")
|
291 |
+
|
292 |
+
if n_state != self.n_state:
|
293 |
+
raise ValueError(f"Expected n_state of {self.n_state}, but got {n_state}")
|
294 |
+
|
295 |
+
x = x.reshape(batch_size, seq_len, self.n_head, self.h_dim)
|
296 |
+
x = x.reshape(-1, self.h_dim)
|
297 |
+
rotated_x = torch.matmul(x, self.rotation_matrix)
|
298 |
+
rotated_x = rotated_x.reshape(batch_size, seq_len, self.n_head, self.h_dim)
|
299 |
+
|
300 |
+
sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(seq_len, device=x.device), self.inv_freq.to(x.device))
|
301 |
+
sin = sinusoid_inp.sin()[None, :, None, :]
|
302 |
+
cos = sinusoid_inp.cos()[None, :, None, :]
|
303 |
+
x1, x2 = rotated_x[..., ::2], rotated_x[..., 1::2]
|
304 |
+
rotated_x = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
|
305 |
+
|
306 |
+
rotated_x = rotated_x.reshape(batch_size, seq_len, self.n_state)
|
307 |
+
return rotated_x
|
308 |
+
|
309 |
+
class LearnedSinusoidalEmbeddings(nn.Module):
|
310 |
+
def __init__(self, n_ctx, n_state, checkpointing=False):
|
311 |
+
super().__init__()
|
312 |
+
self.n_ctx = n_ctx
|
313 |
+
self.n_state = n_state
|
314 |
+
self.checkpointing = checkpointing
|
315 |
+
|
316 |
+
position = torch.arange(0, n_ctx, dtype=torch.float).unsqueeze(1)
|
317 |
+
div_term = torch.exp(torch.arange(0, n_state, 2).float() * -(math.log(10000.0) / n_state))
|
318 |
+
features = torch.zeros(n_ctx, n_state)
|
319 |
+
features[:, 0::2] = torch.sin(position * div_term)
|
320 |
+
features[:, 1::2] = torch.cos(position * div_term)
|
321 |
+
self.register_buffer('sinusoidal_features', features)
|
322 |
+
|
323 |
+
self.positional_embeddings = nn.Parameter(self.sinusoidal_features.clone())
|
324 |
+
|
325 |
+
def forward(self, positions):
|
326 |
+
if self.checkpointing:
|
327 |
+
position_embeddings = checkpoint(lambda x: self.positional_embeddings[x], positions)
|
328 |
+
else:
|
329 |
+
position_embeddings = self.positional_embeddings[positions]
|
330 |
+
|
331 |
+
position_embeddings = torch.nn.functional.normalize(position_embeddings, p=2, dim=-1)
|
332 |
+
return position_embeddings
|
333 |
+
|
334 |
+
class MultiHeadAttention(nn.Module):
|
335 |
+
use_sdpa = True
|
336 |
+
|
337 |
+
def __init__(self, n_state: int, n_head: int, base: int = 10000, max_rel_dist: int = 1):
|
338 |
+
super().__init__()
|
339 |
+
assert n_state % n_head == 0, "n_state must be divisible by n_head"
|
340 |
+
self.n_head = n_head
|
341 |
+
self.h_dim = n_state // n_head
|
342 |
+
assert self.h_dim % 2 == 0, "Head dimension must be even for rotary embeddings"
|
343 |
+
|
344 |
+
self.positional_scaling = nn.Parameter(torch.ones(1))
|
345 |
+
|
346 |
+
self.query = nn.Linear(n_state, n_state)
|
347 |
+
self.key = nn.Linear(n_state, n_state, bias=False)
|
348 |
+
self.value = nn.Linear(n_state, n_state)
|
349 |
+
self.out = nn.Linear(n_state, n_state)
|
350 |
+
|
351 |
+
self.max_rel_dist = max_rel_dist
|
352 |
+
self.base = base
|
353 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))
|
354 |
+
self.register_buffer('inv_freq', inv_freq)
|
355 |
+
|
356 |
+
self.rotary_embedding = RotaryEmbeddingWithRotation(n_state, n_head, base=10000)
|
357 |
+
|
358 |
+
self.rotation_matrix = nn.Parameter(torch.empty(self.h_dim, self.h_dim))
|
359 |
+
nn.init.orthogonal_(self.rotation_matrix)
|
360 |
+
|
361 |
+
self.givens_rotations = GivensRotations(self.h_dim, num_rotations=self.h_dim // 2)
|
362 |
+
|
363 |
+
self.rel_pos_bias = nn.Embedding(2 * self.max_rel_dist - 1, self.n_head)
|
364 |
+
self.rel_pos_bias.weight.data.fill_(0)
|
365 |
+
|
366 |
+
if device:
|
367 |
+
self.to(device)
|
368 |
+
|
369 |
+
def update_base(self, new_base):
|
370 |
+
self.base = new_base
|
371 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))
|
372 |
+
self.register_buffer('inv_freq', inv_freq)
|
373 |
+
self.rotary_embedding.update_base(new_base)
|
374 |
+
|
375 |
+
def apply_rotary_embedding(self, x: torch.Tensor) -> torch.Tensor:
|
376 |
+
seq_len = x.shape[1]
|
377 |
+
positions = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
|
378 |
+
scaled_positions = self.positional_scaling * positions
|
379 |
+
sinusoid_inp = torch.outer(scaled_positions, self.inv_freq.to(x.device))
|
380 |
+
sin = sinusoid_inp.sin()[None, :, None, :]
|
381 |
+
cos = sinusoid_inp.cos()[None, :, None, :]
|
382 |
+
|
383 |
+
x1, x2 = x[..., ::2], x[..., 1::2]
|
384 |
+
x_rotated = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
|
385 |
+
return x_rotated
|
386 |
+
|
387 |
+
def forward(self, x, xa: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, kv_cache: Optional[dict] = None):
|
388 |
+
q = self.query(x)
|
389 |
+
|
390 |
+
if kv_cache is None or xa is None or 'k' not in kv_cache:
|
391 |
+
k_input = x if xa is None else xa
|
392 |
+
k = self.key(k_input)
|
393 |
+
v = self.value(k_input)
|
394 |
+
if kv_cache is not None:
|
395 |
+
kv_cache['k'] = k
|
396 |
+
kv_cache['v'] = v
|
397 |
+
else:
|
398 |
+
k = kv_cache['k']
|
399 |
+
v = kv_cache['v']
|
400 |
+
|
401 |
+
q = q.view(q.shape[0], q.shape[1], self.n_head, -1)
|
402 |
+
k = k.view(k.shape[0], k.shape[1], self.n_head, -1)
|
403 |
+
v = v.view(v.shape[0], v.shape[1], self.n_head, -1)
|
404 |
+
|
405 |
+
q = self.apply_rotary_embedding(q)
|
406 |
+
k = self.apply_rotary_embedding(k)
|
407 |
+
|
408 |
+
q = torch.matmul(q, self.rotation_matrix)
|
409 |
+
k = torch.matmul(k, self.rotation_matrix)
|
410 |
+
|
411 |
+
q = self.givens_rotations(q)
|
412 |
+
k = self.givens_rotations(k)
|
413 |
+
|
414 |
+
q = q.view(q.shape[0], q.shape[1], -1)
|
415 |
+
k = k.view(k.shape[0], k.shape[1], -1)
|
416 |
+
|
417 |
+
wv, qk = self.qkv_attention(q, k, v, mask)
|
418 |
+
return self.out(wv), qk
|
419 |
+
|
420 |
+
def qkv_attention(self, q, k, v, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
421 |
+
n_batch, n_ctx, n_state = q.shape
|
422 |
+
|
423 |
+
scale = (n_state // self.n_head) ** -0.25
|
424 |
+
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
425 |
+
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
426 |
+
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
427 |
+
|
428 |
+
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
429 |
+
|
430 |
+
seq_len_q = q.size(2)
|
431 |
+
seq_len_k = k.size(2)
|
432 |
+
|
433 |
+
positions = torch.arange(seq_len_q, device=q.device).unsqueeze(1) - torch.arange(seq_len_k, device=q.device).unsqueeze(0)
|
434 |
+
positions = positions.clamp(-self.max_rel_dist + 1, self.max_rel_dist - 1) + self.max_rel_dist - 1
|
435 |
+
rel_bias = self.rel_pos_bias(positions)
|
436 |
+
rel_bias = rel_bias.permute(2, 0, 1).unsqueeze(0)
|
437 |
+
|
438 |
+
qk = qk + rel_bias
|
439 |
+
|
440 |
+
if mask is not None:
|
441 |
+
qk = qk + mask[:n_ctx, :n_ctx]
|
442 |
+
qk = qk.float()
|
443 |
+
|
444 |
+
w = F.softmax(qk, dim=-1).to(q.dtype)
|
445 |
+
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
446 |
+
qk = qk.detach()
|
447 |
+
|
448 |
+
return out, qk
|
449 |
+
|
450 |
+
class ResidualAttentionBlock(nn.Module):
|
451 |
+
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, max_rel_dist = 1, checkpointing=False):
|
452 |
+
super().__init__()
|
453 |
+
|
454 |
+
self.attn = MultiHeadAttention(n_state, n_head)
|
455 |
+
self.attn_ln = LayerNorm(n_state)
|
456 |
+
self.checkpointing = checkpointing
|
457 |
+
self.max_rel_dist = max_rel_dist
|
458 |
+
|
459 |
+
self.cross_attn = (
|
460 |
+
MultiHeadAttention(n_state, n_head) if cross_attention else None
|
461 |
+
)
|
462 |
+
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
463 |
+
|
464 |
+
n_mlp = n_state * 4
|
465 |
+
self.mlp = nn.Sequential(
|
466 |
+
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
|
467 |
+
)
|
468 |
+
self.mlp_ln = LayerNorm(n_state)
|
469 |
+
|
470 |
+
def forward(self, x, xa: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, kv_cache: Optional[dict] = None):
|
471 |
+
if self.checkpointing:
|
472 |
+
x = checkpoint(self._attn_forward, x, mask, kv_cache)
|
473 |
+
else:
|
474 |
+
x = self._attn_forward(x, mask, kv_cache)
|
475 |
+
|
476 |
+
if self.cross_attn:
|
477 |
+
if self.checkpointing:
|
478 |
+
x = checkpoint(self._cross_attn_forward, x, xa, kv_cache)
|
479 |
+
else:
|
480 |
+
x = self._cross_attn_forward(x, xa, kv_cache)
|
481 |
+
|
482 |
+
if self.checkpointing:
|
483 |
+
x = checkpoint(self._mlp_forward, x)
|
484 |
+
else:
|
485 |
+
x = self._mlp_forward(x)
|
486 |
+
|
487 |
+
return x
|
488 |
+
|
489 |
+
def _attn_forward(self, x, mask, kv_cache):
|
490 |
+
residual = x
|
491 |
+
x = self.attn_ln(x)
|
492 |
+
x = residual + self.attn(x, mask=mask, kv_cache=kv_cache)[0]
|
493 |
+
return x
|
494 |
+
|
495 |
+
def _cross_attn_forward(self, x, xa, kv_cache):
|
496 |
+
residual = x
|
497 |
+
x = self.cross_attn_ln(x)
|
498 |
+
x = residual + self.cross_attn(x, xa, kv_cache=kv_cache)[0]
|
499 |
+
return x
|
500 |
+
|
501 |
+
def _mlp_forward(self, x):
|
502 |
+
residual = x
|
503 |
+
x = self.mlp_ln(x)
|
504 |
+
x = residual + self.mlp(x)
|
505 |
+
return x
|
506 |
+
|
507 |
+
class AudioEncoder(nn.Module):
|
508 |
+
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, max_rel_dist, checkpointing=False):
|
509 |
+
super().__init__()
|
510 |
+
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
511 |
+
self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
512 |
+
self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx, n_state, checkpointing=checkpointing)
|
513 |
+
self.rotary_embedding = RotaryEmbeddingWithRotation(n_state, n_head, base=10000)
|
514 |
+
self.checkpointing = checkpointing
|
515 |
+
|
516 |
+
self.blocks = nn.ModuleList(
|
517 |
+
[ResidualAttentionBlock(n_state, n_head, max_rel_dist, checkpointing=checkpointing) for _ in range(n_layer)]
|
518 |
+
)
|
519 |
+
self.ln_post = LayerNorm(n_state)
|
520 |
+
|
521 |
+
def update_base(self, new_base):
|
522 |
+
self.rotary_embedding.update_base(new_base)
|
523 |
+
for block in self.blocks:
|
524 |
+
if isinstance(block.attn, MultiHeadAttention):
|
525 |
+
block.attn.update_base(new_base)
|
526 |
+
if block.cross_attn and isinstance(block.cross_attn, MultiHeadAttention):
|
527 |
+
block.cross_attn.update_base(new_base)
|
528 |
+
|
529 |
+
def forward(self, x):
|
530 |
+
if self.checkpointing:
|
531 |
+
x = checkpoint(self._conv_forward, x)
|
532 |
+
else:
|
533 |
+
x = self._conv_forward(x)
|
534 |
+
|
535 |
+
for block in self.blocks:
|
536 |
+
if self.checkpointing:
|
537 |
+
x = checkpoint(block, x)
|
538 |
+
else:
|
539 |
+
x = block(x)
|
540 |
+
|
541 |
+
x = self.ln_post(x)
|
542 |
+
return x
|
543 |
+
|
544 |
+
def _conv_forward(self, x):
|
545 |
+
x = F.gelu(self.conv1(x))
|
546 |
+
x = F.gelu(self.conv2(x))
|
547 |
+
x = x.permute(0, 2, 1)
|
548 |
+
x = self.rotary_embedding(x)
|
549 |
+
|
550 |
+
pos_emb = self.positional_embedding(torch.arange(x.size(1), device=x.device)).unsqueeze(0)
|
551 |
+
x = x + pos_emb
|
552 |
+
return x
|
553 |
+
|
554 |
+
class TextDecoder(nn.Module):
|
555 |
+
def __init__(self, vocab_size, n_ctx, n_state, n_head, n_layer, max_rel_dist, cross_attention, checkpointing=False):
|
556 |
+
super().__init__()
|
557 |
+
self.token_embedding = nn.Embedding(vocab_size, n_state)
|
558 |
+
self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx, n_state, checkpointing=checkpointing)
|
559 |
+
self.rotary_embedding = RotaryEmbeddingWithRotation(n_state, n_head, base=10000)
|
560 |
+
self.checkpointing = checkpointing
|
561 |
+
self.n_head = n_head
|
562 |
+
|
563 |
+
self.blocks = nn.ModuleList([
|
564 |
+
ResidualAttentionBlock(n_state, n_head, max_rel_dist, cross_attention, checkpointing=checkpointing)
|
565 |
+
for _ in range(n_layer)
|
566 |
+
])
|
567 |
+
self.ln = LayerNorm(n_state)
|
568 |
+
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
569 |
+
self.register_buffer("mask", mask, persistent=False)
|
570 |
+
|
571 |
+
def update_base(self, new_base):
|
572 |
+
self.rotary_embedding.update_base(new_base)
|
573 |
+
for block in self.blocks:
|
574 |
+
if isinstance(block.attn, MultiHeadAttention):
|
575 |
+
block.attn.update_base(new_base)
|
576 |
+
if block.cross_attn and isinstance(block.cross_attn, MultiHeadAttention):
|
577 |
+
block.cross_attn.update_base(new_base)
|
578 |
+
|
579 |
+
def forward(self, x, xa, kv_cache: Optional[dict] = None):
|
580 |
+
if self.checkpointing:
|
581 |
+
x = checkpoint(self._embedding_forward, x, xa, kv_cache)
|
582 |
+
else:
|
583 |
+
x = self._embedding_forward(x, xa, kv_cache)
|
584 |
+
|
585 |
+
for block in self.blocks:
|
586 |
+
if self.checkpointing:
|
587 |
+
x = checkpoint(block, x, xa, self.mask, kv_cache)
|
588 |
+
else:
|
589 |
+
x = block(x, xa, self.mask, kv_cache)
|
590 |
+
|
591 |
+
x = self.ln(x)
|
592 |
+
logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
|
593 |
+
|
594 |
+
return logits
|
595 |
+
|
596 |
+
def _embedding_forward(self, x, xa, kv_cache):
|
597 |
+
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
598 |
+
positions = torch.arange(x.shape[1], device=x.device) + offset
|
599 |
+
pos_emb = self.positional_embedding(positions).unsqueeze(0)
|
600 |
+
|
601 |
+
x = self.token_embedding(x) + pos_emb
|
602 |
+
x = x.to(xa.dtype)
|
603 |
+
|
604 |
+
batch_size, seq_length, embedding_dim = x.shape
|
605 |
+
num_heads = self.n_head
|
606 |
+
head_dim = embedding_dim // num_heads
|
607 |
+
x = x.view(batch_size, seq_length, num_heads, head_dim)
|
608 |
+
|
609 |
+
x = self.rotary_embedding(x)
|
610 |
+
x = x.view(batch_size, seq_length, embedding_dim)
|
611 |
+
return x
|
612 |
+
|
613 |
+
class Echo(WhisperPreTrainedModel, PyTorchModelHubMixin):
|
614 |
+
config_class = WhisperConfig
|
615 |
+
|
616 |
+
def __init__(self, config: WhisperConfig):
|
617 |
+
super().__init__(config)
|
618 |
+
self.config = config
|
619 |
+
|
620 |
+
self.n_mels = self.config.num_mel_bins
|
621 |
+
self.n_audio_ctx = self.config.max_source_positions
|
622 |
+
self.n_audio_state = self.config.d_model
|
623 |
+
self.n_audio_head = self.config.encoder_attention_heads
|
624 |
+
self.n_audio_layer = self.config.encoder_layers
|
625 |
+
self.vocab_size = self.config.vocab_size
|
626 |
+
self.n_text_ctx = self.config.max_target_positions
|
627 |
+
self.n_text_state = self.config.d_model
|
628 |
+
self.n_text_head = self.config.decoder_attention_heads
|
629 |
+
self.n_text_layer = self.config.decoder_layers
|
630 |
+
self.max_rel_dist = self.config.max_rel_dist
|
631 |
+
self.checkpointing = self.config.checkpointing
|
632 |
+
self.base = self.config.base
|
633 |
+
|
634 |
+
self.encoder = AudioEncoder(
|
635 |
+
self.config.n_mels,
|
636 |
+
self.config.n_audio_ctx,
|
637 |
+
self.config.n_audio_state,
|
638 |
+
self.config.n_audio_head,
|
639 |
+
self.config.n_audio_layer,
|
640 |
+
self.config.checkpointing,
|
641 |
+
self.config.max_rel_dist
|
642 |
+
)
|
643 |
+
self.decoder = TextDecoder(
|
644 |
+
self.config.vocab_size,
|
645 |
+
self.config.n_text_ctx,
|
646 |
+
self.config.n_text_state,
|
647 |
+
self.config.n_text_head,
|
648 |
+
self.config.n_text_layer,
|
649 |
+
self.config.checkpointing,
|
650 |
+
self.config.max_rel_dist
|
651 |
+
)
|
652 |
+
|
653 |
+
all_heads = torch.zeros(self.config.n_text_layer, self.config.n_text_head, dtype=torch.bool)
|
654 |
+
all_heads[self.config.n_text_layer // 2:] = True
|
655 |
+
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
656 |
+
|
657 |
+
self.best_loss = float('inf')
|
658 |
+
self.base = 10000
|
659 |
+
|
660 |
+
def update_base(self, new_base):
|
661 |
+
self.encoder.rotary_embedding.update_base(new_base)
|
662 |
+
self.decoder.rotary_embedding.update_base(new_base)
|
663 |
+
for name, module in self.encoder.named_modules():
|
664 |
+
if isinstance(module, MultiHeadAttention):
|
665 |
+
module.update_base(new_base)
|
666 |
+
for name, module in self.decoder.named_modules():
|
667 |
+
if isinstance(module, MultiHeadAttention):
|
668 |
+
module.update_base(new_base)
|
669 |
+
|
670 |
+
def adjust_base(self, loss, factor=1.05):
|
671 |
+
if loss < self.best_loss:
|
672 |
+
new_base = self.base * factor
|
673 |
+
else:
|
674 |
+
new_base = self.base / factor
|
675 |
+
|
676 |
+
self.update_base(new_base)
|
677 |
+
self.best_loss = loss
|
678 |
+
# print(f"Adjusted base: {new_base}")
|
679 |
+
|
680 |
+
@staticmethod
|
681 |
+
def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id) -> torch.Tensor:
|
682 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
683 |
+
shifted_input_ids[:, 1:] = input_ids[:, :-1]
|
684 |
+
shifted_input_ids[:, 0] = decoder_start_token_id
|
685 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
686 |
+
return shifted_input_ids
|
687 |
+
|
688 |
+
def forward(self, input_features, labels=None, dec_input_ids=None):
|
689 |
+
if labels is not None:
|
690 |
+
if dec_input_ids is None:
|
691 |
+
dec_input_ids = self.shift_tokens_right(
|
692 |
+
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
693 |
+
)
|
694 |
+
|
695 |
+
encoded_features = self.encoder(input_features).to(device)
|
696 |
+
logits = self.decoder(dec_input_ids, encoded_features)
|
697 |
+
|
698 |
+
loss = None
|
699 |
+
if labels is not None:
|
700 |
+
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
|
701 |
+
labels = labels.to(logits.device).long()
|
702 |
+
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
703 |
+
|
704 |
+
self.adjust_base(loss.item())
|
705 |
+
|
706 |
+
return {
|
707 |
+
"loss": loss,
|
708 |
+
"logits": logits,
|
709 |
+
"input_features": encoded_features,
|
710 |
+
"labels": labels,
|
711 |
+
"decoder_input_ids": dec_input_ids
|
712 |
+
}
|
713 |
+
|
714 |
+
def _initialize_weights(self):
|
715 |
+
nn.init.normal_(self.decoder.token_embedding.weight, mean=0.0, std=self.config.init_std)
|
716 |
+
if hasattr(self.decoder.positional_embedding, 'weight'):
|
717 |
+
nn.init.normal_(self.decoder.positional_embedding.weight, mean=0.0, std=self.config.init_std)
|
718 |
+
for block in self.decoder.blocks:
|
719 |
+
for layer in block.children():
|
720 |
+
if isinstance(layer, nn.Linear):
|
721 |
+
nn.init.xavier_normal_(layer.weight)
|
722 |
+
if layer.bias is not None:
|
723 |
+
nn.init.zeros_(layer.bias)
|
724 |
+
|
725 |
+
nn.init.constant_(self.decoder.ln.gamma, 1)
|
726 |
+
if self.decoder.ln.beta is not None:
|
727 |
+
nn.init.constant_(self.decoder.ln.beta, 0)
|
728 |
+
|
729 |
+
nn.init.xavier_normal_(self.encoder.conv1.weight)
|
730 |
+
if self.encoder.conv1.bias is not None:
|
731 |
+
nn.init.zeros_(self.encoder.conv1.bias)
|
732 |
+
|
733 |
+
nn.init.kaiming_normal_(self.encoder.conv2.weight, mode='fan_out', nonlinearity='relu')
|
734 |
+
if self.encoder.conv2.bias is not None:
|
735 |
+
nn.init.zeros_(self.encoder.conv2.bias)
|
736 |
+
|
737 |
+
nn.init.constant_(self.encoder.ln_post.gamma, 1)
|
738 |
+
if self.encoder.ln_post.beta is not None:
|
739 |
+
nn.init.constant_(self.encoder.ln_post.beta, 0)
|
740 |
+
|
741 |
+
def apply_initialization(self):
|
742 |
+
self._initialize_weights()
|
743 |
+
|
744 |
+
def set_alignment_heads(self, dump: bytes):
|
745 |
+
array = np.frombuffer(
|
746 |
+
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
747 |
+
).copy()
|
748 |
+
mask = torch.from_numpy(array).reshape(
|
749 |
+
self.config.n_text_layer, self.config.n_text_head
|
750 |
+
)
|
751 |
+
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
|
752 |
+
|
753 |
+
def embed_audio(self, mel):
|
754 |
+
return self.encoder(mel)
|
755 |
+
|
756 |
+
def logits(self, labels, input_features):
|
757 |
+
return self.decoder(labels, input_features)
|
758 |
+
|
759 |
+
@property
|
760 |
+
def device(self):
|
761 |
+
return next(self.parameters()).device
|
762 |
+
|
763 |
+
@property
|
764 |
+
def is_multilingual(self):
|
765 |
+
return self.config.vocab_size >= len(tokenizer)
|
766 |
+
|
767 |
+
@property
|
768 |
+
def num_languages(self):
|
769 |
+
return self.config.vocab_size - (len(tokenizer)-100) - int(self.is_multilingual)
|
770 |
+
|
771 |
+
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
772 |
+
cache = {**cache} if cache is not None else {}
|
773 |
+
hooks = []
|
774 |
+
|
775 |
+
def save_to_cache(module, _, output):
|
776 |
+
if module not in cache or output.shape[1] > self.config.n_text_ctx:
|
777 |
+
cache[module] = output
|
778 |
+
else:
|
779 |
+
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
780 |
+
return cache[module]
|
781 |
+
|
782 |
+
def install_hooks(layer: nn.Module):
|
783 |
+
if isinstance(layer, MultiHeadAttention):
|
784 |
+
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
785 |
+
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
786 |
+
|
787 |
+
self.decoder.apply(install_hooks)
|
788 |
+
return cache, hooks
|
789 |
+
|
790 |
+
detect_language = detect_language_function
|
791 |
+
transcribe = transcribe_function
|
792 |
+
decode = decode_function
|
793 |
+
|
794 |
+
def get_encoder(self):
|
795 |
+
return self.encoder
|
796 |
+
|
797 |
+
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
798 |
+
return {'input_features': input_ids}
|
799 |
+
|
800 |
+
def _prepare_decoder_input_ids_for_generation(self, batch_size, decoder_start_token_id=None, bos_token_id=None):
|
801 |
+
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * self.config.decoder_start_token_id
|
802 |
+
|
803 |
+
def can_generate(self):
|
804 |
+
return True
|
805 |
+
|
806 |
+
def generate(self, inputs, **kwargs):
|
807 |
+
encoder_outputs = self.encoder(inputs)
|
808 |
+
decoder_input_ids = torch.zeros((inputs.size(0), 1), dtype=torch.long, device=inputs.device)
|
809 |
+
outputs = self.decoder(decoder_input_ids, encoder_outputs)
|
810 |
+
return outputs.argmax(dim=-1)
|
811 |
+
|
812 |
+
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="japanese", task="transcribe")
|
813 |
+
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="japanese", task="transcribe")
|
814 |
+
|
815 |
+
config = WhisperConfig(
|
816 |
+
n_mels=80,
|
817 |
+
n_audio_ctx=1500,
|
818 |
+
n_audio_state=1024,
|
819 |
+
n_audio_head=16,
|
820 |
+
n_audio_layer=20,
|
821 |
+
vocab_size=(len(tokenizer)),
|
822 |
+
n_text_ctx=448,
|
823 |
+
n_text_state=1024,
|
824 |
+
n_text_head=16,
|
825 |
+
n_text_layer=16,
|
826 |
+
max_rel_dist=10,
|
827 |
+
cross_attention=True,
|
828 |
+
checkpointing=True,
|
829 |
+
base=10000
|
830 |
+
)
|
831 |
+
|
832 |
+
model = Echo(config).to(device)
|
833 |
+
model.apply_initialization()
|
834 |
+
|
835 |
+
|
836 |
+
class CustomCallback(TrainerCallback):
|
837 |
+
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
838 |
+
print(f"Evaluation metrics at step {state.global_step}: {metrics}")
|
839 |
+
|
840 |
+
raw_datasets = IterableDatasetDict()
|
841 |
+
|
842 |
+
raw_datasets["train"] = load_dataset("mozilla-foundation/common_voice_17_0", "ja", split="train", trust_remote_code=True, streaming=True)
|
843 |
+
raw_datasets["test"] = load_dataset("mozilla-foundation/common_voice_17_0", "ja", split="test", trust_remote_code=True, streaming=True).take(100)
|
844 |
+
|
845 |
+
raw_datasets = raw_datasets.cast_column("audio", Audio(sampling_rate=16000))
|
846 |
+
|
847 |
+
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="japanese", task="transcribe")
|
848 |
+
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="japanese", task="transcribe")
|
849 |
+
|
850 |
+
def prepare_dataset(batch):
|
851 |
+
audio = batch["audio"]
|
852 |
+
batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
|
853 |
+
transcription = batch["sentence"]
|
854 |
+
batch["labels"] = processor.tokenizer(transcription).input_ids
|
855 |
+
return batch
|
856 |
+
|
857 |
+
vectorized_datasets = raw_datasets.map(prepare_dataset, remove_columns=list(next(iter(raw_datasets.values())).features)).with_format("torch")
|
858 |
+
|
859 |
+
@dataclass
|
860 |
+
class DataCollatorSpeechSeq2SeqWithPadding:
|
861 |
+
processor: Any
|
862 |
+
|
863 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
864 |
+
input_features = [{"input_features": feature["input_features"]} for feature in features]
|
865 |
+
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
866 |
+
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
867 |
+
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
|
868 |
+
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
869 |
+
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
|
870 |
+
labels = labels[:, 1:]
|
871 |
+
batch["labels"] = labels
|
872 |
+
return batch
|
873 |
+
|
874 |
+
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
|
875 |
+
|
876 |
+
metric = evaluate.load("cer")
|
877 |
+
|
878 |
+
def compute_metrics(pred):
|
879 |
+
pred_logits = pred.predictions
|
880 |
+
label_ids = pred.label_ids
|
881 |
+
|
882 |
+
if isinstance(pred_logits, tuple):
|
883 |
+
pred_ids = pred_logits[0]
|
884 |
+
else:
|
885 |
+
pred_ids = pred_logits
|
886 |
+
if pred_ids.ndim == 3:
|
887 |
+
pred_ids = np.argmax(pred_ids, axis=-1)
|
888 |
+
|
889 |
+
label_ids[label_ids == -100] = tokenizer.pad_token_id
|
890 |
+
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
891 |
+
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
892 |
+
cer = 100 * metric.compute(predictions=pred_str, references=label_str)
|
893 |
+
return {"cer": cer}
|
894 |
+
|
895 |
+
training_args = Seq2SeqTrainingArguments(
|
896 |
+
output_dir="./test",
|
897 |
+
per_device_train_batch_size=1,
|
898 |
+
per_device_eval_batch_size=1,
|
899 |
+
gradient_accumulation_steps=1,
|
900 |
+
eval_accumulation_steps=1,
|
901 |
+
num_train_epochs=1,
|
902 |
+
tf32=True,
|
903 |
+
bf16=True,
|
904 |
+
learning_rate=1e-5,
|
905 |
+
# warmup_steps=500,
|
906 |
+
evaluation_strategy="steps",
|
907 |
+
# predict_with_generate=True,
|
908 |
+
# generation_max_length=225,
|
909 |
+
max_steps=100,
|
910 |
+
save_steps=100,
|
911 |
+
eval_steps=10,
|
912 |
+
logging_steps=5,
|
913 |
+
report_to=["tensorboard"],
|
914 |
+
load_best_model_at_end=True,
|
915 |
+
metric_for_best_model="wer",
|
916 |
+
greater_is_better=False,
|
917 |
+
push_to_hub=False,
|
918 |
+
optim="adafactor",
|
919 |
+
weight_decay=0.0025,
|
920 |
+
disable_tqdm=False,
|
921 |
+
save_total_limit=2,
|
922 |
+
torch_empty_cache_steps=10,
|
923 |
+
)
|
924 |
+
|
925 |
+
trainer = Seq2SeqTrainer(
|
926 |
+
args=training_args,
|
927 |
+
model=model,
|
928 |
+
train_dataset=vectorized_datasets["train"],
|
929 |
+
eval_dataset=vectorized_datasets["test"],
|
930 |
+
data_collator=data_collator,
|
931 |
+
compute_metrics=compute_metrics,
|
932 |
+
tokenizer=processor,
|
933 |
+
)
|
934 |
+
|
935 |
+
trainer.add_callback(CustomCallback)
|
936 |
+
|
937 |
+
trainer.train()
|
938 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|