kevinwang676 commited on
Commit
be9690e
·
verified ·
1 Parent(s): 7cb2bfb

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .github/ISSUE_TEMPLATE/bug_report.md +38 -0
  2. .github/ISSUE_TEMPLATE/feature_request.md +20 -0
  3. .gitignore +49 -0
  4. .gitmodules +3 -0
  5. CODE_OF_CONDUCT.md +76 -0
  6. FAQ.md +16 -0
  7. LICENSE +201 -0
  8. README.md +159 -13
  9. asset/dingding.png +0 -0
  10. cosyvoice/__init__.py +0 -0
  11. cosyvoice/bin/inference.py +114 -0
  12. cosyvoice/bin/train.py +136 -0
  13. cosyvoice/cli/__init__.py +0 -0
  14. cosyvoice/cli/cosyvoice.py +83 -0
  15. cosyvoice/cli/frontend.py +168 -0
  16. cosyvoice/cli/model.py +60 -0
  17. cosyvoice/dataset/__init__.py +0 -0
  18. cosyvoice/dataset/dataset.py +160 -0
  19. cosyvoice/dataset/processor.py +369 -0
  20. cosyvoice/flow/decoder.py +222 -0
  21. cosyvoice/flow/flow.py +135 -0
  22. cosyvoice/flow/flow_matching.py +138 -0
  23. cosyvoice/flow/length_regulator.py +49 -0
  24. cosyvoice/hifigan/f0_predictor.py +55 -0
  25. cosyvoice/hifigan/generator.py +391 -0
  26. cosyvoice/llm/llm.py +206 -0
  27. cosyvoice/transformer/__init__.py +0 -0
  28. cosyvoice/transformer/activation.py +84 -0
  29. cosyvoice/transformer/attention.py +326 -0
  30. cosyvoice/transformer/convolution.py +145 -0
  31. cosyvoice/transformer/decoder.py +396 -0
  32. cosyvoice/transformer/decoder_layer.py +132 -0
  33. cosyvoice/transformer/embedding.py +293 -0
  34. cosyvoice/transformer/encoder.py +472 -0
  35. cosyvoice/transformer/encoder_layer.py +236 -0
  36. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  37. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  38. cosyvoice/transformer/subsampling.py +383 -0
  39. cosyvoice/utils/__init__.py +0 -0
  40. cosyvoice/utils/class_utils.py +70 -0
  41. cosyvoice/utils/common.py +103 -0
  42. cosyvoice/utils/executor.py +110 -0
  43. cosyvoice/utils/file_utils.py +41 -0
  44. cosyvoice/utils/frontend_utils.py +125 -0
  45. cosyvoice/utils/mask.py +227 -0
  46. cosyvoice/utils/scheduler.py +739 -0
  47. cosyvoice/utils/train_utils.py +289 -0
  48. examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml +198 -0
  49. examples/libritts/cosyvoice/conf/cosyvoice.yaml +198 -0
  50. examples/libritts/cosyvoice/conf/ds_stage2.json +42 -0
.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Bug report
3
+ about: Create a report to help us improve
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Describe the bug**
11
+ A clear and concise description of what the bug is.
12
+
13
+ **To Reproduce**
14
+ Steps to reproduce the behavior:
15
+ 1. Go to '...'
16
+ 2. Click on '....'
17
+ 3. Scroll down to '....'
18
+ 4. See error
19
+
20
+ **Expected behavior**
21
+ A clear and concise description of what you expected to happen.
22
+
23
+ **Screenshots**
24
+ If applicable, add screenshots to help explain your problem.
25
+
26
+ **Desktop (please complete the following information):**
27
+ - OS: [e.g. iOS]
28
+ - Browser [e.g. chrome, safari]
29
+ - Version [e.g. 22]
30
+
31
+ **Smartphone (please complete the following information):**
32
+ - Device: [e.g. iPhone6]
33
+ - OS: [e.g. iOS8.1]
34
+ - Browser [e.g. stock browser, safari]
35
+ - Version [e.g. 22]
36
+
37
+ **Additional context**
38
+ Add any other context about the problem here.
.github/ISSUE_TEMPLATE/feature_request.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Feature request
3
+ about: Suggest an idea for this project
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Is your feature request related to a problem? Please describe.**
11
+ A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12
+
13
+ **Describe the solution you'd like**
14
+ A clear and concise description of what you want to happen.
15
+
16
+ **Describe alternatives you've considered**
17
+ A clear and concise description of any alternative solutions or features you've considered.
18
+
19
+ **Additional context**
20
+ Add any other context or screenshots about the feature request here.
.gitignore ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # Visual Studio Code files
7
+ .vscode
8
+ .vs
9
+
10
+ # PyCharm files
11
+ .idea
12
+
13
+ # Eclipse Project settings
14
+ *.*project
15
+ .settings
16
+
17
+ # Sublime Text settings
18
+ *.sublime-workspace
19
+ *.sublime-project
20
+
21
+ # Editor temporaries
22
+ *.swn
23
+ *.swo
24
+ *.swp
25
+ *.swm
26
+ *~
27
+
28
+ # IPython notebook checkpoints
29
+ .ipynb_checkpoints
30
+
31
+ # macOS dir files
32
+ .DS_Store
33
+
34
+ exp
35
+ data
36
+ raw_wav
37
+ tensorboard
38
+ **/*build*
39
+
40
+ # Clangd files
41
+ .cache
42
+ compile_commands.json
43
+
44
+ # train/inference files
45
+ *.wav
46
+ *.pt
47
+ pretrained_models/*
48
+ *_pb2_grpc.py
49
+ *_pb2.py
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "third_party/Matcha-TTS"]
2
+ path = third_party/Matcha-TTS
3
+ url = https://github.com/shivammehta25/Matcha-TTS.git
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributor Covenant Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to making participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies both within project spaces and in public spaces
49
+ when an individual is representing the project or its community. Examples of
50
+ representing a project or community include using an official project e-mail
51
+ address, posting via an official social media account, or acting as an appointed
52
+ representative at an online or offline event. Representation of a project may be
53
+ further defined and clarified by project maintainers.
54
+
55
+ ## Enforcement
56
+
57
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
58
+ reported by contacting the project team at [email protected]. All
59
+ complaints will be reviewed and investigated and will result in a response that
60
+ is deemed necessary and appropriate to the circumstances. The project team is
61
+ obligated to maintain confidentiality with regard to the reporter of an incident.
62
+ Further details of specific enforcement policies may be posted separately.
63
+
64
+ Project maintainers who do not follow or enforce the Code of Conduct in good
65
+ faith may face temporary or permanent repercussions as determined by other
66
+ members of the project's leadership.
67
+
68
+ ## Attribution
69
+
70
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72
+
73
+ [homepage]: https://www.contributor-covenant.org
74
+
75
+ For answers to common questions about this code of conduct, see
76
+ https://www.contributor-covenant.org/faq
FAQ.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## ModuleNotFoundError: No module named 'matcha'
2
+
3
+ Matcha-TTS is a third_party module. Please check `third_party` directory. If there is no `Matcha-TTS`, execute `git submodule update --init --recursive`.
4
+
5
+ run `export PYTHONPATH=third_party/Matcha-TTS` if you want to use `from cosyvoice.cli.cosyvoice import CosyVoice` in python script.
6
+
7
+ ## cannot find resource.zip or cannot unzip resource.zip
8
+
9
+ Please make sure you have git-lfs installed. Execute
10
+
11
+ ```sh
12
+ git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
13
+ cd pretrained_models/CosyVoice-ttsfrd/
14
+ unzip resource.zip -d .
15
+ pip install ttsfrd-0.3.6-cp38-cp38-linux_x86_64.whl
16
+ ```
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,13 +1,159 @@
1
- ---
2
- title: CosyVoice Instruct
3
- emoji: 🔥
4
- colorFrom: green
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 4.38.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CosyVoice
2
+ ## 👉🏻 [CosyVoice Demos](https://fun-audio-llm.github.io/) 👈🏻
3
+ [[CosyVoice Paper](https://fun-audio-llm.github.io/pdf/CosyVoice_v1.pdf)][[CosyVoice Studio](https://www.modelscope.cn/studios/iic/CosyVoice-300M)][[CosyVoice Code](https://github.com/FunAudioLLM/CosyVoice)]
4
+
5
+ For `SenseVoice`, visit [SenseVoice repo](https://github.com/FunAudioLLM/SenseVoice) and [SenseVoice space](https://www.modelscope.cn/studios/iic/SenseVoice).
6
+
7
+ ## Install
8
+
9
+ **Clone and install**
10
+
11
+ - Clone the repo
12
+ ``` sh
13
+ git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
14
+ # If you failed to clone submodule due to network failures, please run following command until success
15
+ cd CosyVoice
16
+ git submodule update --init --recursive
17
+ ```
18
+
19
+ - Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
20
+ - Create Conda env:
21
+
22
+ ``` sh
23
+ conda create -n cosyvoice python=3.8
24
+ conda activate cosyvoice
25
+ # pynini is required by WeTextProcessing, use conda to install it as it can be executed on all platform.
26
+ conda install -y -c conda-forge pynini==2.1.5
27
+ pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
28
+
29
+ # If you encounter sox compatibility issues
30
+ # ubuntu
31
+ sudo apt-get install sox libsox-dev
32
+ # centos
33
+ sudo yum install sox sox-devel
34
+ ```
35
+
36
+ **Model download**
37
+
38
+ We strongly recommend that you download our pretrained `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
39
+
40
+ If you are expert in this field, and you are only interested in training your own CosyVoice model from scratch, you can skip this step.
41
+
42
+ ``` python
43
+ # SDK模型下载
44
+ from modelscope import snapshot_download
45
+ snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
46
+ snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
47
+ snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
48
+ snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
49
+ ```
50
+
51
+ ``` sh
52
+ # git模型下载,请确保已安装git lfs
53
+ mkdir -p pretrained_models
54
+ git clone https://www.modelscope.cn/iic/CosyVoice-300M.git pretrained_models/CosyVoice-300M
55
+ git clone https://www.modelscope.cn/iic/CosyVoice-300M-SFT.git pretrained_models/CosyVoice-300M-SFT
56
+ git clone https://www.modelscope.cn/iic/CosyVoice-300M-Instruct.git pretrained_models/CosyVoice-300M-Instruct
57
+ git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
58
+ ```
59
+
60
+ Optionaly, you can unzip `ttsfrd` resouce and install `ttsfrd` package for better text normalization performance.
61
+
62
+ Notice that this step is not necessary. If you do not install `ttsfrd` package, we will use WeTextProcessing by default.
63
+
64
+ ``` sh
65
+ cd pretrained_models/CosyVoice-ttsfrd/
66
+ unzip resource.zip -d .
67
+ pip install ttsfrd-0.3.6-cp38-cp38-linux_x86_64.whl
68
+ ```
69
+
70
+ **Basic Usage**
71
+
72
+ For zero_shot/cross_lingual inference, please use `CosyVoice-300M` model.
73
+ For sft inference, please use `CosyVoice-300M-SFT` model.
74
+ For instruct inference, please use `CosyVoice-300M-Instruct` model.
75
+ First, add `third_party/Matcha-TTS` to your `PYTHONPATH`.
76
+
77
+ ``` sh
78
+ export PYTHONPATH=third_party/Matcha-TTS
79
+ ```
80
+
81
+ ``` python
82
+ from cosyvoice.cli.cosyvoice import CosyVoice
83
+ from cosyvoice.utils.file_utils import load_wav
84
+ import torchaudio
85
+
86
+ cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT')
87
+ # sft usage
88
+ print(cosyvoice.list_avaliable_spks())
89
+ output = cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女')
90
+ torchaudio.save('sft.wav', output['tts_speech'], 22050)
91
+
92
+ cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M')
93
+ # zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
94
+ prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
95
+ output = cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k)
96
+ torchaudio.save('zero_shot.wav', output['tts_speech'], 22050)
97
+ # cross_lingual usage
98
+ prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
99
+ output = cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k)
100
+ torchaudio.save('cross_lingual.wav', output['tts_speech'], 22050)
101
+
102
+ cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
103
+ # instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
104
+ output = cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
105
+ torchaudio.save('instruct.wav', output['tts_speech'], 22050)
106
+ ```
107
+
108
+ **Start web demo**
109
+
110
+ You can use our web demo page to get familiar with CosyVoice quickly.
111
+ We support sft/zero_shot/cross_lingual/instruct inference in web demo.
112
+
113
+ Please see the demo website for details.
114
+
115
+ ``` python
116
+ # change iic/CosyVoice-300M-SFT for sft inference, or iic/CosyVoice-300M-Instruct for instruct inference
117
+ python3 webui.py --port 50000 --model_dir pretrained_models/CosyVoice-300M
118
+ ```
119
+
120
+ **Advanced Usage**
121
+
122
+ For advanced user, we have provided train and inference scripts in `examples/libritts/cosyvoice/run.sh`.
123
+ You can get familiar with CosyVoice following this recipie.
124
+
125
+ **Build for deployment**
126
+
127
+ Optionally, if you want to use grpc for service deployment,
128
+ you can run following steps. Otherwise, you can just ignore this step.
129
+
130
+ ``` sh
131
+ cd runtime/python
132
+ docker build -t cosyvoice:v1.0 .
133
+ # change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
134
+ # for grpc usage
135
+ docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
136
+ python3 grpc/client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
137
+ # for fastapi usage
138
+ docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && MODEL_DIR=iic/CosyVoice-300M fastapi dev --port 50000 server.py && sleep infinity"
139
+ python3 fastapi/client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
140
+ ```
141
+
142
+ ## Discussion & Communication
143
+
144
+ You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).
145
+
146
+ You can also scan the QR code to join our official Dingding chat group.
147
+
148
+ <img src="./asset/dingding.png" width="250px">
149
+
150
+ ## Acknowledge
151
+
152
+ 1. We borrowed a lot of code from [FunASR](https://github.com/modelscope/FunASR).
153
+ 2. We borrowed a lot of code from [FunCodec](https://github.com/modelscope/FunCodec).
154
+ 3. We borrowed a lot of code from [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS).
155
+ 4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
156
+ 5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
157
+
158
+ ## Disclaimer
159
+ The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
asset/dingding.png ADDED
cosyvoice/__init__.py ADDED
File without changes
cosyvoice/bin/inference.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ import os
21
+
22
+ import torch
23
+ from torch.utils.data import DataLoader
24
+ import torchaudio
25
+ from hyperpyyaml import load_hyperpyyaml
26
+ from tqdm import tqdm
27
+ from cosyvoice.cli.model import CosyVoiceModel
28
+
29
+ from cosyvoice.dataset.dataset import Dataset
30
+
31
+ def get_args():
32
+ parser = argparse.ArgumentParser(description='inference with your model')
33
+ parser.add_argument('--config', required=True, help='config file')
34
+ parser.add_argument('--prompt_data', required=True, help='prompt data file')
35
+ parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
36
+ parser.add_argument('--tts_text', required=True, help='tts input file')
37
+ parser.add_argument('--llm_model', required=True, help='llm model file')
38
+ parser.add_argument('--flow_model', required=True, help='flow model file')
39
+ parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
40
+ parser.add_argument('--gpu',
41
+ type=int,
42
+ default=-1,
43
+ help='gpu id for this rank, -1 for cpu')
44
+ parser.add_argument('--mode',
45
+ default='sft',
46
+ choices=['sft', 'zero_shot'],
47
+ help='inference mode')
48
+ parser.add_argument('--result_dir', required=True, help='asr result file')
49
+ args = parser.parse_args()
50
+ print(args)
51
+ return args
52
+
53
+
54
+ def main():
55
+ args = get_args()
56
+ logging.basicConfig(level=logging.DEBUG,
57
+ format='%(asctime)s %(levelname)s %(message)s')
58
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
59
+
60
+ # Init cosyvoice models from configs
61
+ use_cuda = args.gpu >= 0 and torch.cuda.is_available()
62
+ device = torch.device('cuda' if use_cuda else 'cpu')
63
+ with open(args.config, 'r') as f:
64
+ configs = load_hyperpyyaml(f)
65
+
66
+ model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
67
+ model.load(args.llm_model, args.flow_model, args.hifigan_model)
68
+
69
+ test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
70
+ test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
71
+
72
+ del configs
73
+ os.makedirs(args.result_dir, exist_ok=True)
74
+ fn = os.path.join(args.result_dir, 'wav.scp')
75
+ f = open(fn, 'w')
76
+ with torch.no_grad():
77
+ for batch_idx, batch in tqdm(enumerate(test_data_loader)):
78
+ utts = batch["utts"]
79
+ assert len(utts) == 1, "inference mode only support batchsize 1"
80
+ text = batch["text"]
81
+ text_token = batch["text_token"].to(device)
82
+ text_token_len = batch["text_token_len"].to(device)
83
+ tts_text = batch["tts_text"]
84
+ tts_index = batch["tts_index"]
85
+ tts_text_token = batch["tts_text_token"].to(device)
86
+ tts_text_token_len = batch["tts_text_token_len"].to(device)
87
+ speech_token = batch["speech_token"].to(device)
88
+ speech_token_len = batch["speech_token_len"].to(device)
89
+ speech_feat = batch["speech_feat"].to(device)
90
+ speech_feat_len = batch["speech_feat_len"].to(device)
91
+ utt_embedding = batch["utt_embedding"].to(device)
92
+ spk_embedding = batch["spk_embedding"].to(device)
93
+ if args.mode == 'sft':
94
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
95
+ 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
96
+ else:
97
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
98
+ 'prompt_text': text_token, 'prompt_text_len': text_token_len,
99
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
100
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
101
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
102
+ 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
103
+ model_output = model.inference(**model_input)
104
+ tts_key = '{}_{}'.format(utts[0], tts_index[0])
105
+ tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
106
+ torchaudio.save(tts_fn, model_output['tts_speech'], sample_rate=22050)
107
+ f.write('{} {}\n'.format(tts_key, tts_fn))
108
+ f.flush()
109
+ f.close()
110
+ logging.info('Result wav.scp saved in {}'.format(fn))
111
+
112
+
113
+ if __name__ == '__main__':
114
+ main()
cosyvoice/bin/train.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+ import argparse
17
+ import datetime
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ from copy import deepcopy
21
+ import torch
22
+ import torch.distributed as dist
23
+ import deepspeed
24
+
25
+ from hyperpyyaml import load_hyperpyyaml
26
+
27
+ from torch.distributed.elastic.multiprocessing.errors import record
28
+
29
+ from cosyvoice.utils.executor import Executor
30
+ from cosyvoice.utils.train_utils import (
31
+ init_distributed,
32
+ init_dataset_and_dataloader,
33
+ init_optimizer_and_scheduler,
34
+ init_summarywriter, save_model,
35
+ wrap_cuda_model, check_modify_and_save_config)
36
+
37
+
38
+ def get_args():
39
+ parser = argparse.ArgumentParser(description='training your network')
40
+ parser.add_argument('--train_engine',
41
+ default='torch_ddp',
42
+ choices=['torch_ddp', 'deepspeed'],
43
+ help='Engine for paralleled training')
44
+ parser.add_argument('--model', required=True, help='model which will be trained')
45
+ parser.add_argument('--config', required=True, help='config file')
46
+ parser.add_argument('--train_data', required=True, help='train data file')
47
+ parser.add_argument('--cv_data', required=True, help='cv data file')
48
+ parser.add_argument('--checkpoint', help='checkpoint model')
49
+ parser.add_argument('--model_dir', required=True, help='save model dir')
50
+ parser.add_argument('--tensorboard_dir',
51
+ default='tensorboard',
52
+ help='tensorboard log dir')
53
+ parser.add_argument('--ddp.dist_backend',
54
+ dest='dist_backend',
55
+ default='nccl',
56
+ choices=['nccl', 'gloo'],
57
+ help='distributed backend')
58
+ parser.add_argument('--num_workers',
59
+ default=0,
60
+ type=int,
61
+ help='num of subprocess workers for reading')
62
+ parser.add_argument('--prefetch',
63
+ default=100,
64
+ type=int,
65
+ help='prefetch number')
66
+ parser.add_argument('--pin_memory',
67
+ action='store_true',
68
+ default=False,
69
+ help='Use pinned memory buffers used for reading')
70
+ parser.add_argument('--deepspeed.save_states',
71
+ dest='save_states',
72
+ default='model_only',
73
+ choices=['model_only', 'model+optimizer'],
74
+ help='save model/optimizer states')
75
+ parser.add_argument('--timeout',
76
+ default=30,
77
+ type=int,
78
+ help='timeout (in seconds) of cosyvoice_join.')
79
+ parser = deepspeed.add_config_arguments(parser)
80
+ args = parser.parse_args()
81
+ return args
82
+
83
+
84
+ @record
85
+ def main():
86
+ args = get_args()
87
+ logging.basicConfig(level=logging.DEBUG,
88
+ format='%(asctime)s %(levelname)s %(message)s')
89
+
90
+ override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model}
91
+ with open(args.config, 'r') as f:
92
+ configs = load_hyperpyyaml(f, overrides=override_dict)
93
+ configs['train_conf'].update(vars(args))
94
+
95
+ # Init env for ddp
96
+ init_distributed(args)
97
+
98
+ # Get dataset & dataloader
99
+ train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
100
+ init_dataset_and_dataloader(args, configs)
101
+
102
+ # Do some sanity checks and save config to arsg.model_dir
103
+ configs = check_modify_and_save_config(args, configs)
104
+
105
+ # Tensorboard summary
106
+ writer = init_summarywriter(args)
107
+
108
+ # load checkpoint
109
+ model = configs[args.model]
110
+ if args.checkpoint is not None:
111
+ model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'))
112
+
113
+ # Dispatch model from cpu to gpu
114
+ model = wrap_cuda_model(args, model)
115
+
116
+ # Get optimizer & scheduler
117
+ model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model)
118
+
119
+ # Save init checkpoints
120
+ info_dict = deepcopy(configs['train_conf'])
121
+ save_model(model, 'init', info_dict)
122
+
123
+ # Get executor
124
+ executor = Executor()
125
+
126
+ # Start training loop
127
+ for epoch in range(info_dict['max_epoch']):
128
+ executor.epoch = epoch
129
+ train_dataset.set_epoch(epoch)
130
+ dist.barrier()
131
+ group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
132
+ executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
133
+ dist.destroy_process_group(group_join)
134
+
135
+ if __name__ == '__main__':
136
+ main()
cosyvoice/cli/__init__.py ADDED
File without changes
cosyvoice/cli/cosyvoice.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import torch
16
+ from hyperpyyaml import load_hyperpyyaml
17
+ from modelscope import snapshot_download
18
+ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
19
+ from cosyvoice.cli.model import CosyVoiceModel
20
+
21
+ class CosyVoice:
22
+
23
+ def __init__(self, model_dir):
24
+ instruct = True if '-Instruct' in model_dir else False
25
+ self.model_dir = model_dir
26
+ if not os.path.exists(model_dir):
27
+ model_dir = snapshot_download(model_dir)
28
+ with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
29
+ configs = load_hyperpyyaml(f)
30
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
31
+ configs['feat_extractor'],
32
+ '{}/campplus.onnx'.format(model_dir),
33
+ '{}/speech_tokenizer_v1.onnx'.format(model_dir),
34
+ '{}/spk2info.pt'.format(model_dir),
35
+ instruct,
36
+ configs['allowed_special'])
37
+ self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
38
+ self.model.load('{}/llm.pt'.format(model_dir),
39
+ '{}/flow.pt'.format(model_dir),
40
+ '{}/hift.pt'.format(model_dir))
41
+ del configs
42
+
43
+ def list_avaliable_spks(self):
44
+ spks = list(self.frontend.spk2info.keys())
45
+ return spks
46
+
47
+ def inference_sft(self, tts_text, spk_id):
48
+ tts_speeches = []
49
+ for i in self.frontend.text_normalize(tts_text, split=True):
50
+ model_input = self.frontend.frontend_sft(i, spk_id)
51
+ model_output = self.model.inference(**model_input)
52
+ tts_speeches.append(model_output['tts_speech'])
53
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
54
+
55
+ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
56
+ prompt_text = self.frontend.text_normalize(prompt_text, split=False)
57
+ tts_speeches = []
58
+ for i in self.frontend.text_normalize(tts_text, split=True):
59
+ model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
60
+ model_output = self.model.inference(**model_input)
61
+ tts_speeches.append(model_output['tts_speech'])
62
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
63
+
64
+ def inference_cross_lingual(self, tts_text, prompt_speech_16k):
65
+ if self.frontend.instruct is True:
66
+ raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
67
+ tts_speeches = []
68
+ for i in self.frontend.text_normalize(tts_text, split=True):
69
+ model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
70
+ model_output = self.model.inference(**model_input)
71
+ tts_speeches.append(model_output['tts_speech'])
72
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
73
+
74
+ def inference_instruct(self, tts_text, spk_id, instruct_text):
75
+ if self.frontend.instruct is False:
76
+ raise ValueError('{} do not support instruct inference'.format(self.model_dir))
77
+ instruct_text = self.frontend.text_normalize(instruct_text, split=False)
78
+ tts_speeches = []
79
+ for i in self.frontend.text_normalize(tts_text, split=True):
80
+ model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
81
+ model_output = self.model.inference(**model_input)
82
+ tts_speeches.append(model_output['tts_speech'])
83
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
cosyvoice/cli/frontend.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from functools import partial
15
+ import onnxruntime
16
+ import torch
17
+ import numpy as np
18
+ import whisper
19
+ from typing import Callable
20
+ import torchaudio.compliance.kaldi as kaldi
21
+ import torchaudio
22
+ import os
23
+ import re
24
+ import inflect
25
+ try:
26
+ import ttsfrd
27
+ use_ttsfrd = True
28
+ except ImportError:
29
+ print("failed to import ttsfrd, use WeTextProcessing instead")
30
+ from tn.chinese.normalizer import Normalizer as ZhNormalizer
31
+ from tn.english.normalizer import Normalizer as EnNormalizer
32
+ use_ttsfrd = False
33
+ from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
34
+
35
+
36
+ class CosyVoiceFrontEnd:
37
+
38
+ def __init__(self,
39
+ get_tokenizer: Callable,
40
+ feat_extractor: Callable,
41
+ campplus_model: str,
42
+ speech_tokenizer_model: str,
43
+ spk2info: str = '',
44
+ instruct: bool = False,
45
+ allowed_special: str = 'all'):
46
+ self.tokenizer = get_tokenizer()
47
+ self.feat_extractor = feat_extractor
48
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
49
+ option = onnxruntime.SessionOptions()
50
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
51
+ option.intra_op_num_threads = 1
52
+ self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
53
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider"if torch.cuda.is_available() else "CPUExecutionProvider"])
54
+ if os.path.exists(spk2info):
55
+ self.spk2info = torch.load(spk2info, map_location=self.device)
56
+ self.instruct = instruct
57
+ self.allowed_special = allowed_special
58
+ self.inflect_parser = inflect.engine()
59
+ self.use_ttsfrd = use_ttsfrd
60
+ if self.use_ttsfrd:
61
+ self.frd = ttsfrd.TtsFrontendEngine()
62
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
63
+ assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, 'failed to initialize ttsfrd resource'
64
+ self.frd.set_lang_type('pinyin')
65
+ self.frd.enable_pinyin_mix(True)
66
+ self.frd.set_breakmodel_index(1)
67
+ else:
68
+ self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False)
69
+ self.en_tn_model = EnNormalizer()
70
+
71
+ def _extract_text_token(self, text):
72
+ text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
73
+ text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
74
+ text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
75
+ return text_token, text_token_len
76
+
77
+ def _extract_speech_token(self, speech):
78
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
79
+ speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
80
+ self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
81
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
82
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
83
+ return speech_token, speech_token_len
84
+
85
+ def _extract_spk_embedding(self, speech):
86
+ feat = kaldi.fbank(speech,
87
+ num_mel_bins=80,
88
+ dither=0,
89
+ sample_frequency=16000)
90
+ feat = feat - feat.mean(dim=0, keepdim=True)
91
+ embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
92
+ embedding = torch.tensor([embedding]).to(self.device)
93
+ return embedding
94
+
95
+ def _extract_speech_feat(self, speech):
96
+ speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
97
+ speech_feat = speech_feat.unsqueeze(dim=0)
98
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
99
+ return speech_feat, speech_feat_len
100
+
101
+ def text_normalize(self, text, split=True):
102
+ text = text.strip()
103
+ if contains_chinese(text):
104
+ if self.use_ttsfrd:
105
+ text = self.frd.get_frd_extra_info(text, 'input')
106
+ else:
107
+ text = self.zh_tn_model.normalize(text)
108
+ text = text.replace("\n", "")
109
+ text = replace_blank(text)
110
+ text = replace_corner_mark(text)
111
+ text = text.replace(".", "、")
112
+ text = text.replace(" - ", ",")
113
+ text = remove_bracket(text)
114
+ text = re.sub(r'[,,]+$', '。', text)
115
+ texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
116
+ token_min_n=60, merge_len=20,
117
+ comma_split=False)]
118
+ else:
119
+ if self.use_ttsfrd:
120
+ text = self.frd.get_frd_extra_info(text, 'input')
121
+ else:
122
+ text = self.en_tn_model.normalize(text)
123
+ text = spell_out_number(text, self.inflect_parser)
124
+ texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
125
+ token_min_n=60, merge_len=20,
126
+ comma_split=False)]
127
+ if split is False:
128
+ return text
129
+ return texts
130
+
131
+ def frontend_sft(self, tts_text, spk_id):
132
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
133
+ embedding = self.spk2info[spk_id]['embedding']
134
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
135
+ return model_input
136
+
137
+ def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
138
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
139
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
140
+ prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
141
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
142
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
143
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
144
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
145
+ 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
146
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
147
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
148
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
149
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
150
+ return model_input
151
+
152
+ def frontend_cross_lingual(self, tts_text, prompt_speech_16k):
153
+ model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k)
154
+ # in cross lingual mode, we remove prompt in llm
155
+ del model_input['prompt_text']
156
+ del model_input['prompt_text_len']
157
+ del model_input['llm_prompt_speech_token']
158
+ del model_input['llm_prompt_speech_token_len']
159
+ return model_input
160
+
161
+ def frontend_instruct(self, tts_text, spk_id, instruct_text):
162
+ model_input = self.frontend_sft(tts_text, spk_id)
163
+ # in instruct mode, we remove spk_embedding in llm due to information leakage
164
+ del model_input['llm_embedding']
165
+ instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
166
+ model_input['prompt_text'] = instruct_text_token
167
+ model_input['prompt_text_len'] = instruct_text_token_len
168
+ return model_input
cosyvoice/cli/model.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+
16
+ class CosyVoiceModel:
17
+
18
+ def __init__(self,
19
+ llm: torch.nn.Module,
20
+ flow: torch.nn.Module,
21
+ hift: torch.nn.Module):
22
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
+ self.llm = llm
24
+ self.flow = flow
25
+ self.hift = hift
26
+
27
+ def load(self, llm_model, flow_model, hift_model):
28
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
29
+ self.llm.to(self.device).eval()
30
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
31
+ self.flow.to(self.device).eval()
32
+ self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
33
+ self.hift.to(self.device).eval()
34
+
35
+ def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
36
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
37
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
38
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
39
+ prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
40
+ tts_speech_token = self.llm.inference(text=text.to(self.device),
41
+ text_len=text_len.to(self.device),
42
+ prompt_text=prompt_text.to(self.device),
43
+ prompt_text_len=prompt_text_len.to(self.device),
44
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
45
+ prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
46
+ embedding=llm_embedding.to(self.device),
47
+ beam_size=1,
48
+ sampling=25,
49
+ max_token_text_ratio=30,
50
+ min_token_text_ratio=3)
51
+ tts_mel = self.flow.inference(token=tts_speech_token,
52
+ token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
53
+ prompt_token=flow_prompt_speech_token.to(self.device),
54
+ prompt_token_len=flow_prompt_speech_token_len.to(self.device),
55
+ prompt_feat=prompt_speech_feat.to(self.device),
56
+ prompt_feat_len=prompt_speech_feat_len.to(self.device),
57
+ embedding=flow_embedding.to(self.device))
58
+ tts_speech = self.hift.inference(mel=tts_mel).cpu()
59
+ torch.cuda.empty_cache()
60
+ return {'tts_speech': tts_speech}
cosyvoice/dataset/__init__.py ADDED
File without changes
cosyvoice/dataset/dataset.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import random
17
+ import json
18
+ import math
19
+ from functools import partial
20
+
21
+ import torch
22
+ import torch.distributed as dist
23
+ from torch.utils.data import IterableDataset
24
+ from cosyvoice.utils.file_utils import read_lists, read_json_lists
25
+
26
+
27
+ class Processor(IterableDataset):
28
+
29
+ def __init__(self, source, f, *args, **kw):
30
+ assert callable(f)
31
+ self.source = source
32
+ self.f = f
33
+ self.args = args
34
+ self.kw = kw
35
+
36
+ def set_epoch(self, epoch):
37
+ self.source.set_epoch(epoch)
38
+
39
+ def __iter__(self):
40
+ """ Return an iterator over the source dataset processed by the
41
+ given processor.
42
+ """
43
+ assert self.source is not None
44
+ assert callable(self.f)
45
+ return self.f(iter(self.source), *self.args, **self.kw)
46
+
47
+ def apply(self, f):
48
+ assert callable(f)
49
+ return Processor(self, f, *self.args, **self.kw)
50
+
51
+
52
+ class DistributedSampler:
53
+
54
+ def __init__(self, shuffle=True, partition=True):
55
+ self.epoch = -1
56
+ self.update()
57
+ self.shuffle = shuffle
58
+ self.partition = partition
59
+
60
+ def update(self):
61
+ assert dist.is_available()
62
+ if dist.is_initialized():
63
+ self.rank = dist.get_rank()
64
+ self.world_size = dist.get_world_size()
65
+ else:
66
+ self.rank = 0
67
+ self.world_size = 1
68
+ worker_info = torch.utils.data.get_worker_info()
69
+ if worker_info is None:
70
+ self.worker_id = 0
71
+ self.num_workers = 1
72
+ else:
73
+ self.worker_id = worker_info.id
74
+ self.num_workers = worker_info.num_workers
75
+ return dict(rank=self.rank,
76
+ world_size=self.world_size,
77
+ worker_id=self.worker_id,
78
+ num_workers=self.num_workers)
79
+
80
+ def set_epoch(self, epoch):
81
+ self.epoch = epoch
82
+
83
+ def sample(self, data):
84
+ """ Sample data according to rank/world_size/num_workers
85
+
86
+ Args:
87
+ data(List): input data list
88
+
89
+ Returns:
90
+ List: data list after sample
91
+ """
92
+ data = list(range(len(data)))
93
+ # force datalist even
94
+ if self.partition:
95
+ if self.shuffle:
96
+ random.Random(self.epoch).shuffle(data)
97
+ if len(data) < self.world_size:
98
+ data = data * math.ceil(self.world_size / len(data))
99
+ data = data[:self.world_size]
100
+ data = data[self.rank::self.world_size]
101
+ if len(data) < self.num_workers:
102
+ data = data * math.ceil(self.num_workers / len(data))
103
+ data = data[:self.num_workers]
104
+ data = data[self.worker_id::self.num_workers]
105
+ return data
106
+
107
+
108
+ class DataList(IterableDataset):
109
+
110
+ def __init__(self, lists, shuffle=True, partition=True):
111
+ self.lists = lists
112
+ self.sampler = DistributedSampler(shuffle, partition)
113
+
114
+ def set_epoch(self, epoch):
115
+ self.sampler.set_epoch(epoch)
116
+
117
+ def __iter__(self):
118
+ sampler_info = self.sampler.update()
119
+ indexes = self.sampler.sample(self.lists)
120
+ for index in indexes:
121
+ data = dict(src=self.lists[index])
122
+ data.update(sampler_info)
123
+ yield data
124
+
125
+
126
+ def Dataset(data_list_file,
127
+ data_pipeline,
128
+ mode='train',
129
+ shuffle=True,
130
+ partition=True,
131
+ tts_file='',
132
+ prompt_utt2data=''):
133
+ """ Construct dataset from arguments
134
+
135
+ We have two shuffle stage in the Dataset. The first is global
136
+ shuffle at shards tar/raw file level. The second is global shuffle
137
+ at training samples level.
138
+
139
+ Args:
140
+ data_type(str): raw/shard
141
+ tokenizer (BaseTokenizer): tokenizer to tokenize
142
+ partition(bool): whether to do data partition in terms of rank
143
+ """
144
+ assert mode in ['train', 'inference']
145
+ lists = read_lists(data_list_file)
146
+ if mode == 'inference':
147
+ with open(tts_file) as f:
148
+ tts_data = json.load(f)
149
+ utt2lists = read_json_lists(prompt_utt2data)
150
+ # filter unnecessary file in inference mode
151
+ lists = list(set([utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists]))
152
+ dataset = DataList(lists,
153
+ shuffle=shuffle,
154
+ partition=partition)
155
+ if mode == 'inference':
156
+ # map partial arg tts_data in inference mode
157
+ data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
158
+ for func in data_pipeline:
159
+ dataset = Processor(dataset, func, mode=mode)
160
+ return dataset
cosyvoice/dataset/processor.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+
17
+ import pyarrow.parquet as pq
18
+ from io import BytesIO
19
+ import torch
20
+ import torchaudio
21
+ from torch.nn.utils.rnn import pad_sequence
22
+ import torch.nn.functional as F
23
+
24
+ torchaudio.set_audio_backend('soundfile')
25
+
26
+ AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
27
+
28
+
29
+ def parquet_opener(data, mode='train', tts_data={}):
30
+ """ Give url or local file, return file descriptor
31
+ Inplace operation.
32
+
33
+ Args:
34
+ data(Iterable[str]): url or local file list
35
+
36
+ Returns:
37
+ Iterable[{src, stream}]
38
+ """
39
+ for sample in data:
40
+ assert 'src' in sample
41
+ url = sample['src']
42
+ try:
43
+ df = pq.read_table(url).to_pandas()
44
+ for i in range(len(df)):
45
+ if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
46
+ continue
47
+ sample.update(dict(df.loc[i]))
48
+ if mode == 'train':
49
+ # NOTE do not return sample directly, must initialize a new dict
50
+ yield {**sample}
51
+ else:
52
+ for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
53
+ yield {**sample, 'tts_index': index, 'tts_text': text}
54
+ except Exception as ex:
55
+ logging.warning('Failed to open {}, ex info {}'.format(url, ex))
56
+
57
+ def filter(data,
58
+ max_length=10240,
59
+ min_length=10,
60
+ token_max_length=200,
61
+ token_min_length=1,
62
+ min_output_input_ratio=0.0005,
63
+ max_output_input_ratio=1,
64
+ mode='train'):
65
+ """ Filter sample according to feature and label length
66
+ Inplace operation.
67
+
68
+ Args::
69
+ data: Iterable[{key, wav, label, sample_rate}]
70
+ max_length: drop utterance which is greater than max_length(10ms)
71
+ min_length: drop utterance which is less than min_length(10ms)
72
+ token_max_length: drop utterance which is greater than
73
+ token_max_length, especially when use char unit for
74
+ english modeling
75
+ token_min_length: drop utterance which is
76
+ less than token_max_length
77
+ min_output_input_ratio: minimal ration of
78
+ token_length / feats_length(10ms)
79
+ max_output_input_ratio: maximum ration of
80
+ token_length / feats_length(10ms)
81
+
82
+ Returns:
83
+ Iterable[{key, wav, label, sample_rate}]
84
+ """
85
+ for sample in data:
86
+ sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
87
+ del sample['audio_data']
88
+ # sample['wav'] is torch.Tensor, we have 100 frames every second
89
+ num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
90
+ if num_frames < min_length:
91
+ continue
92
+ if num_frames > max_length:
93
+ continue
94
+ if len(sample['text_token']) < token_min_length:
95
+ continue
96
+ if len(sample['text_token']) > token_max_length:
97
+ continue
98
+ if len(sample['speech_token']) == 0:
99
+ continue
100
+ if num_frames != 0:
101
+ if len(sample['text_token']) / num_frames < min_output_input_ratio:
102
+ continue
103
+ if len(sample['text_token']) / num_frames > max_output_input_ratio:
104
+ continue
105
+ yield sample
106
+
107
+
108
+ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
109
+ """ Resample data.
110
+ Inplace operation.
111
+
112
+ Args:
113
+ data: Iterable[{key, wav, label, sample_rate}]
114
+ resample_rate: target resample rate
115
+
116
+ Returns:
117
+ Iterable[{key, wav, label, sample_rate}]
118
+ """
119
+ for sample in data:
120
+ assert 'sample_rate' in sample
121
+ assert 'speech' in sample
122
+ sample_rate = sample['sample_rate']
123
+ waveform = sample['speech']
124
+ if sample_rate != resample_rate:
125
+ if sample_rate < min_sample_rate:
126
+ continue
127
+ sample['sample_rate'] = resample_rate
128
+ sample['speech'] = torchaudio.transforms.Resample(
129
+ orig_freq=sample_rate, new_freq=resample_rate)(waveform)
130
+ max_val = sample['speech'].abs().max()
131
+ if max_val > 1:
132
+ sample['speech'] /= max_val
133
+ yield sample
134
+
135
+
136
+ def compute_fbank(data,
137
+ feat_extractor,
138
+ mode='train'):
139
+ """ Extract fbank
140
+
141
+ Args:
142
+ data: Iterable[{key, wav, label, sample_rate}]
143
+
144
+ Returns:
145
+ Iterable[{key, feat, label}]
146
+ """
147
+ for sample in data:
148
+ assert 'sample_rate' in sample
149
+ assert 'speech' in sample
150
+ assert 'utt' in sample
151
+ assert 'text_token' in sample
152
+ waveform = sample['speech']
153
+ mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
154
+ sample['speech_feat'] = mat
155
+ del sample['speech']
156
+ yield sample
157
+
158
+
159
+ def parse_embedding(data, normalize, mode='train'):
160
+ """ Parse utt_embedding/spk_embedding
161
+
162
+ Args:
163
+ data: Iterable[{key, wav, label, sample_rate}]
164
+
165
+ Returns:
166
+ Iterable[{key, feat, label}]
167
+ """
168
+ for sample in data:
169
+ sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
170
+ sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
171
+ if normalize:
172
+ sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
173
+ sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
174
+ yield sample
175
+
176
+
177
+ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
178
+ """ Decode text to chars or BPE
179
+ Inplace operation
180
+
181
+ Args:
182
+ data: Iterable[{key, wav, txt, sample_rate}]
183
+
184
+ Returns:
185
+ Iterable[{key, wav, txt, tokens, label, sample_rate}]
186
+ """
187
+ tokenizer = get_tokenizer()
188
+ for sample in data:
189
+ assert 'text' in sample
190
+ sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
191
+ if mode == 'inference':
192
+ sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
193
+ yield sample
194
+
195
+
196
+ def shuffle(data, shuffle_size=10000, mode='train'):
197
+ """ Local shuffle the data
198
+
199
+ Args:
200
+ data: Iterable[{key, feat, label}]
201
+ shuffle_size: buffer size for shuffle
202
+
203
+ Returns:
204
+ Iterable[{key, feat, label}]
205
+ """
206
+ buf = []
207
+ for sample in data:
208
+ buf.append(sample)
209
+ if len(buf) >= shuffle_size:
210
+ random.shuffle(buf)
211
+ for x in buf:
212
+ yield x
213
+ buf = []
214
+ # The sample left over
215
+ random.shuffle(buf)
216
+ for x in buf:
217
+ yield x
218
+
219
+
220
+ def sort(data, sort_size=500, mode='train'):
221
+ """ Sort the data by feature length.
222
+ Sort is used after shuffle and before batch, so we can group
223
+ utts with similar lengths into a batch, and `sort_size` should
224
+ be less than `shuffle_size`
225
+
226
+ Args:
227
+ data: Iterable[{key, feat, label}]
228
+ sort_size: buffer size for sort
229
+
230
+ Returns:
231
+ Iterable[{key, feat, label}]
232
+ """
233
+
234
+ buf = []
235
+ for sample in data:
236
+ buf.append(sample)
237
+ if len(buf) >= sort_size:
238
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
239
+ for x in buf:
240
+ yield x
241
+ buf = []
242
+ # The sample left over
243
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
244
+ for x in buf:
245
+ yield x
246
+
247
+
248
+ def static_batch(data, batch_size=16):
249
+ """ Static batch the data by `batch_size`
250
+
251
+ Args:
252
+ data: Iterable[{key, feat, label}]
253
+ batch_size: batch size
254
+
255
+ Returns:
256
+ Iterable[List[{key, feat, label}]]
257
+ """
258
+ buf = []
259
+ for sample in data:
260
+ buf.append(sample)
261
+ if len(buf) >= batch_size:
262
+ yield buf
263
+ buf = []
264
+ if len(buf) > 0:
265
+ yield buf
266
+
267
+
268
+ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
269
+ """ Dynamic batch the data until the total frames in batch
270
+ reach `max_frames_in_batch`
271
+
272
+ Args:
273
+ data: Iterable[{key, feat, label}]
274
+ max_frames_in_batch: max_frames in one batch
275
+
276
+ Returns:
277
+ Iterable[List[{key, feat, label}]]
278
+ """
279
+ buf = []
280
+ longest_frames = 0
281
+ for sample in data:
282
+ assert 'speech_feat' in sample
283
+ assert isinstance(sample['speech_feat'], torch.Tensor)
284
+ new_sample_frames = sample['speech_feat'].size(0)
285
+ longest_frames = max(longest_frames, new_sample_frames)
286
+ frames_after_padding = longest_frames * (len(buf) + 1)
287
+ if frames_after_padding > max_frames_in_batch:
288
+ yield buf
289
+ buf = [sample]
290
+ longest_frames = new_sample_frames
291
+ else:
292
+ buf.append(sample)
293
+ if len(buf) > 0:
294
+ yield buf
295
+
296
+
297
+ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
298
+ """ Wrapper for static/dynamic batch
299
+ """
300
+ if mode == 'inference':
301
+ return static_batch(data, 1)
302
+ else:
303
+ if batch_type == 'static':
304
+ return static_batch(data, batch_size)
305
+ elif batch_type == 'dynamic':
306
+ return dynamic_batch(data, max_frames_in_batch)
307
+ else:
308
+ logging.fatal('Unsupported batch type {}'.format(batch_type))
309
+
310
+
311
+ def padding(data, use_spk_embedding, mode='train'):
312
+ """ Padding the data into training data
313
+
314
+ Args:
315
+ data: Iterable[List[{key, feat, label}]]
316
+
317
+ Returns:
318
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
319
+ """
320
+ for sample in data:
321
+ assert isinstance(sample, list)
322
+ speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
323
+ dtype=torch.int32)
324
+ order = torch.argsort(speech_feat_len, descending=True)
325
+
326
+ utts = [sample[i]['utt'] for i in order]
327
+ speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
328
+ speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
329
+ speech_token = pad_sequence(speech_token,
330
+ batch_first=True,
331
+ padding_value=0)
332
+ speech_feat = [sample[i]['speech_feat'] for i in order]
333
+ speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
334
+ speech_feat = pad_sequence(speech_feat,
335
+ batch_first=True,
336
+ padding_value=0)
337
+ text = [sample[i]['text'] for i in order]
338
+ text_token = [torch.tensor(sample[i]['text_token']) for i in order]
339
+ text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
340
+ text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
341
+ utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
342
+ spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
343
+ batch = {
344
+ "utts": utts,
345
+ "speech_token": speech_token,
346
+ "speech_token_len": speech_token_len,
347
+ "speech_feat": speech_feat,
348
+ "speech_feat_len": speech_feat_len,
349
+ "text": text,
350
+ "text_token": text_token,
351
+ "text_token_len": text_token_len,
352
+ "utt_embedding": utt_embedding,
353
+ "spk_embedding": spk_embedding,
354
+ }
355
+ if mode == 'inference':
356
+ tts_text = [sample[i]['tts_text'] for i in order]
357
+ tts_index = [sample[i]['tts_index'] for i in order]
358
+ tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
359
+ tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
360
+ tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
361
+ batch.update({'tts_text': tts_text,
362
+ 'tts_index': tts_index,
363
+ 'tts_text_token': tts_text_token,
364
+ 'tts_text_token_len': tts_text_token_len})
365
+ if use_spk_embedding is True:
366
+ batch["embedding"] = batch["spk_embedding"]
367
+ else:
368
+ batch["embedding"] = batch["utt_embedding"]
369
+ yield batch
cosyvoice/flow/decoder.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ from einops import pack, rearrange, repeat
17
+ from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
18
+ from matcha.models.components.transformer import BasicTransformerBlock
19
+
20
+
21
+ class ConditionalDecoder(nn.Module):
22
+ def __init__(
23
+ self,
24
+ in_channels,
25
+ out_channels,
26
+ channels=(256, 256),
27
+ dropout=0.05,
28
+ attention_head_dim=64,
29
+ n_blocks=1,
30
+ num_mid_blocks=2,
31
+ num_heads=4,
32
+ act_fn="snake",
33
+ ):
34
+ """
35
+ This decoder requires an input with the same shape of the target. So, if your text content
36
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
37
+ """
38
+ super().__init__()
39
+ channels = tuple(channels)
40
+ self.in_channels = in_channels
41
+ self.out_channels = out_channels
42
+
43
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
44
+ time_embed_dim = channels[0] * 4
45
+ self.time_mlp = TimestepEmbedding(
46
+ in_channels=in_channels,
47
+ time_embed_dim=time_embed_dim,
48
+ act_fn="silu",
49
+ )
50
+ self.down_blocks = nn.ModuleList([])
51
+ self.mid_blocks = nn.ModuleList([])
52
+ self.up_blocks = nn.ModuleList([])
53
+
54
+ output_channel = in_channels
55
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
56
+ input_channel = output_channel
57
+ output_channel = channels[i]
58
+ is_last = i == len(channels) - 1
59
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
60
+ transformer_blocks = nn.ModuleList(
61
+ [
62
+ BasicTransformerBlock(
63
+ dim=output_channel,
64
+ num_attention_heads=num_heads,
65
+ attention_head_dim=attention_head_dim,
66
+ dropout=dropout,
67
+ activation_fn=act_fn,
68
+ )
69
+ for _ in range(n_blocks)
70
+ ]
71
+ )
72
+ downsample = (
73
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
74
+ )
75
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
76
+
77
+ for i in range(num_mid_blocks):
78
+ input_channel = channels[-1]
79
+ out_channels = channels[-1]
80
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
81
+
82
+ transformer_blocks = nn.ModuleList(
83
+ [
84
+ BasicTransformerBlock(
85
+ dim=output_channel,
86
+ num_attention_heads=num_heads,
87
+ attention_head_dim=attention_head_dim,
88
+ dropout=dropout,
89
+ activation_fn=act_fn,
90
+ )
91
+ for _ in range(n_blocks)
92
+ ]
93
+ )
94
+
95
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
96
+
97
+ channels = channels[::-1] + (channels[0],)
98
+ for i in range(len(channels) - 1):
99
+ input_channel = channels[i] * 2
100
+ output_channel = channels[i + 1]
101
+ is_last = i == len(channels) - 2
102
+ resnet = ResnetBlock1D(
103
+ dim=input_channel,
104
+ dim_out=output_channel,
105
+ time_emb_dim=time_embed_dim,
106
+ )
107
+ transformer_blocks = nn.ModuleList(
108
+ [
109
+ BasicTransformerBlock(
110
+ dim=output_channel,
111
+ num_attention_heads=num_heads,
112
+ attention_head_dim=attention_head_dim,
113
+ dropout=dropout,
114
+ activation_fn=act_fn,
115
+ )
116
+ for _ in range(n_blocks)
117
+ ]
118
+ )
119
+ upsample = (
120
+ Upsample1D(output_channel, use_conv_transpose=True)
121
+ if not is_last
122
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
123
+ )
124
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
125
+ self.final_block = Block1D(channels[-1], channels[-1])
126
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
127
+ self.initialize_weights()
128
+
129
+
130
+ def initialize_weights(self):
131
+ for m in self.modules():
132
+ if isinstance(m, nn.Conv1d):
133
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
134
+ if m.bias is not None:
135
+ nn.init.constant_(m.bias, 0)
136
+ elif isinstance(m, nn.GroupNorm):
137
+ nn.init.constant_(m.weight, 1)
138
+ nn.init.constant_(m.bias, 0)
139
+ elif isinstance(m, nn.Linear):
140
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
141
+ if m.bias is not None:
142
+ nn.init.constant_(m.bias, 0)
143
+
144
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
145
+ """Forward pass of the UNet1DConditional model.
146
+
147
+ Args:
148
+ x (torch.Tensor): shape (batch_size, in_channels, time)
149
+ mask (_type_): shape (batch_size, 1, time)
150
+ t (_type_): shape (batch_size)
151
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
152
+ cond (_type_, optional): placeholder for future use. Defaults to None.
153
+
154
+ Raises:
155
+ ValueError: _description_
156
+ ValueError: _description_
157
+
158
+ Returns:
159
+ _type_: _description_
160
+ """
161
+
162
+ t = self.time_embeddings(t)
163
+ t = self.time_mlp(t)
164
+
165
+ x = pack([x, mu], "b * t")[0]
166
+
167
+ if spks is not None:
168
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
169
+ x = pack([x, spks], "b * t")[0]
170
+ if cond is not None:
171
+ x = pack([x, cond], "b * t")[0]
172
+
173
+ hiddens = []
174
+ masks = [mask]
175
+ for resnet, transformer_blocks, downsample in self.down_blocks:
176
+ mask_down = masks[-1]
177
+ x = resnet(x, mask_down, t)
178
+ x = rearrange(x, "b c t -> b t c").contiguous()
179
+ attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
180
+ for transformer_block in transformer_blocks:
181
+ x = transformer_block(
182
+ hidden_states=x,
183
+ attention_mask=attn_mask,
184
+ timestep=t,
185
+ )
186
+ x = rearrange(x, "b t c -> b c t").contiguous()
187
+ hiddens.append(x) # Save hidden states for skip connections
188
+ x = downsample(x * mask_down)
189
+ masks.append(mask_down[:, :, ::2])
190
+ masks = masks[:-1]
191
+ mask_mid = masks[-1]
192
+
193
+ for resnet, transformer_blocks in self.mid_blocks:
194
+ x = resnet(x, mask_mid, t)
195
+ x = rearrange(x, "b c t -> b t c").contiguous()
196
+ attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
197
+ for transformer_block in transformer_blocks:
198
+ x = transformer_block(
199
+ hidden_states=x,
200
+ attention_mask=attn_mask,
201
+ timestep=t,
202
+ )
203
+ x = rearrange(x, "b t c -> b c t").contiguous()
204
+
205
+ for resnet, transformer_blocks, upsample in self.up_blocks:
206
+ mask_up = masks.pop()
207
+ skip = hiddens.pop()
208
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
209
+ x = resnet(x, mask_up, t)
210
+ x = rearrange(x, "b c t -> b t c").contiguous()
211
+ attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
212
+ for transformer_block in transformer_blocks:
213
+ x = transformer_block(
214
+ hidden_states=x,
215
+ attention_mask=attn_mask,
216
+ timestep=t,
217
+ )
218
+ x = rearrange(x, "b t c -> b c t").contiguous()
219
+ x = upsample(x * mask_up)
220
+ x = self.final_block(x, mask_up)
221
+ output = self.final_proj(x * mask_up)
222
+ return output * mask
cosyvoice/flow/flow.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ from typing import Dict, Optional
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch.nn import functional as F
19
+ from omegaconf import DictConfig
20
+ from cosyvoice.utils.mask import make_pad_mask
21
+
22
+
23
+ class MaskedDiffWithXvec(torch.nn.Module):
24
+ def __init__(self,
25
+ input_size: int = 512,
26
+ output_size: int = 80,
27
+ spk_embed_dim: int = 192,
28
+ output_type: str = "mel",
29
+ vocab_size: int = 4096,
30
+ input_frame_rate: int = 50,
31
+ only_mask_loss: bool = True,
32
+ encoder: torch.nn.Module = None,
33
+ length_regulator: torch.nn.Module = None,
34
+ decoder: torch.nn.Module = None,
35
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
36
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
37
+ super().__init__()
38
+ self.input_size = input_size
39
+ self.output_size = output_size
40
+ self.decoder_conf = decoder_conf
41
+ self.mel_feat_conf = mel_feat_conf
42
+ self.vocab_size = vocab_size
43
+ self.output_type = output_type
44
+ self.input_frame_rate = input_frame_rate
45
+ logging.info(f"input frame rate={self.input_frame_rate}")
46
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
47
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
48
+ self.encoder = encoder
49
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
50
+ self.decoder = decoder
51
+ self.length_regulator = length_regulator
52
+ self.only_mask_loss = only_mask_loss
53
+
54
+ def forward(
55
+ self,
56
+ batch: dict,
57
+ device: torch.device,
58
+ ) -> Dict[str, Optional[torch.Tensor]]:
59
+ token = batch['speech_token'].to(device)
60
+ token_len = batch['speech_token_len'].to(device)
61
+ feat = batch['speech_feat'].to(device)
62
+ feat_len = batch['speech_feat_len'].to(device)
63
+ embedding = batch['embedding'].to(device)
64
+
65
+ # xvec projection
66
+ embedding = F.normalize(embedding, dim=1)
67
+ embedding = self.spk_embed_affine_layer(embedding)
68
+
69
+ # concat text and prompt_text
70
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
71
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
72
+
73
+ # text encode
74
+ h, h_lengths = self.encoder(token, token_len)
75
+ h = self.encoder_proj(h)
76
+ h, h_lengths = self.length_regulator(h, feat_len)
77
+
78
+ # get conditions
79
+ conds = torch.zeros(feat.shape, device=token.device)
80
+ conds = conds.transpose(1, 2)
81
+
82
+ mask = (~make_pad_mask(feat_len)).to(h)
83
+ feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
84
+ loss, _ = self.decoder.compute_loss(
85
+ feat.transpose(1, 2).contiguous(),
86
+ mask.unsqueeze(1),
87
+ h.transpose(1, 2).contiguous(),
88
+ embedding,
89
+ cond=conds
90
+ )
91
+ return {'loss': loss}
92
+
93
+ @torch.inference_mode()
94
+ def inference(self,
95
+ token,
96
+ token_len,
97
+ prompt_token,
98
+ prompt_token_len,
99
+ prompt_feat,
100
+ prompt_feat_len,
101
+ embedding):
102
+ assert token.shape[0] == 1
103
+ # xvec projection
104
+ embedding = F.normalize(embedding, dim=1)
105
+ embedding = self.spk_embed_affine_layer(embedding)
106
+
107
+ # concat text and prompt_text
108
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
109
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
110
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
111
+
112
+ # text encode
113
+ h, h_lengths = self.encoder(token, token_len)
114
+ h = self.encoder_proj(h)
115
+ feat_len = (token_len / 50 * 22050 / 256).int()
116
+ h, h_lengths = self.length_regulator(h, feat_len)
117
+
118
+ # get conditions
119
+ conds = torch.zeros([1, feat_len.max().item(), self.output_size], device=token.device)
120
+ if prompt_feat.shape[1] != 0:
121
+ for i, j in enumerate(prompt_feat_len):
122
+ conds[i, :j] = prompt_feat[i]
123
+ conds = conds.transpose(1, 2)
124
+
125
+ mask = (~make_pad_mask(feat_len)).to(h)
126
+ feat = self.decoder(
127
+ mu=h.transpose(1, 2).contiguous(),
128
+ mask=mask.unsqueeze(1),
129
+ spks=embedding,
130
+ cond=conds,
131
+ n_timesteps=10
132
+ )
133
+ if prompt_feat.shape[1] != 0:
134
+ feat = feat[:, :, prompt_feat.shape[1]:]
135
+ return feat
cosyvoice/flow/flow_matching.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from matcha.models.components.flow_matching import BASECFM
17
+
18
+ class ConditionalCFM(BASECFM):
19
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
20
+ super().__init__(
21
+ n_feats=in_channels,
22
+ cfm_params=cfm_params,
23
+ n_spks=n_spks,
24
+ spk_emb_dim=spk_emb_dim,
25
+ )
26
+ self.t_scheduler = cfm_params.t_scheduler
27
+ self.training_cfg_rate = cfm_params.training_cfg_rate
28
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
29
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
30
+ # Just change the architecture of the estimator here
31
+ self.estimator = estimator
32
+
33
+ @torch.inference_mode()
34
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
35
+ """Forward diffusion
36
+
37
+ Args:
38
+ mu (torch.Tensor): output of encoder
39
+ shape: (batch_size, n_feats, mel_timesteps)
40
+ mask (torch.Tensor): output_mask
41
+ shape: (batch_size, 1, mel_timesteps)
42
+ n_timesteps (int): number of diffusion steps
43
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
44
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
45
+ shape: (batch_size, spk_emb_dim)
46
+ cond: Not used but kept for future purposes
47
+
48
+ Returns:
49
+ sample: generated mel-spectrogram
50
+ shape: (batch_size, n_feats, mel_timesteps)
51
+ """
52
+ z = torch.randn_like(mu) * temperature
53
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
54
+ if self.t_scheduler == 'cosine':
55
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
56
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
57
+
58
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
59
+ """
60
+ Fixed euler solver for ODEs.
61
+ Args:
62
+ x (torch.Tensor): random noise
63
+ t_span (torch.Tensor): n_timesteps interpolated
64
+ shape: (n_timesteps + 1,)
65
+ mu (torch.Tensor): output of encoder
66
+ shape: (batch_size, n_feats, mel_timesteps)
67
+ mask (torch.Tensor): output_mask
68
+ shape: (batch_size, 1, mel_timesteps)
69
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
70
+ shape: (batch_size, spk_emb_dim)
71
+ cond: Not used but kept for future purposes
72
+ """
73
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
74
+
75
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
76
+ # Or in future might add like a return_all_steps flag
77
+ sol = []
78
+
79
+ for step in range(1, len(t_span)):
80
+ dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
81
+ # Classifier-Free Guidance inference introduced in VoiceBox
82
+ if self.inference_cfg_rate > 0:
83
+ cfg_dphi_dt = self.estimator(
84
+ x, mask,
85
+ torch.zeros_like(mu), t,
86
+ torch.zeros_like(spks) if spks is not None else None,
87
+ torch.zeros_like(cond)
88
+ )
89
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
90
+ self.inference_cfg_rate * cfg_dphi_dt)
91
+ x = x + dt * dphi_dt
92
+ t = t + dt
93
+ sol.append(x)
94
+ if step < len(t_span) - 1:
95
+ dt = t_span[step + 1] - t
96
+
97
+ return sol[-1]
98
+
99
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
100
+ """Computes diffusion loss
101
+
102
+ Args:
103
+ x1 (torch.Tensor): Target
104
+ shape: (batch_size, n_feats, mel_timesteps)
105
+ mask (torch.Tensor): target mask
106
+ shape: (batch_size, 1, mel_timesteps)
107
+ mu (torch.Tensor): output of encoder
108
+ shape: (batch_size, n_feats, mel_timesteps)
109
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
110
+ shape: (batch_size, spk_emb_dim)
111
+
112
+ Returns:
113
+ loss: conditional flow matching loss
114
+ y: conditional flow
115
+ shape: (batch_size, n_feats, mel_timesteps)
116
+ """
117
+ b, _, t = mu.shape
118
+
119
+ # random timestep
120
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
121
+ if self.t_scheduler == 'cosine':
122
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
123
+ # sample noise p(x_0)
124
+ z = torch.randn_like(x1)
125
+
126
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
127
+ u = x1 - (1 - self.sigma_min) * z
128
+
129
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
130
+ if self.training_cfg_rate > 0:
131
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
132
+ mu = mu * cfg_mask.view(-1, 1, 1)
133
+ spks = spks * cfg_mask.view(-1, 1)
134
+ cond = cond * cfg_mask.view(-1, 1, 1)
135
+
136
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
137
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
138
+ return loss, y
cosyvoice/flow/length_regulator.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Tuple
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+ from cosyvoice.utils.mask import make_pad_mask
18
+
19
+
20
+ class InterpolateRegulator(nn.Module):
21
+ def __init__(
22
+ self,
23
+ channels: int,
24
+ sampling_ratios: Tuple,
25
+ out_channels: int = None,
26
+ groups: int = 1,
27
+ ):
28
+ super().__init__()
29
+ self.sampling_ratios = sampling_ratios
30
+ out_channels = out_channels or channels
31
+ model = nn.ModuleList([])
32
+ if len(sampling_ratios) > 0:
33
+ for _ in sampling_ratios:
34
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
35
+ norm = nn.GroupNorm(groups, channels)
36
+ act = nn.Mish()
37
+ model.extend([module, norm, act])
38
+ model.append(
39
+ nn.Conv1d(channels, out_channels, 1, 1)
40
+ )
41
+ self.model = nn.Sequential(*model)
42
+
43
+ def forward(self, x, ylens=None):
44
+ # x in (B, T, D)
45
+ mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
46
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
47
+ out = self.model(x).transpose(1, 2).contiguous()
48
+ olens = ylens
49
+ return out * mask, olens
cosyvoice/hifigan/f0_predictor.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn.utils import weight_norm
17
+
18
+
19
+ class ConvRNNF0Predictor(nn.Module):
20
+ def __init__(self,
21
+ num_class: int = 1,
22
+ in_channels: int = 80,
23
+ cond_channels: int = 512
24
+ ):
25
+ super().__init__()
26
+
27
+ self.num_class = num_class
28
+ self.condnet = nn.Sequential(
29
+ weight_norm(
30
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
31
+ ),
32
+ nn.ELU(),
33
+ weight_norm(
34
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
35
+ ),
36
+ nn.ELU(),
37
+ weight_norm(
38
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
39
+ ),
40
+ nn.ELU(),
41
+ weight_norm(
42
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
43
+ ),
44
+ nn.ELU(),
45
+ weight_norm(
46
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
47
+ ),
48
+ nn.ELU(),
49
+ )
50
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ x = self.condnet(x)
54
+ x = x.transpose(1, 2)
55
+ return torch.abs(self.classifier(x).squeeze(-1))
cosyvoice/hifigan/generator.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """HIFI-GAN"""
16
+
17
+ import typing as tp
18
+ import numpy as np
19
+ from scipy.signal import get_window
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch.nn import Conv1d
24
+ from torch.nn import ConvTranspose1d
25
+ from torch.nn.utils import remove_weight_norm
26
+ from torch.nn.utils import weight_norm
27
+ from torch.distributions.uniform import Uniform
28
+
29
+ from cosyvoice.transformer.activation import Snake
30
+ from cosyvoice.utils.common import get_padding
31
+ from cosyvoice.utils.common import init_weights
32
+
33
+
34
+ """hifigan based generator implementation.
35
+
36
+ This code is modified from https://github.com/jik876/hifi-gan
37
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
38
+ https://github.com/NVIDIA/BigVGAN
39
+
40
+ """
41
+ class ResBlock(torch.nn.Module):
42
+ """Residual block module in HiFiGAN/BigVGAN."""
43
+ def __init__(
44
+ self,
45
+ channels: int = 512,
46
+ kernel_size: int = 3,
47
+ dilations: tp.List[int] = [1, 3, 5],
48
+ ):
49
+ super(ResBlock, self).__init__()
50
+ self.convs1 = nn.ModuleList()
51
+ self.convs2 = nn.ModuleList()
52
+
53
+ for dilation in dilations:
54
+ self.convs1.append(
55
+ weight_norm(
56
+ Conv1d(
57
+ channels,
58
+ channels,
59
+ kernel_size,
60
+ 1,
61
+ dilation=dilation,
62
+ padding=get_padding(kernel_size, dilation)
63
+ )
64
+ )
65
+ )
66
+ self.convs2.append(
67
+ weight_norm(
68
+ Conv1d(
69
+ channels,
70
+ channels,
71
+ kernel_size,
72
+ 1,
73
+ dilation=1,
74
+ padding=get_padding(kernel_size, 1)
75
+ )
76
+ )
77
+ )
78
+ self.convs1.apply(init_weights)
79
+ self.convs2.apply(init_weights)
80
+ self.activations1 = nn.ModuleList([
81
+ Snake(channels, alpha_logscale=False)
82
+ for _ in range(len(self.convs1))
83
+ ])
84
+ self.activations2 = nn.ModuleList([
85
+ Snake(channels, alpha_logscale=False)
86
+ for _ in range(len(self.convs2))
87
+ ])
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ for idx in range(len(self.convs1)):
91
+ xt = self.activations1[idx](x)
92
+ xt = self.convs1[idx](xt)
93
+ xt = self.activations2[idx](xt)
94
+ xt = self.convs2[idx](xt)
95
+ x = xt + x
96
+ return x
97
+
98
+ def remove_weight_norm(self):
99
+ for idx in range(len(self.convs1)):
100
+ remove_weight_norm(self.convs1[idx])
101
+ remove_weight_norm(self.convs2[idx])
102
+
103
+ class SineGen(torch.nn.Module):
104
+ """ Definition of sine generator
105
+ SineGen(samp_rate, harmonic_num = 0,
106
+ sine_amp = 0.1, noise_std = 0.003,
107
+ voiced_threshold = 0,
108
+ flag_for_pulse=False)
109
+ samp_rate: sampling rate in Hz
110
+ harmonic_num: number of harmonic overtones (default 0)
111
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
112
+ noise_std: std of Gaussian noise (default 0.003)
113
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
114
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
115
+ Note: when flag_for_pulse is True, the first time step of a voiced
116
+ segment is always sin(np.pi) or cos(0)
117
+ """
118
+
119
+ def __init__(self, samp_rate, harmonic_num=0,
120
+ sine_amp=0.1, noise_std=0.003,
121
+ voiced_threshold=0):
122
+ super(SineGen, self).__init__()
123
+ self.sine_amp = sine_amp
124
+ self.noise_std = noise_std
125
+ self.harmonic_num = harmonic_num
126
+ self.sampling_rate = samp_rate
127
+ self.voiced_threshold = voiced_threshold
128
+
129
+ def _f02uv(self, f0):
130
+ # generate uv signal
131
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
132
+ return uv
133
+
134
+ @torch.no_grad()
135
+ def forward(self, f0):
136
+ """
137
+ :param f0: [B, 1, sample_len], Hz
138
+ :return: [B, 1, sample_len]
139
+ """
140
+
141
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
142
+ for i in range(self.harmonic_num + 1):
143
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
144
+
145
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
146
+ u_dist = Uniform(low=-np.pi, high=np.pi)
147
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
148
+ phase_vec[:, 0, :] = 0
149
+
150
+ # generate sine waveforms
151
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
152
+
153
+ # generate uv signal
154
+ uv = self._f02uv(f0)
155
+
156
+ # noise: for unvoiced should be similar to sine_amp
157
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
158
+ # . for voiced regions is self.noise_std
159
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
160
+ noise = noise_amp * torch.randn_like(sine_waves)
161
+
162
+ # first: set the unvoiced part to 0 by uv
163
+ # then: additive noise
164
+ sine_waves = sine_waves * uv + noise
165
+ return sine_waves, uv, noise
166
+
167
+
168
+ class SourceModuleHnNSF(torch.nn.Module):
169
+ """ SourceModule for hn-nsf
170
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
171
+ add_noise_std=0.003, voiced_threshod=0)
172
+ sampling_rate: sampling_rate in Hz
173
+ harmonic_num: number of harmonic above F0 (default: 0)
174
+ sine_amp: amplitude of sine source signal (default: 0.1)
175
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
176
+ note that amplitude of noise in unvoiced is decided
177
+ by sine_amp
178
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
179
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
180
+ F0_sampled (batchsize, length, 1)
181
+ Sine_source (batchsize, length, 1)
182
+ noise_source (batchsize, length 1)
183
+ uv (batchsize, length, 1)
184
+ """
185
+
186
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
187
+ add_noise_std=0.003, voiced_threshod=0):
188
+ super(SourceModuleHnNSF, self).__init__()
189
+
190
+ self.sine_amp = sine_amp
191
+ self.noise_std = add_noise_std
192
+
193
+ # to produce sine waveforms
194
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
195
+ sine_amp, add_noise_std, voiced_threshod)
196
+
197
+ # to merge source harmonics into a single excitation
198
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
199
+ self.l_tanh = torch.nn.Tanh()
200
+
201
+ def forward(self, x):
202
+ """
203
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
204
+ F0_sampled (batchsize, length, 1)
205
+ Sine_source (batchsize, length, 1)
206
+ noise_source (batchsize, length 1)
207
+ """
208
+ # source for harmonic branch
209
+ with torch.no_grad():
210
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
211
+ sine_wavs = sine_wavs.transpose(1, 2)
212
+ uv = uv.transpose(1, 2)
213
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
214
+
215
+ # source for noise branch, in the same shape as uv
216
+ noise = torch.randn_like(uv) * self.sine_amp / 3
217
+ return sine_merge, noise, uv
218
+
219
+
220
+ class HiFTGenerator(nn.Module):
221
+ """
222
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
223
+ https://arxiv.org/abs/2309.09493
224
+ """
225
+ def __init__(
226
+ self,
227
+ in_channels: int = 80,
228
+ base_channels: int = 512,
229
+ nb_harmonics: int = 8,
230
+ sampling_rate: int = 22050,
231
+ nsf_alpha: float = 0.1,
232
+ nsf_sigma: float = 0.003,
233
+ nsf_voiced_threshold: float = 10,
234
+ upsample_rates: tp.List[int] = [8, 8],
235
+ upsample_kernel_sizes: tp.List[int] = [16, 16],
236
+ istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
237
+ resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
238
+ resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
239
+ source_resblock_kernel_sizes: tp.List[int] = [7, 11],
240
+ source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
241
+ lrelu_slope: float = 0.1,
242
+ audio_limit: float = 0.99,
243
+ f0_predictor: torch.nn.Module = None,
244
+ ):
245
+ super(HiFTGenerator, self).__init__()
246
+
247
+ self.out_channels = 1
248
+ self.nb_harmonics = nb_harmonics
249
+ self.sampling_rate = sampling_rate
250
+ self.istft_params = istft_params
251
+ self.lrelu_slope = lrelu_slope
252
+ self.audio_limit = audio_limit
253
+
254
+ self.num_kernels = len(resblock_kernel_sizes)
255
+ self.num_upsamples = len(upsample_rates)
256
+ self.m_source = SourceModuleHnNSF(
257
+ sampling_rate=sampling_rate,
258
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
259
+ harmonic_num=nb_harmonics,
260
+ sine_amp=nsf_alpha,
261
+ add_noise_std=nsf_sigma,
262
+ voiced_threshod=nsf_voiced_threshold)
263
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
264
+
265
+ self.conv_pre = weight_norm(
266
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
267
+ )
268
+
269
+ # Up
270
+ self.ups = nn.ModuleList()
271
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
272
+ self.ups.append(
273
+ weight_norm(
274
+ ConvTranspose1d(
275
+ base_channels // (2**i),
276
+ base_channels // (2**(i + 1)),
277
+ k,
278
+ u,
279
+ padding=(k - u) // 2,
280
+ )
281
+ )
282
+ )
283
+
284
+ # Down
285
+ self.source_downs = nn.ModuleList()
286
+ self.source_resblocks = nn.ModuleList()
287
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
288
+ downsample_cum_rates = np.cumprod(downsample_rates)
289
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
290
+ source_resblock_dilation_sizes)):
291
+ if u == 1:
292
+ self.source_downs.append(
293
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
294
+ )
295
+ else:
296
+ self.source_downs.append(
297
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
298
+ )
299
+
300
+ self.source_resblocks.append(
301
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
302
+ )
303
+
304
+ self.resblocks = nn.ModuleList()
305
+ for i in range(len(self.ups)):
306
+ ch = base_channels // (2**(i + 1))
307
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
308
+ self.resblocks.append(ResBlock(ch, k, d))
309
+
310
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
311
+ self.ups.apply(init_weights)
312
+ self.conv_post.apply(init_weights)
313
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
314
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
315
+ self.f0_predictor = f0_predictor
316
+
317
+ def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
318
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
319
+
320
+ har_source, _, _ = self.m_source(f0)
321
+ return har_source.transpose(1, 2)
322
+
323
+ def _stft(self, x):
324
+ spec = torch.stft(
325
+ x,
326
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
327
+ return_complex=True)
328
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
329
+ return spec[..., 0], spec[..., 1]
330
+
331
+ def _istft(self, magnitude, phase):
332
+ magnitude = torch.clip(magnitude, max=1e2)
333
+ real = magnitude * torch.cos(phase)
334
+ img = magnitude * torch.sin(phase)
335
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
336
+ return inverse_transform
337
+
338
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
339
+ f0 = self.f0_predictor(x)
340
+ s = self._f02source(f0)
341
+
342
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
343
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
344
+
345
+ x = self.conv_pre(x)
346
+ for i in range(self.num_upsamples):
347
+ x = F.leaky_relu(x, self.lrelu_slope)
348
+ x = self.ups[i](x)
349
+
350
+ if i == self.num_upsamples - 1:
351
+ x = self.reflection_pad(x)
352
+
353
+ # fusion
354
+ si = self.source_downs[i](s_stft)
355
+ si = self.source_resblocks[i](si)
356
+ x = x + si
357
+
358
+ xs = None
359
+ for j in range(self.num_kernels):
360
+ if xs is None:
361
+ xs = self.resblocks[i * self.num_kernels + j](x)
362
+ else:
363
+ xs += self.resblocks[i * self.num_kernels + j](x)
364
+ x = xs / self.num_kernels
365
+
366
+ x = F.leaky_relu(x)
367
+ x = self.conv_post(x)
368
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
369
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
370
+
371
+ x = self._istft(magnitude, phase)
372
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
373
+ return x
374
+
375
+ def remove_weight_norm(self):
376
+ print('Removing weight norm...')
377
+ for l in self.ups:
378
+ remove_weight_norm(l)
379
+ for l in self.resblocks:
380
+ l.remove_weight_norm()
381
+ remove_weight_norm(self.conv_pre)
382
+ remove_weight_norm(self.conv_post)
383
+ self.source_module.remove_weight_norm()
384
+ for l in self.source_downs:
385
+ remove_weight_norm(l)
386
+ for l in self.source_resblocks:
387
+ l.remove_weight_norm()
388
+
389
+ @torch.inference_mode()
390
+ def inference(self, mel: torch.Tensor) -> torch.Tensor:
391
+ return self.forward(x=mel)
cosyvoice/llm/llm.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, Optional, Union
15
+ import torch
16
+ from torch import nn
17
+ import torch.nn.functional as F
18
+ from torch.nn.utils.rnn import pad_sequence, unpad_sequence
19
+ from cosyvoice.utils.common import IGNORE_ID
20
+ from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
21
+ from cosyvoice.utils.common import th_accuracy
22
+
23
+
24
+ class TransformerLM(torch.nn.Module):
25
+ def __init__(
26
+ self,
27
+ text_encoder_input_size: int,
28
+ llm_input_size: int,
29
+ llm_output_size: int,
30
+ text_token_size: int,
31
+ speech_token_size: int,
32
+ text_encoder: torch.nn.Module,
33
+ llm: torch.nn.Module,
34
+ length_normalized_loss: bool = True,
35
+ lsm_weight: float = 0.0,
36
+ spk_embed_dim: int = 192,
37
+ ):
38
+ super().__init__()
39
+ self.llm_input_size = llm_input_size
40
+ self.speech_token_size = speech_token_size
41
+ # 1. build text token inputs related modules
42
+ self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
43
+ self.text_encoder = text_encoder
44
+ self.text_encoder_affine_layer = nn.Linear(
45
+ self.text_encoder.output_size(),
46
+ llm_input_size
47
+ )
48
+
49
+ # 2. build speech token language model related modules
50
+ self.sos_eos = 0
51
+ self.task_id = 1
52
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
53
+ self.llm = llm
54
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
55
+ self.criterion_ce = LabelSmoothingLoss(
56
+ size=speech_token_size + 1,
57
+ padding_idx=IGNORE_ID,
58
+ smoothing=lsm_weight,
59
+ normalize_length=length_normalized_loss,
60
+ )
61
+
62
+ # 3. [Optional] build speech token related modules
63
+ self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
64
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
65
+
66
+ def encode(
67
+ self,
68
+ text: torch.Tensor,
69
+ text_lengths: torch.Tensor,
70
+ ):
71
+ encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
72
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
73
+ encoder_out = self.text_encoder_affine_layer(encoder_out)
74
+ return encoder_out, encoder_out_lens
75
+
76
+ def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
77
+ text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
78
+ speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
79
+ lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) for i in range(len(text_token))]
80
+ lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
81
+ lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
82
+ return lm_input, lm_input_len
83
+
84
+ def forward(
85
+ self,
86
+ batch: dict,
87
+ device: torch.device,
88
+ ) -> Dict[str, Optional[torch.Tensor]]:
89
+ """
90
+ Args:
91
+ text: (B, L, D)
92
+ text_lengths: (B,)
93
+ audio: (B, T, N) or (B, T)
94
+ audio_lengths: (B,)
95
+ """
96
+ text_token = batch['text_token'].to(device)
97
+ text_token_len = batch['text_token_len'].to(device)
98
+ speech_token = batch['speech_token'].to(device)
99
+ speech_token_len = batch['speech_token_len'].to(device)
100
+ embedding = batch['embedding'].to(device)
101
+
102
+ # 1. prepare llm_target
103
+ lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + [self.speech_token_size]) for i in range(text_token.size(0))]
104
+ lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
105
+
106
+ # 1. encode text_token
107
+ text_token = self.text_embedding(text_token)
108
+ text_token, text_token_len = self.encode(text_token, text_token_len)
109
+
110
+ # 2. embedding projection
111
+ embedding = F.normalize(embedding, dim=1)
112
+ embedding = self.spk_embed_affine_layer(embedding)
113
+ embedding = embedding.unsqueeze(1)
114
+
115
+ # 3. eos and task_id
116
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
117
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
118
+
119
+ # 4. encode speech_token
120
+ speech_token = self.speech_embedding(speech_token)
121
+
122
+ # 5. unpad and pad
123
+ lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len)
124
+
125
+ # 6. run lm forward
126
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
127
+ logits = self.llm_decoder(lm_output)
128
+ loss = self.criterion_ce(logits, lm_target)
129
+ acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
130
+ return {'loss': loss, 'acc': acc}
131
+
132
+ def sampling_ids(
133
+ self,
134
+ weighted_scores: torch.Tensor,
135
+ sampling: Union[bool, int, float] = True,
136
+ beam_size: int = 1,
137
+ ignore_eos: bool = True,
138
+ ):
139
+ while True:
140
+ prob, indices = weighted_scores.softmax(dim=-1).topk(sampling)
141
+ top_ids = prob.multinomial(beam_size, replacement=True)
142
+ top_ids = indices[top_ids]
143
+ if (not ignore_eos) or (self.speech_token_size not in top_ids):
144
+ break
145
+ return top_ids
146
+
147
+ @torch.inference_mode()
148
+ def inference(
149
+ self,
150
+ text: torch.Tensor,
151
+ text_len: torch.Tensor,
152
+ prompt_text: torch.Tensor,
153
+ prompt_text_len: torch.Tensor,
154
+ prompt_speech_token: torch.Tensor,
155
+ prompt_speech_token_len: torch.Tensor,
156
+ embedding: torch.Tensor,
157
+ beam_size: int = 1,
158
+ sampling: int = 25,
159
+ max_token_text_ratio: float = 20,
160
+ min_token_text_ratio: float = 2,
161
+ ) -> torch.Tensor:
162
+ device = text.device
163
+ text = torch.concat([prompt_text, text], dim=1)
164
+ text_len += prompt_text_len
165
+ text = self.text_embedding(text)
166
+
167
+ # 1. encode text
168
+ text, text_len = self.encode(text, text_len)
169
+
170
+ # 2. encode embedding
171
+ if embedding.shape[0] != 0:
172
+ embedding = F.normalize(embedding, dim=1)
173
+ embedding = self.spk_embed_affine_layer(embedding)
174
+ embedding = embedding.unsqueeze(dim=1)
175
+ else:
176
+ embedding = torch.zeros(1, 0, self.llm_input_size).to(device)
177
+
178
+ # 3. concat llm_input
179
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
180
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
181
+ if prompt_speech_token_len != 0:
182
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
183
+ else:
184
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size).to(device)
185
+ lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
186
+
187
+ # 4. cal min/max_length
188
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
189
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
190
+
191
+ # 5. step by step decode
192
+ out_tokens = []
193
+ offset = 0
194
+ att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
195
+ for i in range(max_len):
196
+ y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
197
+ att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
198
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
199
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item()
200
+ if top_ids == self.speech_token_size:
201
+ break
202
+ out_tokens.append(top_ids)
203
+ offset += lm_input.size(1)
204
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
205
+
206
+ return torch.tensor([out_tokens], dtype=torch.int64, device=device)
cosyvoice/transformer/__init__.py ADDED
File without changes
cosyvoice/transformer/activation.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
2
+ # 2020 Northwestern Polytechnical University (Pengcheng Guo)
3
+ # 2020 Mobvoi Inc (Binbin Zhang)
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Swish() activation function for Conformer."""
18
+
19
+ import torch
20
+ from torch import nn, sin, pow
21
+ from torch.nn import Parameter
22
+
23
+
24
+ class Swish(torch.nn.Module):
25
+ """Construct an Swish object."""
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ """Return Swish activation function."""
29
+ return x * torch.sigmoid(x)
30
+
31
+
32
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
33
+ # LICENSE is in incl_licenses directory.
34
+ class Snake(nn.Module):
35
+ '''
36
+ Implementation of a sine-based periodic activation function
37
+ Shape:
38
+ - Input: (B, C, T)
39
+ - Output: (B, C, T), same shape as the input
40
+ Parameters:
41
+ - alpha - trainable parameter
42
+ References:
43
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
44
+ https://arxiv.org/abs/2006.08195
45
+ Examples:
46
+ >>> a1 = snake(256)
47
+ >>> x = torch.randn(256)
48
+ >>> x = a1(x)
49
+ '''
50
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
51
+ '''
52
+ Initialization.
53
+ INPUT:
54
+ - in_features: shape of the input
55
+ - alpha: trainable parameter
56
+ alpha is initialized to 1 by default, higher values = higher-frequency.
57
+ alpha will be trained along with the rest of your model.
58
+ '''
59
+ super(Snake, self).__init__()
60
+ self.in_features = in_features
61
+
62
+ # initialize alpha
63
+ self.alpha_logscale = alpha_logscale
64
+ if self.alpha_logscale: # log scale alphas initialized to zeros
65
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
66
+ else: # linear scale alphas initialized to ones
67
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
68
+
69
+ self.alpha.requires_grad = alpha_trainable
70
+
71
+ self.no_div_by_zero = 0.000000001
72
+
73
+ def forward(self, x):
74
+ '''
75
+ Forward pass of the function.
76
+ Applies the function to the input elementwise.
77
+ Snake ∶= x + 1/a * sin^2 (xa)
78
+ '''
79
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
80
+ if self.alpha_logscale:
81
+ alpha = torch.exp(alpha)
82
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
83
+
84
+ return x
cosyvoice/transformer/attention.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2022 Xingchen Song ([email protected])
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Multi-Head Attention layer definition."""
18
+
19
+ import math
20
+ from typing import Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+
26
+ class MultiHeadedAttention(nn.Module):
27
+ """Multi-Head Attention layer.
28
+
29
+ Args:
30
+ n_head (int): The number of heads.
31
+ n_feat (int): The number of features.
32
+ dropout_rate (float): Dropout rate.
33
+
34
+ """
35
+
36
+ def __init__(self,
37
+ n_head: int,
38
+ n_feat: int,
39
+ dropout_rate: float,
40
+ key_bias: bool = True):
41
+ """Construct an MultiHeadedAttention object."""
42
+ super().__init__()
43
+ assert n_feat % n_head == 0
44
+ # We assume d_v always equals d_k
45
+ self.d_k = n_feat // n_head
46
+ self.h = n_head
47
+ self.linear_q = nn.Linear(n_feat, n_feat)
48
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
49
+ self.linear_v = nn.Linear(n_feat, n_feat)
50
+ self.linear_out = nn.Linear(n_feat, n_feat)
51
+ self.dropout = nn.Dropout(p=dropout_rate)
52
+
53
+ def forward_qkv(
54
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
55
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
56
+ """Transform query, key and value.
57
+
58
+ Args:
59
+ query (torch.Tensor): Query tensor (#batch, time1, size).
60
+ key (torch.Tensor): Key tensor (#batch, time2, size).
61
+ value (torch.Tensor): Value tensor (#batch, time2, size).
62
+
63
+ Returns:
64
+ torch.Tensor: Transformed query tensor, size
65
+ (#batch, n_head, time1, d_k).
66
+ torch.Tensor: Transformed key tensor, size
67
+ (#batch, n_head, time2, d_k).
68
+ torch.Tensor: Transformed value tensor, size
69
+ (#batch, n_head, time2, d_k).
70
+
71
+ """
72
+ n_batch = query.size(0)
73
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
74
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
75
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
76
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
77
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
78
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
79
+
80
+ return q, k, v
81
+
82
+ def forward_attention(
83
+ self,
84
+ value: torch.Tensor,
85
+ scores: torch.Tensor,
86
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
87
+ ) -> torch.Tensor:
88
+ """Compute attention context vector.
89
+
90
+ Args:
91
+ value (torch.Tensor): Transformed value, size
92
+ (#batch, n_head, time2, d_k).
93
+ scores (torch.Tensor): Attention score, size
94
+ (#batch, n_head, time1, time2).
95
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
96
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
97
+
98
+ Returns:
99
+ torch.Tensor: Transformed value (#batch, time1, d_model)
100
+ weighted by the attention score (#batch, time1, time2).
101
+
102
+ """
103
+ n_batch = value.size(0)
104
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
105
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
106
+ # 1st chunk to ease the onnx export.]
107
+ # 2. pytorch training
108
+ if mask.size(2) > 0: # time2 > 0
109
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
110
+ # For last chunk, time2 might be larger than scores.size(-1)
111
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
112
+ scores = scores.masked_fill(mask, -float('inf'))
113
+ attn = torch.softmax(scores, dim=-1).masked_fill(
114
+ mask, 0.0) # (batch, head, time1, time2)
115
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
116
+ # 1. onnx(16/-1, -1/-1, 16/0)
117
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
118
+ else:
119
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
120
+
121
+ p_attn = self.dropout(attn)
122
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
123
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
124
+ self.h * self.d_k)
125
+ ) # (batch, time1, d_model)
126
+
127
+ return self.linear_out(x) # (batch, time1, d_model)
128
+
129
+ def forward(
130
+ self,
131
+ query: torch.Tensor,
132
+ key: torch.Tensor,
133
+ value: torch.Tensor,
134
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
135
+ pos_emb: torch.Tensor = torch.empty(0),
136
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
137
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
138
+ """Compute scaled dot product attention.
139
+
140
+ Args:
141
+ query (torch.Tensor): Query tensor (#batch, time1, size).
142
+ key (torch.Tensor): Key tensor (#batch, time2, size).
143
+ value (torch.Tensor): Value tensor (#batch, time2, size).
144
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
145
+ (#batch, time1, time2).
146
+ 1.When applying cross attention between decoder and encoder,
147
+ the batch padding mask for input is in (#batch, 1, T) shape.
148
+ 2.When applying self attention of encoder,
149
+ the mask is in (#batch, T, T) shape.
150
+ 3.When applying self attention of decoder,
151
+ the mask is in (#batch, L, L) shape.
152
+ 4.If the different position in decoder see different block
153
+ of the encoder, such as Mocha, the passed in mask could be
154
+ in (#batch, L, T) shape. But there is no such case in current
155
+ CosyVoice.
156
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
157
+ where `cache_t == chunk_size * num_decoding_left_chunks`
158
+ and `head * d_k == size`
159
+
160
+
161
+ Returns:
162
+ torch.Tensor: Output tensor (#batch, time1, d_model).
163
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
164
+ where `cache_t == chunk_size * num_decoding_left_chunks`
165
+ and `head * d_k == size`
166
+
167
+ """
168
+ q, k, v = self.forward_qkv(query, key, value)
169
+
170
+ # NOTE(xcsong):
171
+ # when export onnx model, for 1st chunk, we feed
172
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
173
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
174
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
175
+ # and we will always do splitting and
176
+ # concatnation(this will simplify onnx export). Note that
177
+ # it's OK to concat & split zero-shaped tensors(see code below).
178
+ # when export jit model, for 1st chunk, we always feed
179
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
180
+ # >>> a = torch.ones((1, 2, 0, 4))
181
+ # >>> b = torch.ones((1, 2, 3, 4))
182
+ # >>> c = torch.cat((a, b), dim=2)
183
+ # >>> torch.equal(b, c) # True
184
+ # >>> d = torch.split(a, 2, dim=-1)
185
+ # >>> torch.equal(d[0], d[1]) # True
186
+ if cache.size(0) > 0:
187
+ key_cache, value_cache = torch.split(cache,
188
+ cache.size(-1) // 2,
189
+ dim=-1)
190
+ k = torch.cat([key_cache, k], dim=2)
191
+ v = torch.cat([value_cache, v], dim=2)
192
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
193
+ # non-trivial to calculate `next_cache_start` here.
194
+ new_cache = torch.cat((k, v), dim=-1)
195
+
196
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
197
+ return self.forward_attention(v, scores, mask), new_cache
198
+
199
+
200
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
201
+ """Multi-Head Attention layer with relative position encoding.
202
+ Paper: https://arxiv.org/abs/1901.02860
203
+ Args:
204
+ n_head (int): The number of heads.
205
+ n_feat (int): The number of features.
206
+ dropout_rate (float): Dropout rate.
207
+ """
208
+
209
+ def __init__(self,
210
+ n_head: int,
211
+ n_feat: int,
212
+ dropout_rate: float,
213
+ key_bias: bool = True):
214
+ """Construct an RelPositionMultiHeadedAttention object."""
215
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
216
+ # linear transformation for positional encoding
217
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
218
+ # these two learnable bias are used in matrix c and matrix d
219
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
220
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
221
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
222
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
223
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
224
+
225
+ def rel_shift(self, x):
226
+ """Compute relative positional encoding.
227
+
228
+ Args:
229
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
230
+ time1 means the length of query vector.
231
+
232
+ Returns:
233
+ torch.Tensor: Output tensor.
234
+
235
+ """
236
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
237
+ x_padded = torch.cat([zero_pad, x], dim=-1)
238
+
239
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
240
+ x = x_padded[:, :, 1:].view_as(x)[
241
+ :, :, :, : x.size(-1) // 2 + 1
242
+ ] # only keep the positions from 0 to time2
243
+ return x
244
+
245
+ def forward(
246
+ self,
247
+ query: torch.Tensor,
248
+ key: torch.Tensor,
249
+ value: torch.Tensor,
250
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
251
+ pos_emb: torch.Tensor = torch.empty(0),
252
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
253
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
254
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
255
+ Args:
256
+ query (torch.Tensor): Query tensor (#batch, time1, size).
257
+ key (torch.Tensor): Key tensor (#batch, time2, size).
258
+ value (torch.Tensor): Value tensor (#batch, time2, size).
259
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
260
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
261
+ pos_emb (torch.Tensor): Positional embedding tensor
262
+ (#batch, time2, size).
263
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
264
+ where `cache_t == chunk_size * num_decoding_left_chunks`
265
+ and `head * d_k == size`
266
+ Returns:
267
+ torch.Tensor: Output tensor (#batch, time1, d_model).
268
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
269
+ where `cache_t == chunk_size * num_decoding_left_chunks`
270
+ and `head * d_k == size`
271
+ """
272
+ q, k, v = self.forward_qkv(query, key, value)
273
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
274
+
275
+ # NOTE(xcsong):
276
+ # when export onnx model, for 1st chunk, we feed
277
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
278
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
279
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
280
+ # and we will always do splitting and
281
+ # concatnation(this will simplify onnx export). Note that
282
+ # it's OK to concat & split zero-shaped tensors(see code below).
283
+ # when export jit model, for 1st chunk, we always feed
284
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
285
+ # >>> a = torch.ones((1, 2, 0, 4))
286
+ # >>> b = torch.ones((1, 2, 3, 4))
287
+ # >>> c = torch.cat((a, b), dim=2)
288
+ # >>> torch.equal(b, c) # True
289
+ # >>> d = torch.split(a, 2, dim=-1)
290
+ # >>> torch.equal(d[0], d[1]) # True
291
+ if cache.size(0) > 0:
292
+ key_cache, value_cache = torch.split(cache,
293
+ cache.size(-1) // 2,
294
+ dim=-1)
295
+ k = torch.cat([key_cache, k], dim=2)
296
+ v = torch.cat([value_cache, v], dim=2)
297
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
298
+ # non-trivial to calculate `next_cache_start` here.
299
+ new_cache = torch.cat((k, v), dim=-1)
300
+
301
+ n_batch_pos = pos_emb.size(0)
302
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
303
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
304
+
305
+ # (batch, head, time1, d_k)
306
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
307
+ # (batch, head, time1, d_k)
308
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
309
+
310
+ # compute attention score
311
+ # first compute matrix a and matrix c
312
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
313
+ # (batch, head, time1, time2)
314
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
315
+
316
+ # compute matrix b and matrix d
317
+ # (batch, head, time1, time2)
318
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
319
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
320
+ if matrix_ac.shape != matrix_bd.shape:
321
+ matrix_bd = self.rel_shift(matrix_bd)
322
+
323
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
324
+ self.d_k) # (batch, head, time1, time2)
325
+
326
+ return self.forward_attention(v, scores, mask), new_cache
cosyvoice/transformer/convolution.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """ConvolutionModule definition."""
17
+
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+
24
+ class ConvolutionModule(nn.Module):
25
+ """ConvolutionModule in Conformer model."""
26
+
27
+ def __init__(self,
28
+ channels: int,
29
+ kernel_size: int = 15,
30
+ activation: nn.Module = nn.ReLU(),
31
+ norm: str = "batch_norm",
32
+ causal: bool = False,
33
+ bias: bool = True):
34
+ """Construct an ConvolutionModule object.
35
+ Args:
36
+ channels (int): The number of channels of conv layers.
37
+ kernel_size (int): Kernel size of conv layers.
38
+ causal (int): Whether use causal convolution or not
39
+ """
40
+ super().__init__()
41
+
42
+ self.pointwise_conv1 = nn.Conv1d(
43
+ channels,
44
+ 2 * channels,
45
+ kernel_size=1,
46
+ stride=1,
47
+ padding=0,
48
+ bias=bias,
49
+ )
50
+ # self.lorder is used to distinguish if it's a causal convolution,
51
+ # if self.lorder > 0: it's a causal convolution, the input will be
52
+ # padded with self.lorder frames on the left in forward.
53
+ # else: it's a symmetrical convolution
54
+ if causal:
55
+ padding = 0
56
+ self.lorder = kernel_size - 1
57
+ else:
58
+ # kernel_size should be an odd number for none causal convolution
59
+ assert (kernel_size - 1) % 2 == 0
60
+ padding = (kernel_size - 1) // 2
61
+ self.lorder = 0
62
+ self.depthwise_conv = nn.Conv1d(
63
+ channels,
64
+ channels,
65
+ kernel_size,
66
+ stride=1,
67
+ padding=padding,
68
+ groups=channels,
69
+ bias=bias,
70
+ )
71
+
72
+ assert norm in ['batch_norm', 'layer_norm']
73
+ if norm == "batch_norm":
74
+ self.use_layer_norm = False
75
+ self.norm = nn.BatchNorm1d(channels)
76
+ else:
77
+ self.use_layer_norm = True
78
+ self.norm = nn.LayerNorm(channels)
79
+
80
+ self.pointwise_conv2 = nn.Conv1d(
81
+ channels,
82
+ channels,
83
+ kernel_size=1,
84
+ stride=1,
85
+ padding=0,
86
+ bias=bias,
87
+ )
88
+ self.activation = activation
89
+
90
+ def forward(
91
+ self,
92
+ x: torch.Tensor,
93
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
94
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
95
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
96
+ """Compute convolution module.
97
+ Args:
98
+ x (torch.Tensor): Input tensor (#batch, time, channels).
99
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
100
+ (0, 0, 0) means fake mask.
101
+ cache (torch.Tensor): left context cache, it is only
102
+ used in causal convolution (#batch, channels, cache_t),
103
+ (0, 0, 0) meas fake cache.
104
+ Returns:
105
+ torch.Tensor: Output tensor (#batch, time, channels).
106
+ """
107
+ # exchange the temporal dimension and the feature dimension
108
+ x = x.transpose(1, 2) # (#batch, channels, time)
109
+
110
+ # mask batch padding
111
+ if mask_pad.size(2) > 0: # time > 0
112
+ x.masked_fill_(~mask_pad, 0.0)
113
+
114
+ if self.lorder > 0:
115
+ if cache.size(2) == 0: # cache_t == 0
116
+ x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
117
+ else:
118
+ assert cache.size(0) == x.size(0) # equal batch
119
+ assert cache.size(1) == x.size(1) # equal channel
120
+ x = torch.cat((cache, x), dim=2)
121
+ assert (x.size(2) > self.lorder)
122
+ new_cache = x[:, :, -self.lorder:]
123
+ else:
124
+ # It's better we just return None if no cache is required,
125
+ # However, for JIT export, here we just fake one tensor instead of
126
+ # None.
127
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
128
+
129
+ # GLU mechanism
130
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
131
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
132
+
133
+ # 1D Depthwise Conv
134
+ x = self.depthwise_conv(x)
135
+ if self.use_layer_norm:
136
+ x = x.transpose(1, 2)
137
+ x = self.activation(self.norm(x))
138
+ if self.use_layer_norm:
139
+ x = x.transpose(1, 2)
140
+ x = self.pointwise_conv2(x)
141
+ # mask batch padding
142
+ if mask_pad.size(2) > 0: # time > 0
143
+ x.masked_fill_(~mask_pad, 0.0)
144
+
145
+ return x.transpose(1, 2), new_cache
cosyvoice/transformer/decoder.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Decoder definition."""
17
+ from typing import Tuple, List, Optional
18
+
19
+ import torch
20
+ import torch.utils.checkpoint as ckpt
21
+ import logging
22
+
23
+ from cosyvoice.transformer.decoder_layer import DecoderLayer
24
+ from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
25
+ from cosyvoice.utils.class_utils import (
26
+ COSYVOICE_EMB_CLASSES,
27
+ COSYVOICE_ATTENTION_CLASSES,
28
+ COSYVOICE_ACTIVATION_CLASSES,
29
+ )
30
+ from cosyvoice.utils.mask import (subsequent_mask, make_pad_mask)
31
+
32
+
33
+ class TransformerDecoder(torch.nn.Module):
34
+ """Base class of Transfomer decoder module.
35
+ Args:
36
+ vocab_size: output dim
37
+ encoder_output_size: dimension of attention
38
+ attention_heads: the number of heads of multi head attention
39
+ linear_units: the hidden units number of position-wise feedforward
40
+ num_blocks: the number of decoder blocks
41
+ dropout_rate: dropout rate
42
+ self_attention_dropout_rate: dropout rate for attention
43
+ input_layer: input layer type
44
+ use_output_layer: whether to use output layer
45
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
46
+ normalize_before:
47
+ True: use layer_norm before each sub-block of a layer.
48
+ False: use layer_norm after each sub-block of a layer.
49
+ src_attention: if false, encoder-decoder cross attention is not
50
+ applied, such as CIF model
51
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
52
+ gradient_checkpointing: rerunning a forward-pass segment for each
53
+ checkpointed segment during backward.
54
+ tie_word_embedding: Tie or clone module weights depending of whether we are
55
+ using TorchScript or not
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ vocab_size: int,
61
+ encoder_output_size: int,
62
+ attention_heads: int = 4,
63
+ linear_units: int = 2048,
64
+ num_blocks: int = 6,
65
+ dropout_rate: float = 0.1,
66
+ positional_dropout_rate: float = 0.1,
67
+ self_attention_dropout_rate: float = 0.0,
68
+ src_attention_dropout_rate: float = 0.0,
69
+ input_layer: str = "embed",
70
+ use_output_layer: bool = True,
71
+ normalize_before: bool = True,
72
+ src_attention: bool = True,
73
+ key_bias: bool = True,
74
+ activation_type: str = "relu",
75
+ gradient_checkpointing: bool = False,
76
+ tie_word_embedding: bool = False,
77
+ ):
78
+ super().__init__()
79
+ attention_dim = encoder_output_size
80
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
81
+
82
+ self.embed = torch.nn.Sequential(
83
+ torch.nn.Identity() if input_layer == "no_pos" else
84
+ torch.nn.Embedding(vocab_size, attention_dim),
85
+ COSYVOICE_EMB_CLASSES[input_layer](attention_dim,
86
+ positional_dropout_rate),
87
+ )
88
+
89
+ self.normalize_before = normalize_before
90
+ self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
91
+ self.use_output_layer = use_output_layer
92
+ if use_output_layer:
93
+ self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
94
+ else:
95
+ self.output_layer = torch.nn.Identity()
96
+ self.num_blocks = num_blocks
97
+ self.decoders = torch.nn.ModuleList([
98
+ DecoderLayer(
99
+ attention_dim,
100
+ COSYVOICE_ATTENTION_CLASSES["selfattn"](
101
+ attention_heads, attention_dim,
102
+ self_attention_dropout_rate, key_bias),
103
+ COSYVOICE_ATTENTION_CLASSES["selfattn"](
104
+ attention_heads, attention_dim, src_attention_dropout_rate,
105
+ key_bias) if src_attention else None,
106
+ PositionwiseFeedForward(attention_dim, linear_units,
107
+ dropout_rate, activation),
108
+ dropout_rate,
109
+ normalize_before,
110
+ ) for _ in range(self.num_blocks)
111
+ ])
112
+
113
+ self.gradient_checkpointing = gradient_checkpointing
114
+ self.tie_word_embedding = tie_word_embedding
115
+
116
+ def forward(
117
+ self,
118
+ memory: torch.Tensor,
119
+ memory_mask: torch.Tensor,
120
+ ys_in_pad: torch.Tensor,
121
+ ys_in_lens: torch.Tensor,
122
+ r_ys_in_pad: torch.Tensor = torch.empty(0),
123
+ reverse_weight: float = 0.0,
124
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
125
+ """Forward decoder.
126
+ Args:
127
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
128
+ memory_mask: encoder memory mask, (batch, 1, maxlen_in)
129
+ ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
130
+ ys_in_lens: input lengths of this batch (batch)
131
+ r_ys_in_pad: not used in transformer decoder, in order to unify api
132
+ with bidirectional decoder
133
+ reverse_weight: not used in transformer decoder, in order to unify
134
+ api with bidirectional decode
135
+ Returns:
136
+ (tuple): tuple containing:
137
+ x: decoded token score before softmax (batch, maxlen_out,
138
+ vocab_size) if use_output_layer is True,
139
+ torch.tensor(0.0), in order to unify api with bidirectional decoder
140
+ olens: (batch, )
141
+ NOTE(xcsong):
142
+ We pass the `__call__` method of the modules instead of `forward` to the
143
+ checkpointing API because `__call__` attaches all the hooks of the module.
144
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
145
+ """
146
+ tgt = ys_in_pad
147
+ maxlen = tgt.size(1)
148
+ # tgt_mask: (B, 1, L)
149
+ tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1)
150
+ tgt_mask = tgt_mask.to(tgt.device)
151
+ # m: (1, L, L)
152
+ m = subsequent_mask(tgt_mask.size(-1),
153
+ device=tgt_mask.device).unsqueeze(0)
154
+ # tgt_mask: (B, L, L)
155
+ tgt_mask = tgt_mask & m
156
+ x, _ = self.embed(tgt)
157
+ if self.gradient_checkpointing and self.training:
158
+ x = self.forward_layers_checkpointed(x, tgt_mask, memory,
159
+ memory_mask)
160
+ else:
161
+ x = self.forward_layers(x, tgt_mask, memory, memory_mask)
162
+ if self.normalize_before:
163
+ x = self.after_norm(x)
164
+ if self.use_output_layer:
165
+ x = self.output_layer(x)
166
+ olens = tgt_mask.sum(1)
167
+ return x, torch.tensor(0.0), olens
168
+
169
+ def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor,
170
+ memory: torch.Tensor,
171
+ memory_mask: torch.Tensor) -> torch.Tensor:
172
+ for layer in self.decoders:
173
+ x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
174
+ memory_mask)
175
+ return x
176
+
177
+ @torch.jit.ignore(drop=True)
178
+ def forward_layers_checkpointed(self, x: torch.Tensor,
179
+ tgt_mask: torch.Tensor,
180
+ memory: torch.Tensor,
181
+ memory_mask: torch.Tensor) -> torch.Tensor:
182
+ for layer in self.decoders:
183
+ x, tgt_mask, memory, memory_mask = ckpt.checkpoint(
184
+ layer.__call__, x, tgt_mask, memory, memory_mask)
185
+ return x
186
+
187
+ def forward_one_step(
188
+ self,
189
+ memory: torch.Tensor,
190
+ memory_mask: torch.Tensor,
191
+ tgt: torch.Tensor,
192
+ tgt_mask: torch.Tensor,
193
+ cache: Optional[List[torch.Tensor]] = None,
194
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
195
+ """Forward one step.
196
+ This is only used for decoding.
197
+ Args:
198
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
199
+ memory_mask: encoded memory mask, (batch, 1, maxlen_in)
200
+ tgt: input token ids, int64 (batch, maxlen_out)
201
+ tgt_mask: input token mask, (batch, maxlen_out)
202
+ dtype=torch.uint8 in PyTorch 1.2-
203
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
204
+ cache: cached output list of (batch, max_time_out-1, size)
205
+ Returns:
206
+ y, cache: NN output value and cache per `self.decoders`.
207
+ y.shape` is (batch, maxlen_out, token)
208
+ """
209
+ x, _ = self.embed(tgt)
210
+ new_cache = []
211
+ for i, decoder in enumerate(self.decoders):
212
+ if cache is None:
213
+ c = None
214
+ else:
215
+ c = cache[i]
216
+ x, tgt_mask, memory, memory_mask = decoder(x,
217
+ tgt_mask,
218
+ memory,
219
+ memory_mask,
220
+ cache=c)
221
+ new_cache.append(x)
222
+ if self.normalize_before:
223
+ y = self.after_norm(x[:, -1])
224
+ else:
225
+ y = x[:, -1]
226
+ if self.use_output_layer:
227
+ y = torch.log_softmax(self.output_layer(y), dim=-1)
228
+ return y, new_cache
229
+
230
+ def tie_or_clone_weights(self, jit_mode: bool = True):
231
+ """Tie or clone module weights (between word_emb and output_layer)
232
+ depending of whether we are using TorchScript or not"""
233
+ if not self.use_output_layer:
234
+ return
235
+ if jit_mode:
236
+ logging.info("clone emb.weight to output.weight")
237
+ self.output_layer.weight = torch.nn.Parameter(
238
+ self.embed[0].weight.clone())
239
+ else:
240
+ logging.info("tie emb.weight with output.weight")
241
+ self.output_layer.weight = self.embed[0].weight
242
+
243
+ if getattr(self.output_layer, "bias", None) is not None:
244
+ self.output_layer.bias.data = torch.nn.functional.pad(
245
+ self.output_layer.bias.data,
246
+ (
247
+ 0,
248
+ self.output_layer.weight.shape[0] -
249
+ self.output_layer.bias.shape[0],
250
+ ),
251
+ "constant",
252
+ 0,
253
+ )
254
+
255
+
256
+ class BiTransformerDecoder(torch.nn.Module):
257
+ """Base class of Transfomer decoder module.
258
+ Args:
259
+ vocab_size: output dim
260
+ encoder_output_size: dimension of attention
261
+ attention_heads: the number of heads of multi head attention
262
+ linear_units: the hidden units number of position-wise feedforward
263
+ num_blocks: the number of decoder blocks
264
+ r_num_blocks: the number of right to left decoder blocks
265
+ dropout_rate: dropout rate
266
+ self_attention_dropout_rate: dropout rate for attention
267
+ input_layer: input layer type
268
+ use_output_layer: whether to use output layer
269
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
270
+ normalize_before:
271
+ True: use layer_norm before each sub-block of a layer.
272
+ False: use layer_norm after each sub-block of a layer.
273
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
274
+ """
275
+
276
+ def __init__(
277
+ self,
278
+ vocab_size: int,
279
+ encoder_output_size: int,
280
+ attention_heads: int = 4,
281
+ linear_units: int = 2048,
282
+ num_blocks: int = 6,
283
+ r_num_blocks: int = 0,
284
+ dropout_rate: float = 0.1,
285
+ positional_dropout_rate: float = 0.1,
286
+ self_attention_dropout_rate: float = 0.0,
287
+ src_attention_dropout_rate: float = 0.0,
288
+ input_layer: str = "embed",
289
+ use_output_layer: bool = True,
290
+ normalize_before: bool = True,
291
+ key_bias: bool = True,
292
+ gradient_checkpointing: bool = False,
293
+ tie_word_embedding: bool = False,
294
+ ):
295
+
296
+ super().__init__()
297
+ self.tie_word_embedding = tie_word_embedding
298
+ self.left_decoder = TransformerDecoder(
299
+ vocab_size,
300
+ encoder_output_size,
301
+ attention_heads,
302
+ linear_units,
303
+ num_blocks,
304
+ dropout_rate,
305
+ positional_dropout_rate,
306
+ self_attention_dropout_rate,
307
+ src_attention_dropout_rate,
308
+ input_layer,
309
+ use_output_layer,
310
+ normalize_before,
311
+ key_bias=key_bias,
312
+ gradient_checkpointing=gradient_checkpointing,
313
+ tie_word_embedding=tie_word_embedding)
314
+
315
+ self.right_decoder = TransformerDecoder(
316
+ vocab_size,
317
+ encoder_output_size,
318
+ attention_heads,
319
+ linear_units,
320
+ r_num_blocks,
321
+ dropout_rate,
322
+ positional_dropout_rate,
323
+ self_attention_dropout_rate,
324
+ src_attention_dropout_rate,
325
+ input_layer,
326
+ use_output_layer,
327
+ normalize_before,
328
+ key_bias=key_bias,
329
+ gradient_checkpointing=gradient_checkpointing,
330
+ tie_word_embedding=tie_word_embedding)
331
+
332
+ def forward(
333
+ self,
334
+ memory: torch.Tensor,
335
+ memory_mask: torch.Tensor,
336
+ ys_in_pad: torch.Tensor,
337
+ ys_in_lens: torch.Tensor,
338
+ r_ys_in_pad: torch.Tensor,
339
+ reverse_weight: float = 0.0,
340
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
341
+ """Forward decoder.
342
+ Args:
343
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
344
+ memory_mask: encoder memory mask, (batch, 1, maxlen_in)
345
+ ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
346
+ ys_in_lens: input lengths of this batch (batch)
347
+ r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
348
+ used for right to left decoder
349
+ reverse_weight: used for right to left decoder
350
+ Returns:
351
+ (tuple): tuple containing:
352
+ x: decoded token score before softmax (batch, maxlen_out,
353
+ vocab_size) if use_output_layer is True,
354
+ r_x: x: decoded token score (right to left decoder)
355
+ before softmax (batch, maxlen_out, vocab_size)
356
+ if use_output_layer is True,
357
+ olens: (batch, )
358
+ """
359
+ l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
360
+ ys_in_lens)
361
+ r_x = torch.tensor(0.0)
362
+ if reverse_weight > 0.0:
363
+ r_x, _, olens = self.right_decoder(memory, memory_mask,
364
+ r_ys_in_pad, ys_in_lens)
365
+ return l_x, r_x, olens
366
+
367
+ def forward_one_step(
368
+ self,
369
+ memory: torch.Tensor,
370
+ memory_mask: torch.Tensor,
371
+ tgt: torch.Tensor,
372
+ tgt_mask: torch.Tensor,
373
+ cache: Optional[List[torch.Tensor]] = None,
374
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
375
+ """Forward one step.
376
+ This is only used for decoding.
377
+ Args:
378
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
379
+ memory_mask: encoded memory mask, (batch, 1, maxlen_in)
380
+ tgt: input token ids, int64 (batch, maxlen_out)
381
+ tgt_mask: input token mask, (batch, maxlen_out)
382
+ dtype=torch.uint8 in PyTorch 1.2-
383
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
384
+ cache: cached output list of (batch, max_time_out-1, size)
385
+ Returns:
386
+ y, cache: NN output value and cache per `self.decoders`.
387
+ y.shape` is (batch, maxlen_out, token)
388
+ """
389
+ return self.left_decoder.forward_one_step(memory, memory_mask, tgt,
390
+ tgt_mask, cache)
391
+
392
+ def tie_or_clone_weights(self, jit_mode: bool = True):
393
+ """Tie or clone module weights (between word_emb and output_layer)
394
+ depending of whether we are using TorchScript or not"""
395
+ self.left_decoder.tie_or_clone_weights(jit_mode)
396
+ self.right_decoder.tie_or_clone_weights(jit_mode)
cosyvoice/transformer/decoder_layer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Decoder self-attention layer definition."""
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+
22
+ class DecoderLayer(nn.Module):
23
+ """Single decoder layer module.
24
+
25
+ Args:
26
+ size (int): Input dimension.
27
+ self_attn (torch.nn.Module): Self-attention module instance.
28
+ `MultiHeadedAttention` instance can be used as the argument.
29
+ src_attn (torch.nn.Module): Inter-attention module instance.
30
+ `MultiHeadedAttention` instance can be used as the argument.
31
+ If `None` is passed, Inter-attention is not used, such as
32
+ CIF, GPT, and other decoder only model.
33
+ feed_forward (torch.nn.Module): Feed-forward module instance.
34
+ `PositionwiseFeedForward` instance can be used as the argument.
35
+ dropout_rate (float): Dropout rate.
36
+ normalize_before (bool):
37
+ True: use layer_norm before each sub-block.
38
+ False: to use layer_norm after each sub-block.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ size: int,
44
+ self_attn: nn.Module,
45
+ src_attn: Optional[nn.Module],
46
+ feed_forward: nn.Module,
47
+ dropout_rate: float,
48
+ normalize_before: bool = True,
49
+ ):
50
+ """Construct an DecoderLayer object."""
51
+ super().__init__()
52
+ self.size = size
53
+ self.self_attn = self_attn
54
+ self.src_attn = src_attn
55
+ self.feed_forward = feed_forward
56
+ self.norm1 = nn.LayerNorm(size, eps=1e-5)
57
+ self.norm2 = nn.LayerNorm(size, eps=1e-5)
58
+ self.norm3 = nn.LayerNorm(size, eps=1e-5)
59
+ self.dropout = nn.Dropout(dropout_rate)
60
+ self.normalize_before = normalize_before
61
+
62
+ def forward(
63
+ self,
64
+ tgt: torch.Tensor,
65
+ tgt_mask: torch.Tensor,
66
+ memory: torch.Tensor,
67
+ memory_mask: torch.Tensor,
68
+ cache: Optional[torch.Tensor] = None
69
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
70
+ """Compute decoded features.
71
+
72
+ Args:
73
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
74
+ tgt_mask (torch.Tensor): Mask for input tensor
75
+ (#batch, maxlen_out).
76
+ memory (torch.Tensor): Encoded memory
77
+ (#batch, maxlen_in, size).
78
+ memory_mask (torch.Tensor): Encoded memory mask
79
+ (#batch, maxlen_in).
80
+ cache (torch.Tensor): cached tensors.
81
+ (#batch, maxlen_out - 1, size).
82
+
83
+ Returns:
84
+ torch.Tensor: Output tensor (#batch, maxlen_out, size).
85
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
86
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
87
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
88
+
89
+ """
90
+ residual = tgt
91
+ if self.normalize_before:
92
+ tgt = self.norm1(tgt)
93
+
94
+ if cache is None:
95
+ tgt_q = tgt
96
+ tgt_q_mask = tgt_mask
97
+ else:
98
+ # compute only the last frame query keeping dim: max_time_out -> 1
99
+ assert cache.shape == (
100
+ tgt.shape[0],
101
+ tgt.shape[1] - 1,
102
+ self.size,
103
+ ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
104
+ tgt_q = tgt[:, -1:, :]
105
+ residual = residual[:, -1:, :]
106
+ tgt_q_mask = tgt_mask[:, -1:, :]
107
+
108
+ x = residual + self.dropout(
109
+ self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
110
+ if not self.normalize_before:
111
+ x = self.norm1(x)
112
+
113
+ if self.src_attn is not None:
114
+ residual = x
115
+ if self.normalize_before:
116
+ x = self.norm2(x)
117
+ x = residual + self.dropout(
118
+ self.src_attn(x, memory, memory, memory_mask)[0])
119
+ if not self.normalize_before:
120
+ x = self.norm2(x)
121
+
122
+ residual = x
123
+ if self.normalize_before:
124
+ x = self.norm3(x)
125
+ x = residual + self.dropout(self.feed_forward(x))
126
+ if not self.normalize_before:
127
+ x = self.norm3(x)
128
+
129
+ if cache is not None:
130
+ x = torch.cat([cache, x], dim=1)
131
+
132
+ return x, tgt_mask, memory, memory_mask
cosyvoice/transformer/embedding.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Positonal Encoding Module."""
17
+
18
+ import math
19
+ from typing import Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import numpy as np
24
+
25
+
26
+ class PositionalEncoding(torch.nn.Module):
27
+ """Positional encoding.
28
+
29
+ :param int d_model: embedding dim
30
+ :param float dropout_rate: dropout rate
31
+ :param int max_len: maximum input length
32
+
33
+ PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
34
+ PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
35
+ """
36
+
37
+ def __init__(self,
38
+ d_model: int,
39
+ dropout_rate: float,
40
+ max_len: int = 5000,
41
+ reverse: bool = False):
42
+ """Construct an PositionalEncoding object."""
43
+ super().__init__()
44
+ self.d_model = d_model
45
+ self.xscale = math.sqrt(self.d_model)
46
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
47
+ self.max_len = max_len
48
+
49
+ self.pe = torch.zeros(self.max_len, self.d_model)
50
+ position = torch.arange(0, self.max_len,
51
+ dtype=torch.float32).unsqueeze(1)
52
+ div_term = torch.exp(
53
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) *
54
+ -(math.log(10000.0) / self.d_model))
55
+ self.pe[:, 0::2] = torch.sin(position * div_term)
56
+ self.pe[:, 1::2] = torch.cos(position * div_term)
57
+ self.pe = self.pe.unsqueeze(0)
58
+
59
+ def forward(self,
60
+ x: torch.Tensor,
61
+ offset: Union[int, torch.Tensor] = 0) \
62
+ -> Tuple[torch.Tensor, torch.Tensor]:
63
+ """Add positional encoding.
64
+
65
+ Args:
66
+ x (torch.Tensor): Input. Its shape is (batch, time, ...)
67
+ offset (int, torch.tensor): position offset
68
+
69
+ Returns:
70
+ torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
71
+ torch.Tensor: for compatibility to RelPositionalEncoding
72
+ """
73
+
74
+ self.pe = self.pe.to(x.device)
75
+ pos_emb = self.position_encoding(offset, x.size(1), False)
76
+ x = x * self.xscale + pos_emb
77
+ return self.dropout(x), self.dropout(pos_emb)
78
+
79
+ def position_encoding(self,
80
+ offset: Union[int, torch.Tensor],
81
+ size: int,
82
+ apply_dropout: bool = True) -> torch.Tensor:
83
+ """ For getting encoding in a streaming fashion
84
+
85
+ Attention!!!!!
86
+ we apply dropout only once at the whole utterance level in a none
87
+ streaming way, but will call this function several times with
88
+ increasing input size in a streaming scenario, so the dropout will
89
+ be applied several times.
90
+
91
+ Args:
92
+ offset (int or torch.tensor): start offset
93
+ size (int): required size of position encoding
94
+
95
+ Returns:
96
+ torch.Tensor: Corresponding encoding
97
+ """
98
+ # How to subscript a Union type:
99
+ # https://github.com/pytorch/pytorch/issues/69434
100
+ if isinstance(offset, int):
101
+ assert offset + size <= self.max_len
102
+ pos_emb = self.pe[:, offset:offset + size]
103
+ elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
104
+ assert offset + size <= self.max_len
105
+ pos_emb = self.pe[:, offset:offset + size]
106
+ else: # for batched streaming decoding on GPU
107
+ assert torch.max(offset) + size <= self.max_len
108
+ index = offset.unsqueeze(1) + \
109
+ torch.arange(0, size).to(offset.device) # B X T
110
+ flag = index > 0
111
+ # remove negative offset
112
+ index = index * flag
113
+ pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
114
+
115
+ if apply_dropout:
116
+ pos_emb = self.dropout(pos_emb)
117
+ return pos_emb
118
+
119
+
120
+ class RelPositionalEncoding(PositionalEncoding):
121
+ """Relative positional encoding module.
122
+ See : Appendix B in https://arxiv.org/abs/1901.02860
123
+ Args:
124
+ d_model (int): Embedding dimension.
125
+ dropout_rate (float): Dropout rate.
126
+ max_len (int): Maximum input length.
127
+ """
128
+
129
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
130
+ """Initialize class."""
131
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
132
+
133
+ def forward(self,
134
+ x: torch.Tensor,
135
+ offset: Union[int, torch.Tensor] = 0) \
136
+ -> Tuple[torch.Tensor, torch.Tensor]:
137
+ """Compute positional encoding.
138
+ Args:
139
+ x (torch.Tensor): Input tensor (batch, time, `*`).
140
+ Returns:
141
+ torch.Tensor: Encoded tensor (batch, time, `*`).
142
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
143
+ """
144
+ self.pe = self.pe.to(x.device)
145
+ x = x * self.xscale
146
+ pos_emb = self.position_encoding(offset, x.size(1), False)
147
+ return self.dropout(x), self.dropout(pos_emb)
148
+
149
+
150
+ class WhisperPositionalEncoding(PositionalEncoding):
151
+ """ Sinusoids position encoding used in openai-whisper.encoder
152
+ """
153
+
154
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
155
+ super().__init__(d_model, dropout_rate, max_len)
156
+ self.xscale = 1.0
157
+ log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
158
+ inv_timescales = torch.exp(-log_timescale_increment *
159
+ torch.arange(d_model // 2))
160
+ scaled_time = torch.arange(max_len)[:, np.newaxis] * \
161
+ inv_timescales[np.newaxis, :]
162
+ pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
163
+ delattr(self, "pe")
164
+ self.register_buffer("pe", pe.unsqueeze(0))
165
+
166
+
167
+ class LearnablePositionalEncoding(PositionalEncoding):
168
+ """ Learnable position encoding used in openai-whisper.decoder
169
+ """
170
+
171
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
172
+ super().__init__(d_model, dropout_rate, max_len)
173
+ # NOTE(xcsong): overwrite self.pe & self.xscale
174
+ self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
175
+ self.xscale = 1.0
176
+
177
+
178
+ class NoPositionalEncoding(torch.nn.Module):
179
+ """ No position encoding
180
+ """
181
+
182
+ def __init__(self, d_model: int, dropout_rate: float):
183
+ super().__init__()
184
+ self.d_model = d_model
185
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
186
+
187
+ def forward(self,
188
+ x: torch.Tensor,
189
+ offset: Union[int, torch.Tensor] = 0) \
190
+ -> Tuple[torch.Tensor, torch.Tensor]:
191
+ """ Just return zero vector for interface compatibility
192
+ """
193
+ pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
194
+ return self.dropout(x), pos_emb
195
+
196
+ def position_encoding(self, offset: Union[int, torch.Tensor],
197
+ size: int) -> torch.Tensor:
198
+ return torch.zeros(1, size, self.d_model)
199
+
200
+
201
+ class EspnetRelPositionalEncoding(torch.nn.Module):
202
+ """Relative positional encoding module (new implementation).
203
+
204
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
205
+
206
+ See : Appendix B in https://arxiv.org/abs/1901.02860
207
+
208
+ Args:
209
+ d_model (int): Embedding dimension.
210
+ dropout_rate (float): Dropout rate.
211
+ max_len (int): Maximum input length.
212
+
213
+ """
214
+
215
+ def __init__(self, d_model, dropout_rate, max_len=5000):
216
+ """Construct an PositionalEncoding object."""
217
+ super(EspnetRelPositionalEncoding, self).__init__()
218
+ self.d_model = d_model
219
+ self.xscale = math.sqrt(self.d_model)
220
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
221
+ self.pe = None
222
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
223
+
224
+ def extend_pe(self, x):
225
+ """Reset the positional encodings."""
226
+ if self.pe is not None:
227
+ # self.pe contains both positive and negative parts
228
+ # the length of self.pe is 2 * input_len - 1
229
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
230
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
231
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
232
+ return
233
+ # Suppose `i` means to the position of query vecotr and `j` means the
234
+ # position of key vector. We use position relative positions when keys
235
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
236
+ pe_positive = torch.zeros(x.size(1), self.d_model)
237
+ pe_negative = torch.zeros(x.size(1), self.d_model)
238
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
239
+ div_term = torch.exp(
240
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
241
+ * -(math.log(10000.0) / self.d_model)
242
+ )
243
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
244
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
245
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
246
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
247
+
248
+ # Reserve the order of positive indices and concat both positive and
249
+ # negative indices. This is used to support the shifting trick
250
+ # as in https://arxiv.org/abs/1901.02860
251
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
252
+ pe_negative = pe_negative[1:].unsqueeze(0)
253
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
254
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
255
+
256
+ def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0):
257
+ """Add positional encoding.
258
+
259
+ Args:
260
+ x (torch.Tensor): Input tensor (batch, time, `*`).
261
+
262
+ Returns:
263
+ torch.Tensor: Encoded tensor (batch, time, `*`).
264
+
265
+ """
266
+ self.extend_pe(x)
267
+ x = x * self.xscale
268
+ pos_emb = self.position_encoding(size=x.size(1), offset=offset)
269
+ return self.dropout(x), self.dropout(pos_emb)
270
+
271
+ def position_encoding(self,
272
+ offset: Union[int, torch.Tensor],
273
+ size: int) -> torch.Tensor:
274
+ """ For getting encoding in a streaming fashion
275
+
276
+ Attention!!!!!
277
+ we apply dropout only once at the whole utterance level in a none
278
+ streaming way, but will call this function several times with
279
+ increasing input size in a streaming scenario, so the dropout will
280
+ be applied several times.
281
+
282
+ Args:
283
+ offset (int or torch.tensor): start offset
284
+ size (int): required size of position encoding
285
+
286
+ Returns:
287
+ torch.Tensor: Corresponding encoding
288
+ """
289
+ pos_emb = self.pe[
290
+ :,
291
+ self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size,
292
+ ]
293
+ return pos_emb
cosyvoice/transformer/encoder.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song ([email protected])
3
+ # 2024 Alibaba Inc (Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from ESPnet(https://github.com/espnet/espnet)
17
+ """Encoder definition."""
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ import torch.utils.checkpoint as ckpt
22
+
23
+ from cosyvoice.transformer.convolution import ConvolutionModule
24
+ from cosyvoice.transformer.encoder_layer import TransformerEncoderLayer
25
+ from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
26
+ from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
27
+ from cosyvoice.utils.class_utils import (
28
+ COSYVOICE_EMB_CLASSES,
29
+ COSYVOICE_SUBSAMPLE_CLASSES,
30
+ COSYVOICE_ATTENTION_CLASSES,
31
+ COSYVOICE_ACTIVATION_CLASSES,
32
+ )
33
+ from cosyvoice.utils.mask import make_pad_mask
34
+ from cosyvoice.utils.mask import add_optional_chunk_mask
35
+
36
+
37
+ class BaseEncoder(torch.nn.Module):
38
+
39
+ def __init__(
40
+ self,
41
+ input_size: int,
42
+ output_size: int = 256,
43
+ attention_heads: int = 4,
44
+ linear_units: int = 2048,
45
+ num_blocks: int = 6,
46
+ dropout_rate: float = 0.1,
47
+ positional_dropout_rate: float = 0.1,
48
+ attention_dropout_rate: float = 0.0,
49
+ input_layer: str = "conv2d",
50
+ pos_enc_layer_type: str = "abs_pos",
51
+ normalize_before: bool = True,
52
+ static_chunk_size: int = 0,
53
+ use_dynamic_chunk: bool = False,
54
+ global_cmvn: torch.nn.Module = None,
55
+ use_dynamic_left_chunk: bool = False,
56
+ gradient_checkpointing: bool = False,
57
+ ):
58
+ """
59
+ Args:
60
+ input_size (int): input dim
61
+ output_size (int): dimension of attention
62
+ attention_heads (int): the number of heads of multi head attention
63
+ linear_units (int): the hidden units number of position-wise feed
64
+ forward
65
+ num_blocks (int): the number of decoder blocks
66
+ dropout_rate (float): dropout rate
67
+ attention_dropout_rate (float): dropout rate in attention
68
+ positional_dropout_rate (float): dropout rate after adding
69
+ positional encoding
70
+ input_layer (str): input layer type.
71
+ optional [linear, conv2d, conv2d6, conv2d8]
72
+ pos_enc_layer_type (str): Encoder positional encoding layer type.
73
+ opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
74
+ normalize_before (bool):
75
+ True: use layer_norm before each sub-block of a layer.
76
+ False: use layer_norm after each sub-block of a layer.
77
+ static_chunk_size (int): chunk size for static chunk training and
78
+ decoding
79
+ use_dynamic_chunk (bool): whether use dynamic chunk size for
80
+ training or not, You can only use fixed chunk(chunk_size > 0)
81
+ or dyanmic chunk size(use_dynamic_chunk = True)
82
+ global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
83
+ use_dynamic_left_chunk (bool): whether use dynamic left chunk in
84
+ dynamic chunk training
85
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
86
+ gradient_checkpointing: rerunning a forward-pass segment for each
87
+ checkpointed segment during backward.
88
+ """
89
+ super().__init__()
90
+ self._output_size = output_size
91
+
92
+ self.global_cmvn = global_cmvn
93
+ self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
94
+ input_size,
95
+ output_size,
96
+ dropout_rate,
97
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
98
+ positional_dropout_rate),
99
+ )
100
+
101
+ self.normalize_before = normalize_before
102
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
103
+ self.static_chunk_size = static_chunk_size
104
+ self.use_dynamic_chunk = use_dynamic_chunk
105
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
106
+ self.gradient_checkpointing = gradient_checkpointing
107
+
108
+ def output_size(self) -> int:
109
+ return self._output_size
110
+
111
+ def forward(
112
+ self,
113
+ xs: torch.Tensor,
114
+ xs_lens: torch.Tensor,
115
+ decoding_chunk_size: int = 0,
116
+ num_decoding_left_chunks: int = -1,
117
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
118
+ """Embed positions in tensor.
119
+
120
+ Args:
121
+ xs: padded input tensor (B, T, D)
122
+ xs_lens: input length (B)
123
+ decoding_chunk_size: decoding chunk size for dynamic chunk
124
+ 0: default for training, use random dynamic chunk.
125
+ <0: for decoding, use full chunk.
126
+ >0: for decoding, use fixed chunk size as set.
127
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
128
+ the chunk size is decoding_chunk_size.
129
+ >=0: use num_decoding_left_chunks
130
+ <0: use all left chunks
131
+ Returns:
132
+ encoder output tensor xs, and subsampled masks
133
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
134
+ masks: torch.Tensor batch padding mask after subsample
135
+ (B, 1, T' ~= T/subsample_rate)
136
+ NOTE(xcsong):
137
+ We pass the `__call__` method of the modules instead of `forward` to the
138
+ checkpointing API because `__call__` attaches all the hooks of the module.
139
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
140
+ """
141
+ T = xs.size(1)
142
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
143
+ if self.global_cmvn is not None:
144
+ xs = self.global_cmvn(xs)
145
+ xs, pos_emb, masks = self.embed(xs, masks)
146
+ mask_pad = masks # (B, 1, T/subsample_rate)
147
+ chunk_masks = add_optional_chunk_mask(xs, masks,
148
+ self.use_dynamic_chunk,
149
+ self.use_dynamic_left_chunk,
150
+ decoding_chunk_size,
151
+ self.static_chunk_size,
152
+ num_decoding_left_chunks)
153
+ if self.gradient_checkpointing and self.training:
154
+ xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb,
155
+ mask_pad)
156
+ else:
157
+ xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
158
+ if self.normalize_before:
159
+ xs = self.after_norm(xs)
160
+ # Here we assume the mask is not changed in encoder layers, so just
161
+ # return the masks before encoder layers, and the masks will be used
162
+ # for cross attention with decoder later
163
+ return xs, masks
164
+
165
+ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
166
+ pos_emb: torch.Tensor,
167
+ mask_pad: torch.Tensor) -> torch.Tensor:
168
+ for layer in self.encoders:
169
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
170
+ return xs
171
+
172
+ @torch.jit.ignore(drop=True)
173
+ def forward_layers_checkpointed(self, xs: torch.Tensor,
174
+ chunk_masks: torch.Tensor,
175
+ pos_emb: torch.Tensor,
176
+ mask_pad: torch.Tensor) -> torch.Tensor:
177
+ for layer in self.encoders:
178
+ xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs,
179
+ chunk_masks, pos_emb,
180
+ mask_pad)
181
+ return xs
182
+
183
+ def forward_chunk(
184
+ self,
185
+ xs: torch.Tensor,
186
+ offset: int,
187
+ required_cache_size: int,
188
+ att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
189
+ cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
190
+ att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
191
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
192
+ """ Forward just one chunk
193
+
194
+ Args:
195
+ xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
196
+ where `time == (chunk_size - 1) * subsample_rate + \
197
+ subsample.right_context + 1`
198
+ offset (int): current offset in encoder output time stamp
199
+ required_cache_size (int): cache size required for next chunk
200
+ compuation
201
+ >=0: actual cache size
202
+ <0: means all history cache is required
203
+ att_cache (torch.Tensor): cache tensor for KEY & VALUE in
204
+ transformer/conformer attention, with shape
205
+ (elayers, head, cache_t1, d_k * 2), where
206
+ `head * d_k == hidden-dim` and
207
+ `cache_t1 == chunk_size * num_decoding_left_chunks`.
208
+ cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
209
+ (elayers, b=1, hidden-dim, cache_t2), where
210
+ `cache_t2 == cnn.lorder - 1`
211
+
212
+ Returns:
213
+ torch.Tensor: output of current input xs,
214
+ with shape (b=1, chunk_size, hidden-dim).
215
+ torch.Tensor: new attention cache required for next chunk, with
216
+ dynamic shape (elayers, head, ?, d_k * 2)
217
+ depending on required_cache_size.
218
+ torch.Tensor: new conformer cnn cache required for next chunk, with
219
+ same shape as the original cnn_cache.
220
+
221
+ """
222
+ assert xs.size(0) == 1
223
+ # tmp_masks is just for interface compatibility
224
+ tmp_masks = torch.ones(1,
225
+ xs.size(1),
226
+ device=xs.device,
227
+ dtype=torch.bool)
228
+ tmp_masks = tmp_masks.unsqueeze(1)
229
+ if self.global_cmvn is not None:
230
+ xs = self.global_cmvn(xs)
231
+ # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
232
+ xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
233
+ # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
234
+ elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
235
+ chunk_size = xs.size(1)
236
+ attention_key_size = cache_t1 + chunk_size
237
+ pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
238
+ size=attention_key_size)
239
+ if required_cache_size < 0:
240
+ next_cache_start = 0
241
+ elif required_cache_size == 0:
242
+ next_cache_start = attention_key_size
243
+ else:
244
+ next_cache_start = max(attention_key_size - required_cache_size, 0)
245
+ r_att_cache = []
246
+ r_cnn_cache = []
247
+ for i, layer in enumerate(self.encoders):
248
+ # NOTE(xcsong): Before layer.forward
249
+ # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
250
+ # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
251
+ xs, _, new_att_cache, new_cnn_cache = layer(
252
+ xs,
253
+ att_mask,
254
+ pos_emb,
255
+ att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
256
+ cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache)
257
+ # NOTE(xcsong): After layer.forward
258
+ # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
259
+ # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
260
+ r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
261
+ r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
262
+ if self.normalize_before:
263
+ xs = self.after_norm(xs)
264
+
265
+ # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
266
+ # ? may be larger than cache_t1, it depends on required_cache_size
267
+ r_att_cache = torch.cat(r_att_cache, dim=0)
268
+ # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
269
+ r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
270
+
271
+ return (xs, r_att_cache, r_cnn_cache)
272
+
273
+ def forward_chunk_by_chunk(
274
+ self,
275
+ xs: torch.Tensor,
276
+ decoding_chunk_size: int,
277
+ num_decoding_left_chunks: int = -1,
278
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
279
+ """ Forward input chunk by chunk with chunk_size like a streaming
280
+ fashion
281
+
282
+ Here we should pay special attention to computation cache in the
283
+ streaming style forward chunk by chunk. Three things should be taken
284
+ into account for computation in the current network:
285
+ 1. transformer/conformer encoder layers output cache
286
+ 2. convolution in conformer
287
+ 3. convolution in subsampling
288
+
289
+ However, we don't implement subsampling cache for:
290
+ 1. We can control subsampling module to output the right result by
291
+ overlapping input instead of cache left context, even though it
292
+ wastes some computation, but subsampling only takes a very
293
+ small fraction of computation in the whole model.
294
+ 2. Typically, there are several covolution layers with subsampling
295
+ in subsampling module, it is tricky and complicated to do cache
296
+ with different convolution layers with different subsampling
297
+ rate.
298
+ 3. Currently, nn.Sequential is used to stack all the convolution
299
+ layers in subsampling, we need to rewrite it to make it work
300
+ with cache, which is not prefered.
301
+ Args:
302
+ xs (torch.Tensor): (1, max_len, dim)
303
+ chunk_size (int): decoding chunk size
304
+ """
305
+ assert decoding_chunk_size > 0
306
+ # The model is trained by static or dynamic chunk
307
+ assert self.static_chunk_size > 0 or self.use_dynamic_chunk
308
+ subsampling = self.embed.subsampling_rate
309
+ context = self.embed.right_context + 1 # Add current frame
310
+ stride = subsampling * decoding_chunk_size
311
+ decoding_window = (decoding_chunk_size - 1) * subsampling + context
312
+ num_frames = xs.size(1)
313
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
314
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
315
+ outputs = []
316
+ offset = 0
317
+ required_cache_size = decoding_chunk_size * num_decoding_left_chunks
318
+
319
+ # Feed forward overlap input step by step
320
+ for cur in range(0, num_frames - context + 1, stride):
321
+ end = min(cur + decoding_window, num_frames)
322
+ chunk_xs = xs[:, cur:end, :]
323
+ (y, att_cache,
324
+ cnn_cache) = self.forward_chunk(chunk_xs, offset,
325
+ required_cache_size, att_cache,
326
+ cnn_cache)
327
+ outputs.append(y)
328
+ offset += y.size(1)
329
+ ys = torch.cat(outputs, 1)
330
+ masks = torch.ones((1, 1, ys.size(1)),
331
+ device=ys.device,
332
+ dtype=torch.bool)
333
+ return ys, masks
334
+
335
+
336
+ class TransformerEncoder(BaseEncoder):
337
+ """Transformer encoder module."""
338
+
339
+ def __init__(
340
+ self,
341
+ input_size: int,
342
+ output_size: int = 256,
343
+ attention_heads: int = 4,
344
+ linear_units: int = 2048,
345
+ num_blocks: int = 6,
346
+ dropout_rate: float = 0.1,
347
+ positional_dropout_rate: float = 0.1,
348
+ attention_dropout_rate: float = 0.0,
349
+ input_layer: str = "conv2d",
350
+ pos_enc_layer_type: str = "abs_pos",
351
+ normalize_before: bool = True,
352
+ static_chunk_size: int = 0,
353
+ use_dynamic_chunk: bool = False,
354
+ global_cmvn: torch.nn.Module = None,
355
+ use_dynamic_left_chunk: bool = False,
356
+ key_bias: bool = True,
357
+ selfattention_layer_type: str = "selfattn",
358
+ activation_type: str = "relu",
359
+ gradient_checkpointing: bool = False,
360
+ ):
361
+ """ Construct TransformerEncoder
362
+
363
+ See Encoder for the meaning of each parameter.
364
+ """
365
+ super().__init__(input_size, output_size, attention_heads,
366
+ linear_units, num_blocks, dropout_rate,
367
+ positional_dropout_rate, attention_dropout_rate,
368
+ input_layer, pos_enc_layer_type, normalize_before,
369
+ static_chunk_size, use_dynamic_chunk, global_cmvn,
370
+ use_dynamic_left_chunk, gradient_checkpointing)
371
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
372
+ self.encoders = torch.nn.ModuleList([
373
+ TransformerEncoderLayer(
374
+ output_size,
375
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](attention_heads,
376
+ output_size,
377
+ attention_dropout_rate,
378
+ key_bias),
379
+ PositionwiseFeedForward(output_size, linear_units,
380
+ dropout_rate, activation),
381
+ dropout_rate, normalize_before) for _ in range(num_blocks)
382
+ ])
383
+
384
+
385
+ class ConformerEncoder(BaseEncoder):
386
+ """Conformer encoder module."""
387
+
388
+ def __init__(
389
+ self,
390
+ input_size: int,
391
+ output_size: int = 256,
392
+ attention_heads: int = 4,
393
+ linear_units: int = 2048,
394
+ num_blocks: int = 6,
395
+ dropout_rate: float = 0.1,
396
+ positional_dropout_rate: float = 0.1,
397
+ attention_dropout_rate: float = 0.0,
398
+ input_layer: str = "conv2d",
399
+ pos_enc_layer_type: str = "rel_pos",
400
+ normalize_before: bool = True,
401
+ static_chunk_size: int = 0,
402
+ use_dynamic_chunk: bool = False,
403
+ global_cmvn: torch.nn.Module = None,
404
+ use_dynamic_left_chunk: bool = False,
405
+ positionwise_conv_kernel_size: int = 1,
406
+ macaron_style: bool = True,
407
+ selfattention_layer_type: str = "rel_selfattn",
408
+ activation_type: str = "swish",
409
+ use_cnn_module: bool = True,
410
+ cnn_module_kernel: int = 15,
411
+ causal: bool = False,
412
+ cnn_module_norm: str = "batch_norm",
413
+ key_bias: bool = True,
414
+ gradient_checkpointing: bool = False,
415
+ ):
416
+ """Construct ConformerEncoder
417
+
418
+ Args:
419
+ input_size to use_dynamic_chunk, see in BaseEncoder
420
+ positionwise_conv_kernel_size (int): Kernel size of positionwise
421
+ conv1d layer.
422
+ macaron_style (bool): Whether to use macaron style for
423
+ positionwise layer.
424
+ selfattention_layer_type (str): Encoder attention layer type,
425
+ the parameter has no effect now, it's just for configure
426
+ compatibility.
427
+ activation_type (str): Encoder activation function type.
428
+ use_cnn_module (bool): Whether to use convolution module.
429
+ cnn_module_kernel (int): Kernel size of convolution module.
430
+ causal (bool): whether to use causal convolution or not.
431
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
432
+ """
433
+ super().__init__(input_size, output_size, attention_heads,
434
+ linear_units, num_blocks, dropout_rate,
435
+ positional_dropout_rate, attention_dropout_rate,
436
+ input_layer, pos_enc_layer_type, normalize_before,
437
+ static_chunk_size, use_dynamic_chunk, global_cmvn,
438
+ use_dynamic_left_chunk, gradient_checkpointing)
439
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
440
+
441
+ # self-attention module definition
442
+ encoder_selfattn_layer_args = (
443
+ attention_heads,
444
+ output_size,
445
+ attention_dropout_rate,
446
+ key_bias,
447
+ )
448
+ # feed-forward module definition
449
+ positionwise_layer_args = (
450
+ output_size,
451
+ linear_units,
452
+ dropout_rate,
453
+ activation,
454
+ )
455
+ # convolution module definition
456
+ convolution_layer_args = (output_size, cnn_module_kernel, activation,
457
+ cnn_module_norm, causal)
458
+
459
+ self.encoders = torch.nn.ModuleList([
460
+ ConformerEncoderLayer(
461
+ output_size,
462
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
463
+ *encoder_selfattn_layer_args),
464
+ PositionwiseFeedForward(*positionwise_layer_args),
465
+ PositionwiseFeedForward(
466
+ *positionwise_layer_args) if macaron_style else None,
467
+ ConvolutionModule(
468
+ *convolution_layer_args) if use_cnn_module else None,
469
+ dropout_rate,
470
+ normalize_before,
471
+ ) for _ in range(num_blocks)
472
+ ])
cosyvoice/transformer/encoder_layer.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Encoder self-attention layer definition."""
17
+
18
+ from typing import Optional, Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+
24
+ class TransformerEncoderLayer(nn.Module):
25
+ """Encoder layer module.
26
+
27
+ Args:
28
+ size (int): Input dimension.
29
+ self_attn (torch.nn.Module): Self-attention module instance.
30
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
31
+ instance can be used as the argument.
32
+ feed_forward (torch.nn.Module): Feed-forward module instance.
33
+ `PositionwiseFeedForward`, instance can be used as the argument.
34
+ dropout_rate (float): Dropout rate.
35
+ normalize_before (bool):
36
+ True: use layer_norm before each sub-block.
37
+ False: to use layer_norm after each sub-block.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ size: int,
43
+ self_attn: torch.nn.Module,
44
+ feed_forward: torch.nn.Module,
45
+ dropout_rate: float,
46
+ normalize_before: bool = True,
47
+ ):
48
+ """Construct an EncoderLayer object."""
49
+ super().__init__()
50
+ self.self_attn = self_attn
51
+ self.feed_forward = feed_forward
52
+ self.norm1 = nn.LayerNorm(size, eps=1e-5)
53
+ self.norm2 = nn.LayerNorm(size, eps=1e-5)
54
+ self.dropout = nn.Dropout(dropout_rate)
55
+ self.size = size
56
+ self.normalize_before = normalize_before
57
+
58
+ def forward(
59
+ self,
60
+ x: torch.Tensor,
61
+ mask: torch.Tensor,
62
+ pos_emb: torch.Tensor,
63
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
64
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
65
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
66
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
67
+ """Compute encoded features.
68
+
69
+ Args:
70
+ x (torch.Tensor): (#batch, time, size)
71
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
72
+ (0, 0, 0) means fake mask.
73
+ pos_emb (torch.Tensor): just for interface compatibility
74
+ to ConformerEncoderLayer
75
+ mask_pad (torch.Tensor): does not used in transformer layer,
76
+ just for unified api with conformer.
77
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
78
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
79
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
80
+ (#batch=1, size, cache_t2), not used here, it's for interface
81
+ compatibility to ConformerEncoderLayer.
82
+ Returns:
83
+ torch.Tensor: Output tensor (#batch, time, size).
84
+ torch.Tensor: Mask tensor (#batch, time, time).
85
+ torch.Tensor: att_cache tensor,
86
+ (#batch=1, head, cache_t1 + time, d_k * 2).
87
+ torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
88
+
89
+ """
90
+ residual = x
91
+ if self.normalize_before:
92
+ x = self.norm1(x)
93
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
94
+ x = residual + self.dropout(x_att)
95
+ if not self.normalize_before:
96
+ x = self.norm1(x)
97
+
98
+ residual = x
99
+ if self.normalize_before:
100
+ x = self.norm2(x)
101
+ x = residual + self.dropout(self.feed_forward(x))
102
+ if not self.normalize_before:
103
+ x = self.norm2(x)
104
+
105
+ fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
106
+ return x, mask, new_att_cache, fake_cnn_cache
107
+
108
+
109
+ class ConformerEncoderLayer(nn.Module):
110
+ """Encoder layer module.
111
+ Args:
112
+ size (int): Input dimension.
113
+ self_attn (torch.nn.Module): Self-attention module instance.
114
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
115
+ instance can be used as the argument.
116
+ feed_forward (torch.nn.Module): Feed-forward module instance.
117
+ `PositionwiseFeedForward` instance can be used as the argument.
118
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
119
+ instance.
120
+ `PositionwiseFeedForward` instance can be used as the argument.
121
+ conv_module (torch.nn.Module): Convolution module instance.
122
+ `ConvlutionModule` instance can be used as the argument.
123
+ dropout_rate (float): Dropout rate.
124
+ normalize_before (bool):
125
+ True: use layer_norm before each sub-block.
126
+ False: use layer_norm after each sub-block.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ size: int,
132
+ self_attn: torch.nn.Module,
133
+ feed_forward: Optional[nn.Module] = None,
134
+ feed_forward_macaron: Optional[nn.Module] = None,
135
+ conv_module: Optional[nn.Module] = None,
136
+ dropout_rate: float = 0.1,
137
+ normalize_before: bool = True,
138
+ ):
139
+ """Construct an EncoderLayer object."""
140
+ super().__init__()
141
+ self.self_attn = self_attn
142
+ self.feed_forward = feed_forward
143
+ self.feed_forward_macaron = feed_forward_macaron
144
+ self.conv_module = conv_module
145
+ self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
146
+ self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
147
+ if feed_forward_macaron is not None:
148
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
149
+ self.ff_scale = 0.5
150
+ else:
151
+ self.ff_scale = 1.0
152
+ if self.conv_module is not None:
153
+ self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
154
+ self.norm_final = nn.LayerNorm(
155
+ size, eps=1e-5) # for the final output of the block
156
+ self.dropout = nn.Dropout(dropout_rate)
157
+ self.size = size
158
+ self.normalize_before = normalize_before
159
+
160
+ def forward(
161
+ self,
162
+ x: torch.Tensor,
163
+ mask: torch.Tensor,
164
+ pos_emb: torch.Tensor,
165
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
166
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
167
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
168
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
169
+ """Compute encoded features.
170
+
171
+ Args:
172
+ x (torch.Tensor): (#batch, time, size)
173
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
174
+ (0, 0, 0) means fake mask.
175
+ pos_emb (torch.Tensor): positional encoding, must not be None
176
+ for ConformerEncoderLayer.
177
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
178
+ (#batch, 1,time), (0, 0, 0) means fake mask.
179
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
180
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
181
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
182
+ (#batch=1, size, cache_t2)
183
+ Returns:
184
+ torch.Tensor: Output tensor (#batch, time, size).
185
+ torch.Tensor: Mask tensor (#batch, time, time).
186
+ torch.Tensor: att_cache tensor,
187
+ (#batch=1, head, cache_t1 + time, d_k * 2).
188
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
189
+ """
190
+
191
+ # whether to use macaron style
192
+ if self.feed_forward_macaron is not None:
193
+ residual = x
194
+ if self.normalize_before:
195
+ x = self.norm_ff_macaron(x)
196
+ x = residual + self.ff_scale * self.dropout(
197
+ self.feed_forward_macaron(x))
198
+ if not self.normalize_before:
199
+ x = self.norm_ff_macaron(x)
200
+
201
+ # multi-headed self-attention module
202
+ residual = x
203
+ if self.normalize_before:
204
+ x = self.norm_mha(x)
205
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
206
+ att_cache)
207
+ x = residual + self.dropout(x_att)
208
+ if not self.normalize_before:
209
+ x = self.norm_mha(x)
210
+
211
+ # convolution module
212
+ # Fake new cnn cache here, and then change it in conv_module
213
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
214
+ if self.conv_module is not None:
215
+ residual = x
216
+ if self.normalize_before:
217
+ x = self.norm_conv(x)
218
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
219
+ x = residual + self.dropout(x)
220
+
221
+ if not self.normalize_before:
222
+ x = self.norm_conv(x)
223
+
224
+ # feed forward module
225
+ residual = x
226
+ if self.normalize_before:
227
+ x = self.norm_ff(x)
228
+
229
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
230
+ if not self.normalize_before:
231
+ x = self.norm_ff(x)
232
+
233
+ if self.conv_module is not None:
234
+ x = self.norm_final(x)
235
+
236
+ return x, mask, new_att_cache, new_cnn_cache
cosyvoice/transformer/label_smoothing_loss.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Label smoothing module."""
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+
21
+ class LabelSmoothingLoss(nn.Module):
22
+ """Label-smoothing loss.
23
+
24
+ In a standard CE loss, the label's data distribution is:
25
+ [0,1,2] ->
26
+ [
27
+ [1.0, 0.0, 0.0],
28
+ [0.0, 1.0, 0.0],
29
+ [0.0, 0.0, 1.0],
30
+ ]
31
+
32
+ In the smoothing version CE Loss,some probabilities
33
+ are taken from the true label prob (1.0) and are divided
34
+ among other labels.
35
+
36
+ e.g.
37
+ smoothing=0.1
38
+ [0,1,2] ->
39
+ [
40
+ [0.9, 0.05, 0.05],
41
+ [0.05, 0.9, 0.05],
42
+ [0.05, 0.05, 0.9],
43
+ ]
44
+
45
+ Args:
46
+ size (int): the number of class
47
+ padding_idx (int): padding class id which will be ignored for loss
48
+ smoothing (float): smoothing rate (0.0 means the conventional CE)
49
+ normalize_length (bool):
50
+ normalize loss by sequence length if True
51
+ normalize loss by batch size if False
52
+ """
53
+
54
+ def __init__(self,
55
+ size: int,
56
+ padding_idx: int,
57
+ smoothing: float,
58
+ normalize_length: bool = False):
59
+ """Construct an LabelSmoothingLoss object."""
60
+ super(LabelSmoothingLoss, self).__init__()
61
+ self.criterion = nn.KLDivLoss(reduction="none")
62
+ self.padding_idx = padding_idx
63
+ self.confidence = 1.0 - smoothing
64
+ self.smoothing = smoothing
65
+ self.size = size
66
+ self.normalize_length = normalize_length
67
+
68
+ def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
69
+ """Compute loss between x and target.
70
+
71
+ The model outputs and data labels tensors are flatten to
72
+ (batch*seqlen, class) shape and a mask is applied to the
73
+ padding part which should not be calculated for loss.
74
+
75
+ Args:
76
+ x (torch.Tensor): prediction (batch, seqlen, class)
77
+ target (torch.Tensor):
78
+ target signal masked with self.padding_id (batch, seqlen)
79
+ Returns:
80
+ loss (torch.Tensor) : The KL loss, scalar float value
81
+ """
82
+ assert x.size(2) == self.size
83
+ batch_size = x.size(0)
84
+ x = x.view(-1, self.size)
85
+ target = target.view(-1)
86
+ # use zeros_like instead of torch.no_grad() for true_dist,
87
+ # since no_grad() can not be exported by JIT
88
+ true_dist = torch.zeros_like(x)
89
+ true_dist.fill_(self.smoothing / (self.size - 1))
90
+ ignore = target == self.padding_idx # (B,)
91
+ total = len(target) - ignore.sum().item()
92
+ target = target.masked_fill(ignore, 0) # avoid -1 index
93
+ true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
94
+ kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
95
+ denom = total if self.normalize_length else batch_size
96
+ return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
cosyvoice/transformer/positionwise_feed_forward.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Positionwise feed forward layer definition."""
16
+
17
+ import torch
18
+
19
+
20
+ class PositionwiseFeedForward(torch.nn.Module):
21
+ """Positionwise feed forward layer.
22
+
23
+ FeedForward are appied on each position of the sequence.
24
+ The output dim is same with the input dim.
25
+
26
+ Args:
27
+ idim (int): Input dimenstion.
28
+ hidden_units (int): The number of hidden units.
29
+ dropout_rate (float): Dropout rate.
30
+ activation (torch.nn.Module): Activation function
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ idim: int,
36
+ hidden_units: int,
37
+ dropout_rate: float,
38
+ activation: torch.nn.Module = torch.nn.ReLU(),
39
+ ):
40
+ """Construct a PositionwiseFeedForward object."""
41
+ super(PositionwiseFeedForward, self).__init__()
42
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
43
+ self.activation = activation
44
+ self.dropout = torch.nn.Dropout(dropout_rate)
45
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
46
+
47
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
48
+ """Forward function.
49
+
50
+ Args:
51
+ xs: input tensor (B, L, D)
52
+ Returns:
53
+ output tensor, (B, L, D)
54
+ """
55
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
56
+
57
+
58
+ class MoEFFNLayer(torch.nn.Module):
59
+ """
60
+ Mixture of expert with Positionwise feed forward layer
61
+ See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
62
+ The output dim is same with the input dim.
63
+
64
+ Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
65
+ https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
66
+ Args:
67
+ n_expert: number of expert.
68
+ n_expert_per_token: The actual number of experts used for each frame
69
+ idim (int): Input dimenstion.
70
+ hidden_units (int): The number of hidden units.
71
+ dropout_rate (float): Dropout rate.
72
+ activation (torch.nn.Module): Activation function
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ n_expert: int,
78
+ n_expert_per_token: int,
79
+ idim: int,
80
+ hidden_units: int,
81
+ dropout_rate: float,
82
+ activation: torch.nn.Module = torch.nn.ReLU(),
83
+ ):
84
+ super(MoEFFNLayer, self).__init__()
85
+ self.gate = torch.nn.Linear(idim, n_expert, bias=False)
86
+ self.experts = torch.nn.ModuleList(
87
+ PositionwiseFeedForward(idim, hidden_units, dropout_rate,
88
+ activation) for _ in range(n_expert))
89
+ self.n_expert_per_token = n_expert_per_token
90
+
91
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
92
+ """Foward function.
93
+ Args:
94
+ xs: input tensor (B, L, D)
95
+ Returns:
96
+ output tensor, (B, L, D)
97
+
98
+ """
99
+ B, L, D = xs.size(
100
+ ) # batch size, sequence length, embedding dimension (idim)
101
+ xs = xs.view(-1, D) # (B*L, D)
102
+ router = self.gate(xs) # (B*L, n_expert)
103
+ logits, indices = torch.topk(
104
+ router, self.n_expert_per_token
105
+ ) # probs:(B*L, n_expert), indices: (B*L, n_expert)
106
+ weights = torch.nn.functional.softmax(
107
+ logits, dim=1,
108
+ dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
109
+ output = torch.zeros_like(xs) # (B*L, D)
110
+ for i, expert in enumerate(self.experts):
111
+ mask = indices == i
112
+ batch_idx, ith_expert = torch.where(mask)
113
+ output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
114
+ xs[batch_idx])
115
+ return output.view(B, L, D)
cosyvoice/transformer/subsampling.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Subsampling layer definition."""
17
+
18
+ from typing import Tuple, Union
19
+
20
+ import torch
21
+
22
+
23
+ class BaseSubsampling(torch.nn.Module):
24
+
25
+ def __init__(self):
26
+ super().__init__()
27
+ self.right_context = 0
28
+ self.subsampling_rate = 1
29
+
30
+ def position_encoding(self, offset: Union[int, torch.Tensor],
31
+ size: int) -> torch.Tensor:
32
+ return self.pos_enc.position_encoding(offset, size)
33
+
34
+
35
+ class EmbedinigNoSubsampling(BaseSubsampling):
36
+ """Embedding input without subsampling
37
+ """
38
+
39
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
40
+ pos_enc_class: torch.nn.Module):
41
+ super().__init__()
42
+ self.embed = torch.nn.Embedding(idim, odim)
43
+ self.pos_enc = pos_enc_class
44
+
45
+ def forward(
46
+ self,
47
+ x: torch.Tensor,
48
+ x_mask: torch.Tensor,
49
+ offset: Union[int, torch.Tensor] = 0
50
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
51
+ """Input x.
52
+
53
+ Args:
54
+ x (torch.Tensor): Input tensor (#batch, time, idim).
55
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
56
+
57
+ Returns:
58
+ torch.Tensor: linear input tensor (#batch, time', odim),
59
+ where time' = time .
60
+ torch.Tensor: linear input mask (#batch, 1, time'),
61
+ where time' = time .
62
+
63
+ """
64
+ x = self.embed(x)
65
+ x, pos_emb = self.pos_enc(x, offset)
66
+ return x, pos_emb, x_mask
67
+
68
+
69
+ class LinearNoSubsampling(BaseSubsampling):
70
+ """Linear transform the input without subsampling
71
+
72
+ Args:
73
+ idim (int): Input dimension.
74
+ odim (int): Output dimension.
75
+ dropout_rate (float): Dropout rate.
76
+
77
+ """
78
+
79
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
80
+ pos_enc_class: torch.nn.Module):
81
+ """Construct an linear object."""
82
+ super().__init__()
83
+ self.out = torch.nn.Sequential(
84
+ torch.nn.Linear(idim, odim),
85
+ torch.nn.LayerNorm(odim, eps=1e-5),
86
+ torch.nn.Dropout(dropout_rate),
87
+ )
88
+ self.pos_enc = pos_enc_class
89
+ self.right_context = 0
90
+ self.subsampling_rate = 1
91
+
92
+ def forward(
93
+ self,
94
+ x: torch.Tensor,
95
+ x_mask: torch.Tensor,
96
+ offset: Union[int, torch.Tensor] = 0
97
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
98
+ """Input x.
99
+
100
+ Args:
101
+ x (torch.Tensor): Input tensor (#batch, time, idim).
102
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
103
+
104
+ Returns:
105
+ torch.Tensor: linear input tensor (#batch, time', odim),
106
+ where time' = time .
107
+ torch.Tensor: linear input mask (#batch, 1, time'),
108
+ where time' = time .
109
+
110
+ """
111
+ x = self.out(x)
112
+ x, pos_emb = self.pos_enc(x, offset)
113
+ return x, pos_emb, x_mask
114
+
115
+
116
+ class Conv1dSubsampling2(BaseSubsampling):
117
+ """Convolutional 1D subsampling (to 1/2 length).
118
+ It is designed for Whisper, ref:
119
+ https://github.com/openai/whisper/blob/main/whisper/model.py
120
+
121
+ Args:
122
+ idim (int): Input dimension.
123
+ odim (int): Output dimension.
124
+ dropout_rate (float): Dropout rate.
125
+
126
+ """
127
+
128
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
129
+ pos_enc_class: torch.nn.Module):
130
+ """Construct an Conv1dSubsampling2 object."""
131
+ super().__init__()
132
+ self.conv = torch.nn.Sequential(
133
+ torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
134
+ torch.nn.GELU(),
135
+ torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
136
+ torch.nn.GELU(),
137
+ )
138
+ self.pos_enc = pos_enc_class
139
+ # The right context for every conv layer is computed by:
140
+ # (kernel_size - 1) * frame_rate_of_this_layer
141
+ self.subsampling_rate = 2
142
+ # 4 = (3 - 1) * 1 + (3 - 1) * 1
143
+ self.right_context = 4
144
+
145
+ def forward(
146
+ self,
147
+ x: torch.Tensor,
148
+ x_mask: torch.Tensor,
149
+ offset: Union[int, torch.Tensor] = 0
150
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
151
+ """Subsample x.
152
+
153
+ Args:
154
+ x (torch.Tensor): Input tensor (#batch, time, idim).
155
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
156
+
157
+ Returns:
158
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
159
+ where time' = time // 2.
160
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
161
+ where time' = time // 2.
162
+ torch.Tensor: positional encoding
163
+
164
+ """
165
+ time = x.size(1)
166
+ x = x.transpose(1, 2) # (b, f, t)
167
+ x = self.conv(x)
168
+ x = x.transpose(1, 2) # (b, t, f)
169
+ x, pos_emb = self.pos_enc(x, offset)
170
+ return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
171
+
172
+
173
+ class Conv2dSubsampling4(BaseSubsampling):
174
+ """Convolutional 2D subsampling (to 1/4 length).
175
+
176
+ Args:
177
+ idim (int): Input dimension.
178
+ odim (int): Output dimension.
179
+ dropout_rate (float): Dropout rate.
180
+
181
+ """
182
+
183
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
184
+ pos_enc_class: torch.nn.Module):
185
+ """Construct an Conv2dSubsampling4 object."""
186
+ super().__init__()
187
+ self.conv = torch.nn.Sequential(
188
+ torch.nn.Conv2d(1, odim, 3, 2),
189
+ torch.nn.ReLU(),
190
+ torch.nn.Conv2d(odim, odim, 3, 2),
191
+ torch.nn.ReLU(),
192
+ )
193
+ self.out = torch.nn.Sequential(
194
+ torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
195
+ self.pos_enc = pos_enc_class
196
+ # The right context for every conv layer is computed by:
197
+ # (kernel_size - 1) * frame_rate_of_this_layer
198
+ self.subsampling_rate = 4
199
+ # 6 = (3 - 1) * 1 + (3 - 1) * 2
200
+ self.right_context = 6
201
+
202
+ def forward(
203
+ self,
204
+ x: torch.Tensor,
205
+ x_mask: torch.Tensor,
206
+ offset: Union[int, torch.Tensor] = 0
207
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
208
+ """Subsample x.
209
+
210
+ Args:
211
+ x (torch.Tensor): Input tensor (#batch, time, idim).
212
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
213
+
214
+ Returns:
215
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
216
+ where time' = time // 4.
217
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
218
+ where time' = time // 4.
219
+ torch.Tensor: positional encoding
220
+
221
+ """
222
+ x = x.unsqueeze(1) # (b, c=1, t, f)
223
+ x = self.conv(x)
224
+ b, c, t, f = x.size()
225
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
226
+ x, pos_emb = self.pos_enc(x, offset)
227
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
228
+
229
+
230
+ class Conv2dSubsampling6(BaseSubsampling):
231
+ """Convolutional 2D subsampling (to 1/6 length).
232
+ Args:
233
+ idim (int): Input dimension.
234
+ odim (int): Output dimension.
235
+ dropout_rate (float): Dropout rate.
236
+ pos_enc (torch.nn.Module): Custom position encoding layer.
237
+ """
238
+
239
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
240
+ pos_enc_class: torch.nn.Module):
241
+ """Construct an Conv2dSubsampling6 object."""
242
+ super().__init__()
243
+ self.conv = torch.nn.Sequential(
244
+ torch.nn.Conv2d(1, odim, 3, 2),
245
+ torch.nn.ReLU(),
246
+ torch.nn.Conv2d(odim, odim, 5, 3),
247
+ torch.nn.ReLU(),
248
+ )
249
+ self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
250
+ odim)
251
+ self.pos_enc = pos_enc_class
252
+ # 10 = (3 - 1) * 1 + (5 - 1) * 2
253
+ self.subsampling_rate = 6
254
+ self.right_context = 10
255
+
256
+ def forward(
257
+ self,
258
+ x: torch.Tensor,
259
+ x_mask: torch.Tensor,
260
+ offset: Union[int, torch.Tensor] = 0
261
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
262
+ """Subsample x.
263
+ Args:
264
+ x (torch.Tensor): Input tensor (#batch, time, idim).
265
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
266
+
267
+ Returns:
268
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
269
+ where time' = time // 6.
270
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
271
+ where time' = time // 6.
272
+ torch.Tensor: positional encoding
273
+ """
274
+ x = x.unsqueeze(1) # (b, c, t, f)
275
+ x = self.conv(x)
276
+ b, c, t, f = x.size()
277
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
278
+ x, pos_emb = self.pos_enc(x, offset)
279
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
280
+
281
+
282
+ class Conv2dSubsampling8(BaseSubsampling):
283
+ """Convolutional 2D subsampling (to 1/8 length).
284
+
285
+ Args:
286
+ idim (int): Input dimension.
287
+ odim (int): Output dimension.
288
+ dropout_rate (float): Dropout rate.
289
+
290
+ """
291
+
292
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
293
+ pos_enc_class: torch.nn.Module):
294
+ """Construct an Conv2dSubsampling8 object."""
295
+ super().__init__()
296
+ self.conv = torch.nn.Sequential(
297
+ torch.nn.Conv2d(1, odim, 3, 2),
298
+ torch.nn.ReLU(),
299
+ torch.nn.Conv2d(odim, odim, 3, 2),
300
+ torch.nn.ReLU(),
301
+ torch.nn.Conv2d(odim, odim, 3, 2),
302
+ torch.nn.ReLU(),
303
+ )
304
+ self.linear = torch.nn.Linear(
305
+ odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
306
+ self.pos_enc = pos_enc_class
307
+ self.subsampling_rate = 8
308
+ # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
309
+ self.right_context = 14
310
+
311
+ def forward(
312
+ self,
313
+ x: torch.Tensor,
314
+ x_mask: torch.Tensor,
315
+ offset: Union[int, torch.Tensor] = 0
316
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
317
+ """Subsample x.
318
+
319
+ Args:
320
+ x (torch.Tensor): Input tensor (#batch, time, idim).
321
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
322
+
323
+ Returns:
324
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
325
+ where time' = time // 8.
326
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
327
+ where time' = time // 8.
328
+ torch.Tensor: positional encoding
329
+ """
330
+ x = x.unsqueeze(1) # (b, c, t, f)
331
+ x = self.conv(x)
332
+ b, c, t, f = x.size()
333
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
334
+ x, pos_emb = self.pos_enc(x, offset)
335
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
336
+
337
+
338
+ class LegacyLinearNoSubsampling(BaseSubsampling):
339
+ """Linear transform the input without subsampling
340
+
341
+ Args:
342
+ idim (int): Input dimension.
343
+ odim (int): Output dimension.
344
+ dropout_rate (float): Dropout rate.
345
+
346
+ """
347
+
348
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
349
+ pos_enc_class: torch.nn.Module):
350
+ """Construct an linear object."""
351
+ super().__init__()
352
+ self.out = torch.nn.Sequential(
353
+ torch.nn.Linear(idim, odim),
354
+ torch.nn.LayerNorm(odim, eps=1e-5),
355
+ torch.nn.Dropout(dropout_rate),
356
+ torch.nn.ReLU(),
357
+ )
358
+ self.pos_enc = pos_enc_class
359
+ self.right_context = 0
360
+ self.subsampling_rate = 1
361
+
362
+ def forward(
363
+ self,
364
+ x: torch.Tensor,
365
+ x_mask: torch.Tensor,
366
+ offset: Union[int, torch.Tensor] = 0
367
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
368
+ """Input x.
369
+
370
+ Args:
371
+ x (torch.Tensor): Input tensor (#batch, time, idim).
372
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
373
+
374
+ Returns:
375
+ torch.Tensor: linear input tensor (#batch, time', odim),
376
+ where time' = time .
377
+ torch.Tensor: linear input mask (#batch, 1, time'),
378
+ where time' = time .
379
+
380
+ """
381
+ x = self.out(x)
382
+ x, pos_emb = self.pos_enc(x, offset)
383
+ return x, pos_emb, x_mask
cosyvoice/utils/__init__.py ADDED
File without changes
cosyvoice/utils/class_utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright [2023-11-28] <[email protected], Xingchen Song>
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import torch
16
+
17
+ from cosyvoice.transformer.activation import Swish
18
+ from cosyvoice.transformer.subsampling import (
19
+ LinearNoSubsampling,
20
+ EmbedinigNoSubsampling,
21
+ Conv1dSubsampling2,
22
+ Conv2dSubsampling4,
23
+ Conv2dSubsampling6,
24
+ Conv2dSubsampling8,
25
+ )
26
+ from cosyvoice.transformer.embedding import (PositionalEncoding,
27
+ RelPositionalEncoding,
28
+ WhisperPositionalEncoding,
29
+ LearnablePositionalEncoding,
30
+ NoPositionalEncoding)
31
+ from cosyvoice.transformer.attention import (MultiHeadedAttention,
32
+ RelPositionMultiHeadedAttention)
33
+ from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
34
+ from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
35
+
36
+
37
+ COSYVOICE_ACTIVATION_CLASSES = {
38
+ "hardtanh": torch.nn.Hardtanh,
39
+ "tanh": torch.nn.Tanh,
40
+ "relu": torch.nn.ReLU,
41
+ "selu": torch.nn.SELU,
42
+ "swish": getattr(torch.nn, "SiLU", Swish),
43
+ "gelu": torch.nn.GELU,
44
+ }
45
+
46
+ COSYVOICE_SUBSAMPLE_CLASSES = {
47
+ "linear": LinearNoSubsampling,
48
+ "linear_legacy": LegacyLinearNoSubsampling,
49
+ "embed": EmbedinigNoSubsampling,
50
+ "conv1d2": Conv1dSubsampling2,
51
+ "conv2d": Conv2dSubsampling4,
52
+ "conv2d6": Conv2dSubsampling6,
53
+ "conv2d8": Conv2dSubsampling8,
54
+ 'paraformer_dummy': torch.nn.Identity
55
+ }
56
+
57
+ COSYVOICE_EMB_CLASSES = {
58
+ "embed": PositionalEncoding,
59
+ "abs_pos": PositionalEncoding,
60
+ "rel_pos": RelPositionalEncoding,
61
+ "rel_pos_espnet": EspnetRelPositionalEncoding,
62
+ "no_pos": NoPositionalEncoding,
63
+ "abs_pos_whisper": WhisperPositionalEncoding,
64
+ "embed_learnable_pe": LearnablePositionalEncoding,
65
+ }
66
+
67
+ COSYVOICE_ATTENTION_CLASSES = {
68
+ "selfattn": MultiHeadedAttention,
69
+ "rel_selfattn": RelPositionMultiHeadedAttention,
70
+ }
cosyvoice/utils/common.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Unility functions for Transformer."""
17
+
18
+ from typing import List
19
+
20
+ import torch
21
+
22
+ IGNORE_ID = -1
23
+
24
+
25
+ def pad_list(xs: List[torch.Tensor], pad_value: int):
26
+ """Perform padding for the list of tensors.
27
+
28
+ Args:
29
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
30
+ pad_value (float): Value for padding.
31
+
32
+ Returns:
33
+ Tensor: Padded tensor (B, Tmax, `*`).
34
+
35
+ Examples:
36
+ >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
37
+ >>> x
38
+ [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
39
+ >>> pad_list(x, 0)
40
+ tensor([[1., 1., 1., 1.],
41
+ [1., 1., 0., 0.],
42
+ [1., 0., 0., 0.]])
43
+
44
+ """
45
+ max_len = max([len(item) for item in xs])
46
+ batchs = len(xs)
47
+ ndim = xs[0].ndim
48
+ if ndim == 1:
49
+ pad_res = torch.zeros(batchs,
50
+ max_len,
51
+ dtype=xs[0].dtype,
52
+ device=xs[0].device)
53
+ elif ndim == 2:
54
+ pad_res = torch.zeros(batchs,
55
+ max_len,
56
+ xs[0].shape[1],
57
+ dtype=xs[0].dtype,
58
+ device=xs[0].device)
59
+ elif ndim == 3:
60
+ pad_res = torch.zeros(batchs,
61
+ max_len,
62
+ xs[0].shape[1],
63
+ xs[0].shape[2],
64
+ dtype=xs[0].dtype,
65
+ device=xs[0].device)
66
+ else:
67
+ raise ValueError(f"Unsupported ndim: {ndim}")
68
+ pad_res.fill_(pad_value)
69
+ for i in range(batchs):
70
+ pad_res[i, :len(xs[i])] = xs[i]
71
+ return pad_res
72
+
73
+
74
+ def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
75
+ ignore_label: int) -> torch.Tensor:
76
+ """Calculate accuracy.
77
+
78
+ Args:
79
+ pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
80
+ pad_targets (LongTensor): Target label tensors (B, Lmax).
81
+ ignore_label (int): Ignore label id.
82
+
83
+ Returns:
84
+ torch.Tensor: Accuracy value (0.0 - 1.0).
85
+
86
+ """
87
+ pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1),
88
+ pad_outputs.size(1)).argmax(2)
89
+ mask = pad_targets != ignore_label
90
+ numerator = torch.sum(
91
+ pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
92
+ denominator = torch.sum(mask)
93
+ return (numerator / denominator).detach()
94
+
95
+
96
+ def get_padding(kernel_size, dilation=1):
97
+ return int((kernel_size * dilation - dilation) / 2)
98
+
99
+
100
+ def init_weights(m, mean=0.0, std=0.01):
101
+ classname = m.__class__.__name__
102
+ if classname.find("Conv") != -1:
103
+ m.weight.data.normal_(mean, std)
cosyvoice/utils/executor.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ from contextlib import nullcontext
18
+ import os
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+
23
+ from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join
24
+
25
+
26
+ class Executor:
27
+
28
+ def __init__(self):
29
+ self.step = 0
30
+ self.epoch = 0
31
+ self.rank = int(os.environ.get('RANK', 0))
32
+ self.device = torch.device('cuda:{}'.format(self.rank))
33
+
34
+ def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join):
35
+ ''' Train one epoch
36
+ '''
37
+
38
+ lr = optimizer.param_groups[0]['lr']
39
+ logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
40
+ logging.info('using accumulate grad, new batch size is {} times'
41
+ ' larger than before'.format(info_dict['accum_grad']))
42
+ # A context manager to be used in conjunction with an instance of
43
+ # torch.nn.parallel.DistributedDataParallel to be able to train
44
+ # with uneven inputs across participating processes.
45
+ model.train()
46
+ model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
47
+ with model_context():
48
+ for batch_idx, batch_dict in enumerate(train_data_loader):
49
+ info_dict["tag"] = "TRAIN"
50
+ info_dict["step"] = self.step
51
+ info_dict["epoch"] = self.epoch
52
+ info_dict["batch_idx"] = batch_idx
53
+ if cosyvoice_join(group_join, info_dict):
54
+ break
55
+
56
+ # Disable gradient synchronizations across DDP processes.
57
+ # Within this context, gradients will be accumulated on module
58
+ # variables, which will later be synchronized.
59
+ if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
60
+ context = model.no_sync
61
+ # Used for single gpu training and DDP gradient synchronization
62
+ # processes.
63
+ else:
64
+ context = nullcontext
65
+
66
+ with context():
67
+ info_dict = batch_forward(model, batch_dict, info_dict)
68
+ info_dict = batch_backward(model, info_dict)
69
+
70
+ info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict)
71
+ log_per_step(writer, info_dict)
72
+ # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
73
+ if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and (batch_idx + 1) % info_dict["accum_grad"] == 0:
74
+ dist.barrier()
75
+ self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
76
+ model.train()
77
+ if (batch_idx + 1) % info_dict["accum_grad"] == 0:
78
+ self.step += 1
79
+ dist.barrier()
80
+ self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
81
+
82
+ @torch.inference_mode()
83
+ def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
84
+ ''' Cross validation on
85
+ '''
86
+ logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank))
87
+ model.eval()
88
+ total_num_utts, total_loss_dict = 0, {} # avoid division by 0
89
+ for batch_idx, batch_dict in enumerate(cv_data_loader):
90
+ info_dict["tag"] = "CV"
91
+ info_dict["step"] = self.step
92
+ info_dict["epoch"] = self.epoch
93
+ info_dict["batch_idx"] = batch_idx
94
+
95
+ num_utts = len(batch_dict["utts"])
96
+ total_num_utts += num_utts
97
+
98
+ info_dict = batch_forward(model, batch_dict, info_dict)
99
+
100
+ for k, v in info_dict['loss_dict'].items():
101
+ if k not in total_loss_dict:
102
+ total_loss_dict[k] = []
103
+ total_loss_dict[k].append(v.item() * num_utts)
104
+ log_per_step(None, info_dict)
105
+ for k, v in total_loss_dict.items():
106
+ total_loss_dict[k] = sum(v) / total_num_utts
107
+ info_dict['loss_dict'] = total_loss_dict
108
+ log_per_save(writer, info_dict)
109
+ model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1)
110
+ save_model(model, model_name, info_dict)
cosyvoice/utils/file_utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ import torchaudio
18
+
19
+
20
+ def read_lists(list_file):
21
+ lists = []
22
+ with open(list_file, 'r', encoding='utf8') as fin:
23
+ for line in fin:
24
+ lists.append(line.strip())
25
+ return lists
26
+
27
+ def read_json_lists(list_file):
28
+ lists = read_lists(list_file)
29
+ results = {}
30
+ for fn in lists:
31
+ with open(fn, 'r', encoding='utf8') as fin:
32
+ results.update(json.load(fin))
33
+ return results
34
+
35
+ def load_wav(wav, target_sr):
36
+ speech, sample_rate = torchaudio.load(wav)
37
+ speech = speech.mean(dim=0, keepdim=True)
38
+ if sample_rate != target_sr:
39
+ assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
40
+ speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
41
+ return speech
cosyvoice/utils/frontend_utils.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
17
+
18
+ # whether contain chinese character
19
+ def contains_chinese(text):
20
+ return bool(chinese_char_pattern.search(text))
21
+
22
+
23
+ # replace special symbol
24
+ def replace_corner_mark(text):
25
+ text = text.replace('²', '平方')
26
+ text = text.replace('³', '立方')
27
+ return text
28
+
29
+
30
+ # remove meaningless symbol
31
+ def remove_bracket(text):
32
+ text = text.replace('(', '').replace(')', '')
33
+ text = text.replace('【', '').replace('】', '')
34
+ text = text.replace('`', '').replace('`', '')
35
+ text = text.replace("——", " ")
36
+ return text
37
+
38
+
39
+ # spell Arabic numerals
40
+ def spell_out_number(text: str, inflect_parser):
41
+ new_text = []
42
+ st = None
43
+ for i, c in enumerate(text):
44
+ if not c.isdigit():
45
+ if st is not None:
46
+ num_str = inflect_parser.number_to_words(text[st: i])
47
+ new_text.append(num_str)
48
+ st = None
49
+ new_text.append(c)
50
+ else:
51
+ if st is None:
52
+ st = i
53
+ if st is not None and st < len(text):
54
+ num_str = inflect_parser.number_to_words(text[st:])
55
+ new_text.append(num_str)
56
+ return ''.join(new_text)
57
+
58
+
59
+ # split paragrah logic:
60
+ # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
61
+ # 2. cal sentence len according to lang
62
+ # 3. split sentence according to puncatation
63
+ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
64
+ def calc_utt_length(_text: str):
65
+ if lang == "zh":
66
+ return len(_text)
67
+ else:
68
+ return len(tokenize(_text))
69
+
70
+ def should_merge(_text: str):
71
+ if lang == "zh":
72
+ return len(_text) < merge_len
73
+ else:
74
+ return len(tokenize(_text)) < merge_len
75
+
76
+ if lang == "zh":
77
+ pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
78
+ else:
79
+ pounc = ['.', '?', '!', ';', ':']
80
+ if comma_split:
81
+ pounc.extend([',', ','])
82
+ st = 0
83
+ utts = []
84
+ for i, c in enumerate(text):
85
+ if c in pounc:
86
+ if len(text[st: i]) > 0:
87
+ utts.append(text[st: i] + c)
88
+ if i + 1 < len(text) and text[i + 1] in ['"', '”']:
89
+ tmp = utts.pop(-1)
90
+ utts.append(tmp + text[i + 1])
91
+ st = i + 2
92
+ else:
93
+ st = i + 1
94
+ if len(utts) == 0:
95
+ if lang == "zh":
96
+ utts.append(text + '。')
97
+ else:
98
+ utts.append(text + '.')
99
+ final_utts = []
100
+ cur_utt = ""
101
+ for utt in utts:
102
+ if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
103
+ final_utts.append(cur_utt)
104
+ cur_utt = ""
105
+ cur_utt = cur_utt + utt
106
+ if len(cur_utt) > 0:
107
+ if should_merge(cur_utt) and len(final_utts) != 0:
108
+ final_utts[-1] = final_utts[-1] + cur_utt
109
+ else:
110
+ final_utts.append(cur_utt)
111
+
112
+ return final_utts
113
+
114
+
115
+ # remove blank between chinese character
116
+ def replace_blank(text: str):
117
+ out_str = []
118
+ for i, c in enumerate(text):
119
+ if c == " ":
120
+ if ((text[i + 1].isascii() and text[i + 1] != " ") and
121
+ (text[i - 1].isascii() and text[i - 1] != " ")):
122
+ out_str.append(c)
123
+ else:
124
+ out_str.append(c)
125
+ return "".join(out_str)
cosyvoice/utils/mask.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import torch
18
+ '''
19
+ def subsequent_mask(
20
+ size: int,
21
+ device: torch.device = torch.device("cpu"),
22
+ ) -> torch.Tensor:
23
+ """Create mask for subsequent steps (size, size).
24
+
25
+ This mask is used only in decoder which works in an auto-regressive mode.
26
+ This means the current step could only do attention with its left steps.
27
+
28
+ In encoder, fully attention is used when streaming is not necessary and
29
+ the sequence is not long. In this case, no attention mask is needed.
30
+
31
+ When streaming is need, chunk-based attention is used in encoder. See
32
+ subsequent_chunk_mask for the chunk-based attention mask.
33
+
34
+ Args:
35
+ size (int): size of mask
36
+ str device (str): "cpu" or "cuda" or torch.Tensor.device
37
+ dtype (torch.device): result dtype
38
+
39
+ Returns:
40
+ torch.Tensor: mask
41
+
42
+ Examples:
43
+ >>> subsequent_mask(3)
44
+ [[1, 0, 0],
45
+ [1, 1, 0],
46
+ [1, 1, 1]]
47
+ """
48
+ ret = torch.ones(size, size, device=device, dtype=torch.bool)
49
+ return torch.tril(ret)
50
+ '''
51
+
52
+
53
+ def subsequent_mask(
54
+ size: int,
55
+ device: torch.device = torch.device("cpu"),
56
+ ) -> torch.Tensor:
57
+ """Create mask for subsequent steps (size, size).
58
+
59
+ This mask is used only in decoder which works in an auto-regressive mode.
60
+ This means the current step could only do attention with its left steps.
61
+
62
+ In encoder, fully attention is used when streaming is not necessary and
63
+ the sequence is not long. In this case, no attention mask is needed.
64
+
65
+ When streaming is need, chunk-based attention is used in encoder. See
66
+ subsequent_chunk_mask for the chunk-based attention mask.
67
+
68
+ Args:
69
+ size (int): size of mask
70
+ str device (str): "cpu" or "cuda" or torch.Tensor.device
71
+ dtype (torch.device): result dtype
72
+
73
+ Returns:
74
+ torch.Tensor: mask
75
+
76
+ Examples:
77
+ >>> subsequent_mask(3)
78
+ [[1, 0, 0],
79
+ [1, 1, 0],
80
+ [1, 1, 1]]
81
+ """
82
+ arange = torch.arange(size, device=device)
83
+ mask = arange.expand(size, size)
84
+ arange = arange.unsqueeze(-1)
85
+ mask = mask <= arange
86
+ return mask
87
+
88
+
89
+ def subsequent_chunk_mask(
90
+ size: int,
91
+ chunk_size: int,
92
+ num_left_chunks: int = -1,
93
+ device: torch.device = torch.device("cpu"),
94
+ ) -> torch.Tensor:
95
+ """Create mask for subsequent steps (size, size) with chunk size,
96
+ this is for streaming encoder
97
+
98
+ Args:
99
+ size (int): size of mask
100
+ chunk_size (int): size of chunk
101
+ num_left_chunks (int): number of left chunks
102
+ <0: use full chunk
103
+ >=0: use num_left_chunks
104
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
105
+
106
+ Returns:
107
+ torch.Tensor: mask
108
+
109
+ Examples:
110
+ >>> subsequent_chunk_mask(4, 2)
111
+ [[1, 1, 0, 0],
112
+ [1, 1, 0, 0],
113
+ [1, 1, 1, 1],
114
+ [1, 1, 1, 1]]
115
+ """
116
+ ret = torch.zeros(size, size, device=device, dtype=torch.bool)
117
+ for i in range(size):
118
+ if num_left_chunks < 0:
119
+ start = 0
120
+ else:
121
+ start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
122
+ ending = min((i // chunk_size + 1) * chunk_size, size)
123
+ ret[i, start:ending] = True
124
+ return ret
125
+
126
+
127
+ def add_optional_chunk_mask(xs: torch.Tensor,
128
+ masks: torch.Tensor,
129
+ use_dynamic_chunk: bool,
130
+ use_dynamic_left_chunk: bool,
131
+ decoding_chunk_size: int,
132
+ static_chunk_size: int,
133
+ num_decoding_left_chunks: int,
134
+ enable_full_context: bool = True):
135
+ """ Apply optional mask for encoder.
136
+
137
+ Args:
138
+ xs (torch.Tensor): padded input, (B, L, D), L for max length
139
+ mask (torch.Tensor): mask for xs, (B, 1, L)
140
+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
141
+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
142
+ training.
143
+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
144
+ 0: default for training, use random dynamic chunk.
145
+ <0: for decoding, use full chunk.
146
+ >0: for decoding, use fixed chunk size as set.
147
+ static_chunk_size (int): chunk size for static chunk training/decoding
148
+ if it's greater than 0, if use_dynamic_chunk is true,
149
+ this parameter will be ignored
150
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
151
+ the chunk size is decoding_chunk_size.
152
+ >=0: use num_decoding_left_chunks
153
+ <0: use all left chunks
154
+ enable_full_context (bool):
155
+ True: chunk size is either [1, 25] or full context(max_len)
156
+ False: chunk size ~ U[1, 25]
157
+
158
+ Returns:
159
+ torch.Tensor: chunk mask of the input xs.
160
+ """
161
+ # Whether to use chunk mask or not
162
+ if use_dynamic_chunk:
163
+ max_len = xs.size(1)
164
+ if decoding_chunk_size < 0:
165
+ chunk_size = max_len
166
+ num_left_chunks = -1
167
+ elif decoding_chunk_size > 0:
168
+ chunk_size = decoding_chunk_size
169
+ num_left_chunks = num_decoding_left_chunks
170
+ else:
171
+ # chunk size is either [1, 25] or full context(max_len).
172
+ # Since we use 4 times subsampling and allow up to 1s(100 frames)
173
+ # delay, the maximum frame is 100 / 4 = 25.
174
+ chunk_size = torch.randint(1, max_len, (1, )).item()
175
+ num_left_chunks = -1
176
+ if chunk_size > max_len // 2 and enable_full_context:
177
+ chunk_size = max_len
178
+ else:
179
+ chunk_size = chunk_size % 25 + 1
180
+ if use_dynamic_left_chunk:
181
+ max_left_chunks = (max_len - 1) // chunk_size
182
+ num_left_chunks = torch.randint(0, max_left_chunks,
183
+ (1, )).item()
184
+ chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
185
+ num_left_chunks,
186
+ xs.device) # (L, L)
187
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
188
+ chunk_masks = masks & chunk_masks # (B, L, L)
189
+ elif static_chunk_size > 0:
190
+ num_left_chunks = num_decoding_left_chunks
191
+ chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
192
+ num_left_chunks,
193
+ xs.device) # (L, L)
194
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
195
+ chunk_masks = masks & chunk_masks # (B, L, L)
196
+ else:
197
+ chunk_masks = masks
198
+ return chunk_masks
199
+
200
+
201
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
202
+ """Make mask tensor containing indices of padded part.
203
+
204
+ See description of make_non_pad_mask.
205
+
206
+ Args:
207
+ lengths (torch.Tensor): Batch of lengths (B,).
208
+ Returns:
209
+ torch.Tensor: Mask tensor containing indices of padded part.
210
+
211
+ Examples:
212
+ >>> lengths = [5, 3, 2]
213
+ >>> make_pad_mask(lengths)
214
+ masks = [[0, 0, 0, 0 ,0],
215
+ [0, 0, 0, 1, 1],
216
+ [0, 0, 1, 1, 1]]
217
+ """
218
+ batch_size = lengths.size(0)
219
+ max_len = max_len if max_len > 0 else lengths.max().item()
220
+ seq_range = torch.arange(0,
221
+ max_len,
222
+ dtype=torch.int64,
223
+ device=lengths.device)
224
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
225
+ seq_length_expand = lengths.unsqueeze(-1)
226
+ mask = seq_range_expand >= seq_length_expand
227
+ return mask
cosyvoice/utils/scheduler.py ADDED
@@ -0,0 +1,739 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2
+ # 2022 Ximalaya Inc (Yuguang Yang)
3
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from ESPnet(https://github.com/espnet/espnet)
17
+ # NeMo(https://github.com/NVIDIA/NeMo)
18
+
19
+ from typing import Union
20
+
21
+ import math
22
+ import warnings
23
+ import torch
24
+ from torch.optim.lr_scheduler import _LRScheduler
25
+
26
+
27
+ class WarmupLR(_LRScheduler):
28
+ """The WarmupLR scheduler
29
+
30
+ This scheduler is almost same as NoamLR Scheduler except for following
31
+ difference:
32
+
33
+ NoamLR:
34
+ lr = optimizer.lr * model_size ** -0.5
35
+ * min(step ** -0.5, step * warmup_step ** -1.5)
36
+ WarmupLR:
37
+ lr = optimizer.lr * warmup_step ** 0.5
38
+ * min(step ** -0.5, step * warmup_step ** -1.5)
39
+
40
+ Note that the maximum lr equals to optimizer.lr in this scheduler.
41
+
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ optimizer: torch.optim.Optimizer,
47
+ warmup_steps: Union[int, float] = 25000,
48
+ last_epoch: int = -1,
49
+ ):
50
+ self.warmup_steps = warmup_steps
51
+
52
+ # __init__() must be invoked before setting field
53
+ # because step() is also invoked in __init__()
54
+ super().__init__(optimizer, last_epoch)
55
+
56
+ def __repr__(self):
57
+ return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
58
+
59
+ def get_lr(self):
60
+ step_num = self.last_epoch + 1
61
+ if self.warmup_steps == 0:
62
+ return [lr * step_num**-0.5 for lr in self.base_lrs]
63
+ else:
64
+ return [
65
+ lr * self.warmup_steps**0.5 *
66
+ min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
67
+ for lr in self.base_lrs
68
+ ]
69
+
70
+ def set_step(self, step: int):
71
+ self.last_epoch = step
72
+
73
+
74
+ class WarmupPolicy(_LRScheduler):
75
+ """Adds warmup kwargs and warmup logic to lr policy.
76
+ All arguments should be passed as kwargs for clarity,
77
+ Args:
78
+ warmup_steps: Number of training steps in warmup stage
79
+ warmup_ratio: Ratio of warmup steps to total steps
80
+ max_steps: Total number of steps while training or `None` for
81
+ infinite training
82
+ """
83
+
84
+ def __init__(self,
85
+ optimizer,
86
+ *,
87
+ warmup_steps=None,
88
+ warmup_ratio=None,
89
+ max_steps=None,
90
+ min_lr=0.0,
91
+ last_epoch=-1):
92
+ assert not (warmup_steps is not None and warmup_ratio is not None),\
93
+ "Either use particular number of step or ratio"
94
+ assert warmup_ratio is None or max_steps is not None, \
95
+ "If there is a ratio, there should be a total steps"
96
+
97
+ # It is necessary to assign all attributes *before* __init__,
98
+ # as class is wrapped by an inner class.
99
+ self.max_steps = max_steps
100
+ if warmup_steps is not None:
101
+ self.warmup_steps = warmup_steps
102
+ elif warmup_ratio is not None:
103
+ self.warmup_steps = int(warmup_ratio * max_steps)
104
+ else:
105
+ self.warmup_steps = 0
106
+
107
+ self.min_lr = min_lr
108
+ super().__init__(optimizer, last_epoch)
109
+
110
+ def get_lr(self):
111
+ if not self._get_lr_called_within_step:
112
+ warnings.warn(
113
+ "To get the last learning rate computed "
114
+ "by the scheduler, please use `get_last_lr()`.",
115
+ UserWarning,
116
+ stacklevel=2)
117
+
118
+ step = self.last_epoch
119
+
120
+ if step <= self.warmup_steps and self.warmup_steps > 0:
121
+ return self._get_warmup_lr(step)
122
+
123
+ if step > self.max_steps:
124
+ return [self.min_lr for _ in self.base_lrs]
125
+
126
+ return self._get_lr(step)
127
+
128
+ def _get_warmup_lr(self, step):
129
+ lr_val = (step + 1) / (self.warmup_steps + 1)
130
+ return [initial_lr * lr_val for initial_lr in self.base_lrs]
131
+
132
+ def _get_lr(self, step):
133
+ """Simple const lr policy"""
134
+ return self.base_lrs
135
+
136
+
137
+ class SquareRootConstantPolicy(_LRScheduler):
138
+ """Adds warmup kwargs and warmup logic to lr policy.
139
+ All arguments should be passed as kwargs for clarity,
140
+ Args:
141
+ warmup_steps: Number of training steps in warmup stage
142
+ warmup_ratio: Ratio of warmup steps to total steps
143
+ max_steps: Total number of steps while training or `None` for
144
+ infinite training
145
+ """
146
+
147
+ def __init__(self,
148
+ optimizer,
149
+ *,
150
+ constant_steps=None,
151
+ constant_ratio=None,
152
+ max_steps=None,
153
+ min_lr=0.0,
154
+ last_epoch=-1):
155
+ assert not (constant_steps is not None
156
+ and constant_ratio is not None), \
157
+ "Either use particular number of step or ratio"
158
+ assert constant_ratio is None or max_steps is not None, \
159
+ "If there is a ratio, there should be a total steps"
160
+
161
+ # It is necessary to assign all attributes *before* __init__,
162
+ # as class is wrapped by an inner class.
163
+ self.max_steps = max_steps
164
+ if constant_steps is not None:
165
+ self.constant_steps = constant_steps
166
+ elif constant_ratio is not None:
167
+ self.constant_steps = int(constant_ratio * max_steps)
168
+ else:
169
+ self.constant_steps = 0
170
+
171
+ self.constant_lr = 1 / (constant_steps**0.5)
172
+ self.min_lr = min_lr
173
+ super().__init__(optimizer, last_epoch)
174
+
175
+ def get_lr(self):
176
+ if not self._get_lr_called_within_step:
177
+ warnings.warn(
178
+ "To get the last learning rate computed "
179
+ "by the scheduler, please use `get_last_lr()`.",
180
+ UserWarning,
181
+ stacklevel=2)
182
+
183
+ step = self.last_epoch
184
+
185
+ if step <= self.constant_steps:
186
+ return [self.constant_lr for _ in self.base_lrs]
187
+
188
+ if step > self.max_steps:
189
+ return [self.min_lr for _ in self.base_lrs]
190
+
191
+ return self._get_lr(step)
192
+
193
+ def _get_lr(self, step):
194
+ """Simple const lr policy"""
195
+ return self.base_lrs
196
+
197
+
198
+ class WarmupHoldPolicy(WarmupPolicy):
199
+ """Variant of WarmupPolicy which maintains high
200
+ learning rate for a defined number of steps.
201
+ All arguments should be passed as kwargs for clarity,
202
+ Args:
203
+ warmup_steps: Number of training steps in warmup stage
204
+ warmup_ratio: Ratio of warmup steps to total steps
205
+ hold_steps: Number of training steps to
206
+ hold the learning rate after warm up
207
+ hold_ratio: Ratio of hold steps to total steps
208
+ max_steps: Total number of steps while training or `None` for
209
+ infinite training
210
+ """
211
+
212
+ def __init__(
213
+ self,
214
+ optimizer,
215
+ *,
216
+ warmup_steps=None,
217
+ warmup_ratio=None,
218
+ hold_steps=None,
219
+ hold_ratio=None,
220
+ max_steps=None,
221
+ min_lr=0.0,
222
+ last_epoch=-1,
223
+ ):
224
+ assert not (hold_steps is not None and hold_ratio is not None), \
225
+ "Either use particular number of step or ratio"
226
+ assert hold_ratio is None or max_steps is not None, \
227
+ "If there is a ratio, there should be a total steps"
228
+
229
+ self.min_lr = min_lr
230
+ self._last_warmup_lr = 0.0
231
+
232
+ # Necessary to duplicate as class attributes are hidden in inner class
233
+ self.max_steps = max_steps
234
+ if warmup_steps is not None:
235
+ self.warmup_steps = warmup_steps
236
+ elif warmup_ratio is not None:
237
+ self.warmup_steps = int(warmup_ratio * max_steps)
238
+ else:
239
+ self.warmup_steps = 0
240
+
241
+ if hold_steps is not None:
242
+ self.hold_steps = hold_steps + self.warmup_steps
243
+ elif hold_ratio is not None:
244
+ self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps
245
+ else:
246
+ self.hold_steps = 0
247
+
248
+ super().__init__(
249
+ optimizer,
250
+ warmup_steps=warmup_steps,
251
+ warmup_ratio=warmup_ratio,
252
+ max_steps=max_steps,
253
+ last_epoch=last_epoch,
254
+ min_lr=min_lr,
255
+ )
256
+
257
+ def get_lr(self):
258
+ if not self._get_lr_called_within_step:
259
+ warnings.warn(
260
+ "To get the last learning rate computed by the scheduler,"
261
+ " "
262
+ "please use `get_last_lr()`.",
263
+ UserWarning,
264
+ stacklevel=2)
265
+
266
+ step = self.last_epoch
267
+
268
+ # Warmup phase
269
+ if step <= self.warmup_steps and self.warmup_steps > 0:
270
+ return self._get_warmup_lr(step)
271
+
272
+ # Hold phase
273
+ if (step >= self.warmup_steps) and (step < self.hold_steps):
274
+ return self.base_lrs
275
+
276
+ if step > self.max_steps:
277
+ return [self.min_lr for _ in self.base_lrs]
278
+
279
+ return self._get_lr(step)
280
+
281
+
282
+ class WarmupAnnealHoldPolicy(_LRScheduler):
283
+ """Adds warmup kwargs and warmup logic to lr policy.
284
+ All arguments should be passed as kwargs for clarity,
285
+ Args:
286
+ warmup_steps: Number of training steps in warmup stage
287
+ warmup_ratio: Ratio of warmup steps to total steps
288
+ max_steps: Total number of steps while training or `None` for
289
+ infinite training
290
+ min_lr: Minimum lr to hold the learning rate after decay at.
291
+ constant_steps: Number of steps to keep lr constant at.
292
+ constant_ratio: Ratio of steps to keep lr constant.
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ optimizer,
298
+ *,
299
+ warmup_steps=None,
300
+ warmup_ratio=None,
301
+ constant_steps=None,
302
+ constant_ratio=None,
303
+ max_steps=None,
304
+ min_lr=0.0,
305
+ last_epoch=-1,
306
+ ):
307
+ assert not (warmup_steps is not None
308
+ and warmup_ratio is not None), \
309
+ "Either use particular number of step or ratio"
310
+ assert not (constant_steps is not None
311
+ and constant_ratio is not None), \
312
+ "Either use constant_steps or constant_ratio"
313
+ assert warmup_ratio is None or max_steps is not None, \
314
+ "If there is a ratio, there should be a total steps"
315
+
316
+ # It is necessary to assign all attributes *before* __init__,
317
+ # as class is wrapped by an inner class.
318
+ self.max_steps = max_steps
319
+
320
+ if warmup_steps is not None:
321
+ self.warmup_steps = warmup_steps
322
+ elif warmup_ratio is not None:
323
+ self.warmup_steps = int(warmup_ratio * max_steps)
324
+ else:
325
+ self.warmup_steps = 0
326
+
327
+ if constant_steps is not None:
328
+ self.constant_steps = constant_steps
329
+ elif constant_ratio is not None:
330
+ self.constant_steps = int(constant_ratio * max_steps)
331
+ else:
332
+ self.constant_steps = 0
333
+
334
+ self.decay_steps = max_steps - (self.constant_steps +
335
+ self.warmup_steps)
336
+
337
+ self.min_lr = min_lr
338
+ super().__init__(optimizer, last_epoch)
339
+
340
+ def get_lr(self):
341
+ if not self._get_lr_called_within_step:
342
+ warnings.warn(
343
+ "To get the last learning rate computed "
344
+ "by the scheduler, please use `get_last_lr()`.",
345
+ UserWarning,
346
+ stacklevel=2)
347
+
348
+ step = self.last_epoch
349
+
350
+ # Warmup steps
351
+ if self.warmup_steps > 0 and step <= self.warmup_steps:
352
+ return self._get_warmup_lr(step)
353
+
354
+ # Constant steps after warmup and decay
355
+ if self.constant_steps > 0 and (
356
+ self.warmup_steps + self.decay_steps) < step <= self.max_steps:
357
+ return self._get_constant_lr(step)
358
+
359
+ # Min lr after max steps of updates
360
+ if step > self.max_steps:
361
+ return [self.min_lr for _ in self.base_lrs]
362
+
363
+ return self._get_lr(step)
364
+
365
+ def _get_warmup_lr(self, step):
366
+ lr_val = (step + 1) / (self.warmup_steps + 1)
367
+ return [initial_lr * lr_val for initial_lr in self.base_lrs]
368
+
369
+ def _get_constant_lr(self, step):
370
+ return [self.min_lr for _ in self.base_lrs]
371
+
372
+ def _get_lr(self, step):
373
+ """Simple const lr policy"""
374
+ return self.base_lrs
375
+
376
+
377
+ def _squareroot_annealing(initial_lr, step, max_steps, min_lr):
378
+ mult = ((max_steps - step) / max_steps)**0.5
379
+ out_lr = initial_lr * mult
380
+ out_lr = max(out_lr, min_lr)
381
+ return out_lr
382
+
383
+
384
+ def _square_annealing(initial_lr, step, max_steps, min_lr):
385
+ mult = ((max_steps - step) / max_steps)**2
386
+ out_lr = initial_lr * mult
387
+ out_lr = max(out_lr, min_lr)
388
+ return out_lr
389
+
390
+
391
+ def _cosine_annealing(initial_lr, step, max_steps, min_lr):
392
+ mult = 0.5 * (1 + math.cos(math.pi * step / max_steps))
393
+ out_lr = (initial_lr - min_lr) * mult + min_lr
394
+ return out_lr
395
+
396
+
397
+ def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step,
398
+ decay_steps, min_lr):
399
+ assert max_lr > min_lr
400
+ # Use linear warmup for the initial part.
401
+ if warmup_steps > 0 and step <= warmup_steps:
402
+ return max_lr * float(step) / float(warmup_steps)
403
+
404
+ # For any steps larger than `decay_steps`, use `min_lr`.
405
+ if step > warmup_steps + decay_steps:
406
+ return min_lr
407
+
408
+ # If we are done with the warmup period, use the decay style.
409
+ num_steps_ = step - warmup_steps
410
+ decay_steps_ = decay_steps
411
+ decay_ratio = float(num_steps_) / float(decay_steps_)
412
+ assert decay_ratio >= 0.0
413
+ assert decay_ratio <= 1.0
414
+ delta_lr = max_lr - min_lr
415
+
416
+ coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
417
+
418
+ return min_lr + coeff * delta_lr
419
+
420
+
421
+ def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle):
422
+ if cycle:
423
+ multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps)
424
+ decay_steps *= multiplier
425
+ else:
426
+ step = min(step, decay_steps)
427
+ p = step / decay_steps
428
+ lr = (initial_lr - min_lr) * math.pow(1.0 - p, power)
429
+ lr += min_lr
430
+ return lr
431
+
432
+
433
+ def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps,
434
+ decay_rate, min_lr):
435
+ # hold_steps = total number of steps
436
+ # to hold the LR, not the warmup + hold steps.
437
+ T_warmup_decay = max(1, warmup_steps**decay_rate)
438
+ T_hold_decay = max(1, (step - hold_steps)**decay_rate)
439
+ lr = (initial_lr * T_warmup_decay) / T_hold_decay
440
+ lr = max(lr, min_lr)
441
+ return lr
442
+
443
+
444
+ class SquareAnnealing(WarmupPolicy):
445
+
446
+ def __init__(self,
447
+ optimizer,
448
+ *,
449
+ max_steps,
450
+ min_lr=1e-5,
451
+ last_epoch=-1,
452
+ **kwargs):
453
+ super().__init__(optimizer=optimizer,
454
+ max_steps=max_steps,
455
+ last_epoch=last_epoch,
456
+ min_lr=min_lr,
457
+ **kwargs)
458
+
459
+ def _get_lr(self, step):
460
+ new_lrs = [
461
+ _square_annealing(
462
+ initial_lr=initial_lr,
463
+ step=step - self.warmup_steps,
464
+ max_steps=self.max_steps - self.warmup_steps,
465
+ min_lr=self.min_lr,
466
+ ) for initial_lr in self.base_lrs
467
+ ]
468
+ return new_lrs
469
+
470
+
471
+ class SquareRootAnnealing(WarmupPolicy):
472
+
473
+ def __init__(self,
474
+ optimizer,
475
+ *,
476
+ max_steps,
477
+ min_lr=0,
478
+ last_epoch=-1,
479
+ **kwargs):
480
+ super().__init__(optimizer=optimizer,
481
+ max_steps=max_steps,
482
+ last_epoch=last_epoch,
483
+ min_lr=min_lr,
484
+ **kwargs)
485
+
486
+ def _get_lr(self, step):
487
+ new_lrs = [
488
+ _squareroot_annealing(initial_lr=initial_lr,
489
+ step=step,
490
+ max_steps=self.max_steps,
491
+ min_lr=self.min_lr)
492
+ for initial_lr in self.base_lrs
493
+ ]
494
+ return new_lrs
495
+
496
+
497
+ class CosineAnnealing(WarmupAnnealHoldPolicy):
498
+
499
+ def __init__(self,
500
+ optimizer,
501
+ *,
502
+ max_steps,
503
+ min_lr=0,
504
+ last_epoch=-1,
505
+ **kwargs):
506
+ super().__init__(optimizer=optimizer,
507
+ max_steps=max_steps,
508
+ last_epoch=last_epoch,
509
+ min_lr=min_lr,
510
+ **kwargs)
511
+
512
+ def _get_lr(self, step):
513
+ for initial_lr in self.base_lrs:
514
+ if initial_lr < self.min_lr:
515
+ raise ValueError(
516
+ f"{self} received an initial learning rate "
517
+ f"that was lower than the minimum learning rate.")
518
+
519
+ if self.constant_steps is None or self.constant_steps == 0:
520
+ new_lrs = [
521
+ _cosine_annealing(
522
+ initial_lr=initial_lr,
523
+ step=step - self.warmup_steps,
524
+ max_steps=self.max_steps - self.warmup_steps,
525
+ min_lr=self.min_lr,
526
+ ) for initial_lr in self.base_lrs
527
+ ]
528
+ else:
529
+ new_lrs = self._get_linear_warmup_with_cosine_annealing_lr(step)
530
+ return new_lrs
531
+
532
+ def _get_warmup_lr(self, step):
533
+ if self.constant_steps is None or self.constant_steps == 0:
534
+ return super()._get_warmup_lr(step)
535
+ else:
536
+ # Use linear warmup for the initial part.
537
+ return self._get_linear_warmup_with_cosine_annealing_lr(step)
538
+
539
+ def _get_constant_lr(self, step):
540
+ # Only called when `constant_steps` > 0.
541
+ return self._get_linear_warmup_with_cosine_annealing_lr(step)
542
+
543
+ def _get_linear_warmup_with_cosine_annealing_lr(self, step):
544
+ # Cosine Schedule for Megatron LM,
545
+ # slightly different warmup schedule + constant LR at the end.
546
+ new_lrs = [
547
+ _linear_warmup_with_cosine_annealing(
548
+ max_lr=self.base_lrs[0],
549
+ warmup_steps=self.warmup_steps,
550
+ step=step,
551
+ decay_steps=self.decay_steps,
552
+ min_lr=self.min_lr,
553
+ ) for _ in self.base_lrs
554
+ ]
555
+ return new_lrs
556
+
557
+
558
+ class NoamAnnealing(_LRScheduler):
559
+
560
+ def __init__(self,
561
+ optimizer,
562
+ *,
563
+ d_model,
564
+ warmup_steps=None,
565
+ warmup_ratio=None,
566
+ max_steps=None,
567
+ min_lr=0.0,
568
+ last_epoch=-1):
569
+ self._normalize = d_model**(-0.5)
570
+ assert not (warmup_steps is not None
571
+ and warmup_ratio is not None), \
572
+ "Either use particular number of step or ratio"
573
+ assert warmup_ratio is None or max_steps is not None, \
574
+ "If there is a ratio, there should be a total steps"
575
+
576
+ # It is necessary to assign all attributes *before* __init__,
577
+ # as class is wrapped by an inner class.
578
+ self.max_steps = max_steps
579
+ if warmup_steps is not None:
580
+ self.warmup_steps = warmup_steps
581
+ elif warmup_ratio is not None:
582
+ self.warmup_steps = int(warmup_ratio * max_steps)
583
+ else:
584
+ self.warmup_steps = 0
585
+
586
+ self.min_lr = min_lr
587
+ super().__init__(optimizer, last_epoch)
588
+
589
+ def get_lr(self):
590
+ if not self._get_lr_called_within_step:
591
+ warnings.warn(
592
+ "To get the last learning rate computed "
593
+ "by the scheduler, please use `get_last_lr()`.",
594
+ UserWarning,
595
+ stacklevel=2)
596
+
597
+ step = max(1, self.last_epoch)
598
+
599
+ for initial_lr in self.base_lrs:
600
+ if initial_lr < self.min_lr:
601
+ raise ValueError(
602
+ f"{self} received an initial learning rate "
603
+ f"that was lower than the minimum learning rate.")
604
+
605
+ new_lrs = [
606
+ self._noam_annealing(initial_lr=initial_lr, step=step)
607
+ for initial_lr in self.base_lrs
608
+ ]
609
+ return new_lrs
610
+
611
+ def _noam_annealing(self, initial_lr, step):
612
+ if self.warmup_steps > 0:
613
+ mult = self._normalize * min(step**(-0.5),
614
+ step * (self.warmup_steps**(-1.5)))
615
+ else:
616
+ mult = self._normalize * step**(-0.5)
617
+
618
+ out_lr = initial_lr * mult
619
+ if step > self.warmup_steps:
620
+ out_lr = max(out_lr, self.min_lr)
621
+ return out_lr
622
+
623
+
624
+ class NoamHoldAnnealing(WarmupHoldPolicy):
625
+
626
+ def __init__(self,
627
+ optimizer,
628
+ *,
629
+ max_steps,
630
+ decay_rate=0.5,
631
+ min_lr=0.0,
632
+ last_epoch=-1,
633
+ **kwargs):
634
+ """
635
+ From Nemo:
636
+ Implementation of the Noam Hold Annealing policy
637
+ from the SqueezeFormer paper.
638
+
639
+ Unlike NoamAnnealing, the peak learning rate
640
+ can be explicitly set for this scheduler.
641
+ The schedule first performs linear warmup,
642
+ then holds the peak LR, then decays with some schedule for
643
+ the remainder of the steps.
644
+ Therefore the min-lr is still dependent
645
+ on the hyper parameters selected.
646
+
647
+ It's schedule is determined by three factors-
648
+
649
+ Warmup Steps: Initial stage, where linear warmup
650
+ occurs uptil the peak LR is reached. Unlike NoamAnnealing,
651
+ the peak LR is explicitly stated here instead of a scaling factor.
652
+
653
+ Hold Steps: Intermediate stage, where the peak LR
654
+ is maintained for some number of steps. In this region,
655
+ the high peak LR allows the model to converge faster
656
+ if training is stable. However the high LR
657
+ may also cause instability during training.
658
+ Should usually be a significant fraction of training
659
+ steps (around 30-40% of the entire training steps).
660
+
661
+ Decay Steps: Final stage, where the LR rapidly decays
662
+ with some scaling rate (set by decay rate).
663
+ To attain Noam decay, use 0.5,
664
+ for Squeezeformer recommended decay, use 1.0.
665
+ The fast decay after prolonged high LR during
666
+ hold phase allows for rapid convergence.
667
+
668
+ References:
669
+ - [Squeezeformer:
670
+ An Efficient Transformer for Automatic Speech Recognition]
671
+ (https://arxiv.org/abs/2206.00888)
672
+
673
+ Args:
674
+ optimizer: Pytorch compatible Optimizer object.
675
+ warmup_steps: Number of training steps in warmup stage
676
+ warmup_ratio: Ratio of warmup steps to total steps
677
+ hold_steps: Number of training steps to
678
+ hold the learning rate after warm up
679
+ hold_ratio: Ratio of hold steps to total steps
680
+ max_steps: Total number of steps while training or `None` for
681
+ infinite training
682
+ decay_rate: Float value describing the polynomial decay
683
+ after the hold period. Default value
684
+ of 0.5 corresponds to Noam decay.
685
+ min_lr: Minimum learning rate.
686
+ """
687
+ self.decay_rate = decay_rate
688
+ super().__init__(optimizer=optimizer,
689
+ max_steps=max_steps,
690
+ last_epoch=last_epoch,
691
+ min_lr=min_lr,
692
+ **kwargs)
693
+
694
+ def _get_lr(self, step):
695
+ if self.warmup_steps is None or self.warmup_steps == 0:
696
+ raise ValueError(
697
+ "Noam scheduler cannot be used without warmup steps")
698
+
699
+ if self.hold_steps > 0:
700
+ hold_steps = self.hold_steps - self.warmup_steps
701
+ else:
702
+ hold_steps = 0
703
+
704
+ new_lrs = [
705
+ _noam_hold_annealing(
706
+ initial_lr,
707
+ step=step,
708
+ warmup_steps=self.warmup_steps,
709
+ hold_steps=hold_steps,
710
+ decay_rate=self.decay_rate,
711
+ min_lr=self.min_lr,
712
+ ) for initial_lr in self.base_lrs
713
+ ]
714
+ return new_lrs
715
+
716
+ def set_step(self, step: int):
717
+ self.last_epoch = step
718
+
719
+
720
+ class ConstantLR(_LRScheduler):
721
+ """The ConstantLR scheduler
722
+
723
+ This scheduler keeps a constant lr
724
+
725
+ """
726
+
727
+ def __init__(
728
+ self,
729
+ optimizer: torch.optim.Optimizer,
730
+ ):
731
+ # __init__() must be invoked before setting field
732
+ # because step() is also invoked in __init__()
733
+ super().__init__(optimizer)
734
+
735
+ def get_lr(self):
736
+ return self.base_lrs
737
+
738
+ def set_step(self, step: int):
739
+ self.last_epoch = step
cosyvoice/utils/train_utils.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ # 2023 Horizon Inc. (authors: Xingchen Song)
3
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from contextlib import nullcontext
18
+ import logging
19
+ import os
20
+ import torch
21
+ import json
22
+ import re
23
+ import datetime
24
+ import yaml
25
+
26
+ import deepspeed
27
+ import torch.optim as optim
28
+ import torch.distributed as dist
29
+
30
+ from torch.utils.tensorboard import SummaryWriter
31
+ from torch.utils.data import DataLoader
32
+ from torch.nn.utils import clip_grad_norm_
33
+
34
+ from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live
35
+
36
+ from cosyvoice.dataset.dataset import Dataset
37
+ from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing, ConstantLR
38
+
39
+
40
+ def init_distributed(args):
41
+ world_size = int(os.environ.get('WORLD_SIZE', 1))
42
+ local_rank = int(os.environ.get('LOCAL_RANK', 0))
43
+ rank = int(os.environ.get('RANK', 0))
44
+ logging.info('training on multiple gpus, this gpu {}'.format(local_rank) +
45
+ ', rank {}, world_size {}'.format(rank, world_size))
46
+ if args.train_engine == 'torch_ddp':
47
+ torch.cuda.set_device(local_rank)
48
+ dist.init_process_group(args.dist_backend)
49
+ else:
50
+ deepspeed.init_distributed(dist_backend=args.dist_backend)
51
+ return world_size, local_rank, rank
52
+
53
+
54
+ def init_dataset_and_dataloader(args, configs):
55
+ train_dataset = Dataset(args.train_data, data_pipeline=configs['data_pipeline'], mode='train', shuffle=True, partition=True)
56
+ cv_dataset = Dataset(args.cv_data, data_pipeline=configs['data_pipeline'], mode='train', shuffle=False, partition=False)
57
+
58
+ # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
59
+ train_data_loader = DataLoader(train_dataset,
60
+ batch_size=None,
61
+ pin_memory=args.pin_memory,
62
+ num_workers=args.num_workers,
63
+ prefetch_factor=args.prefetch)
64
+ cv_data_loader = DataLoader(cv_dataset,
65
+ batch_size=None,
66
+ pin_memory=args.pin_memory,
67
+ num_workers=args.num_workers,
68
+ prefetch_factor=args.prefetch)
69
+ return train_dataset, cv_dataset, train_data_loader, cv_data_loader
70
+
71
+
72
+
73
+ def check_modify_and_save_config(args, configs):
74
+ if args.train_engine == "torch_ddp":
75
+ configs['train_conf']["dtype"] = 'fp32'
76
+ else:
77
+ with open(args.deepspeed_config, 'r') as fin:
78
+ ds_configs = json.load(fin)
79
+ if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]:
80
+ configs['train_conf']["dtype"] = "fp16"
81
+ elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]:
82
+ configs['train_conf']["dtype"] = "bf16"
83
+ else:
84
+ configs['train_conf']["dtype"] = "fp32"
85
+ assert ds_configs["train_micro_batch_size_per_gpu"] == 1
86
+ # if use deepspeed, override ddp config
87
+ configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] * configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"])
88
+ configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"]
89
+ configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"]
90
+ configs['train_conf']['log_interval'] = ds_configs["steps_per_print"]
91
+ return configs
92
+
93
+
94
+ def wrap_cuda_model(args, model):
95
+ local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
96
+ world_size = int(os.environ.get('WORLD_SIZE', 1))
97
+ if args.train_engine == "torch_ddp": # native pytorch ddp
98
+ assert (torch.cuda.is_available())
99
+ model.cuda()
100
+ model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
101
+ else:
102
+ if int(os.environ.get('RANK', 0)) == 0:
103
+ logging.info("Estimating model states memory needs (zero2)...")
104
+ estimate_zero2_model_states_mem_needs_all_live(
105
+ model,
106
+ num_gpus_per_node=local_world_size,
107
+ num_nodes=world_size // local_world_size)
108
+ return model
109
+
110
+
111
+ def init_optimizer_and_scheduler(args, configs, model):
112
+ if configs['train_conf']['optim'] == 'adam':
113
+ optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
114
+ elif configs['train_conf']['optim'] == 'adamw':
115
+ optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
116
+ else:
117
+ raise ValueError("unknown optimizer: " + configs['train_conf'])
118
+
119
+ if configs['train_conf']['scheduler'] == 'warmuplr':
120
+ scheduler_type = WarmupLR
121
+ scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
122
+ elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
123
+ scheduler_type = NoamHoldAnnealing
124
+ scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
125
+ elif configs['train_conf']['scheduler'] == 'constantlr':
126
+ scheduler_type = ConstantLR
127
+ scheduler = ConstantLR(optimizer)
128
+ else:
129
+ raise ValueError("unknown scheduler: " + configs['train_conf'])
130
+
131
+ # use deepspeed optimizer for speedup
132
+ if args.train_engine == "deepspeed":
133
+ def scheduler(opt):
134
+ return scheduler_type(opt, **configs['train_conf']['scheduler_conf'])
135
+ model, optimizer, _, scheduler = deepspeed.initialize(
136
+ args=args,
137
+ model=model,
138
+ optimizer=None,
139
+ lr_scheduler=scheduler,
140
+ model_parameters=model.parameters())
141
+
142
+ return model, optimizer, scheduler
143
+
144
+
145
+ def init_summarywriter(args):
146
+ writer = None
147
+ if int(os.environ.get('RANK', 0)) == 0:
148
+ os.makedirs(args.model_dir, exist_ok=True)
149
+ writer = SummaryWriter(args.tensorboard_dir)
150
+ return writer
151
+
152
+
153
+ def save_model(model, model_name, info_dict):
154
+ rank = int(os.environ.get('RANK', 0))
155
+ model_dir = info_dict["model_dir"]
156
+ save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name))
157
+
158
+ if info_dict["train_engine"] == "torch_ddp":
159
+ if rank == 0:
160
+ torch.save(model.module.state_dict(), save_model_path)
161
+ else:
162
+ with torch.no_grad():
163
+ model.save_checkpoint(save_dir=model_dir,
164
+ tag=model_name,
165
+ client_state=info_dict)
166
+ if rank == 0:
167
+ info_path = re.sub('.pt$', '.yaml', save_model_path)
168
+ info_dict['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')
169
+ with open(info_path, 'w') as fout:
170
+ data = yaml.dump(info_dict)
171
+ fout.write(data)
172
+ logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format(rank, save_model_path))
173
+
174
+
175
+ def cosyvoice_join(group_join, info_dict):
176
+ world_size = int(os.environ.get('WORLD_SIZE', 1))
177
+ local_rank = int(os.environ.get('LOCAL_RANK', 0))
178
+ rank = int(os.environ.get('RANK', 0))
179
+
180
+ if info_dict["batch_idx"] != 0:
181
+ # we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr
182
+ try:
183
+ dist.monitored_barrier(group=group_join,
184
+ timeout=group_join.options._timeout)
185
+ return False
186
+ except RuntimeError as e:
187
+ logging.info("Detected uneven workload distribution: {}\n".format(e) +
188
+ "Break current worker to manually join all workers, " +
189
+ "world_size {}, current rank {}, current local_rank {}\n".
190
+ format(world_size, rank, local_rank))
191
+ return True
192
+ else:
193
+ return False
194
+
195
+
196
+ def batch_forward(model, batch, info_dict):
197
+ device = int(os.environ.get('LOCAL_RANK', 0))
198
+
199
+ dtype = info_dict["dtype"]
200
+ if dtype == "fp16":
201
+ dtype = torch.float16
202
+ elif dtype == "bf16":
203
+ dtype = torch.bfloat16
204
+ else: # fp32
205
+ dtype = torch.float32
206
+
207
+ if info_dict['train_engine'] == 'torch_ddp':
208
+ autocast = nullcontext()
209
+ else:
210
+ autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)
211
+
212
+ with autocast:
213
+ info_dict['loss_dict'] = model(batch, device)
214
+ return info_dict
215
+
216
+
217
+ def batch_backward(model, info_dict):
218
+ if info_dict["train_engine"] == "deepspeed":
219
+ scaled_loss = model.backward(info_dict['loss_dict']['loss'])
220
+ else:
221
+ scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad']
222
+ scaled_loss.backward()
223
+
224
+ info_dict['loss_dict']['loss'] = scaled_loss
225
+ return info_dict
226
+
227
+
228
+ def update_parameter_and_lr(model, optimizer, scheduler, info_dict):
229
+ grad_norm = 0.0
230
+ if info_dict['train_engine'] == "deepspeed":
231
+ info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary()
232
+ model.step()
233
+ grad_norm = model.get_global_grad_norm()
234
+ elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0:
235
+ grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
236
+ if torch.isfinite(grad_norm):
237
+ optimizer.step()
238
+ optimizer.zero_grad()
239
+ scheduler.step()
240
+ info_dict["lr"] = optimizer.param_groups[0]['lr']
241
+ info_dict["grad_norm"] = grad_norm
242
+ return info_dict
243
+
244
+
245
+ def log_per_step(writer, info_dict):
246
+ tag = info_dict["tag"]
247
+ epoch = info_dict.get('epoch', 0)
248
+ step = info_dict["step"]
249
+ batch_idx = info_dict["batch_idx"]
250
+ loss_dict = info_dict['loss_dict']
251
+ rank = int(os.environ.get('RANK', 0))
252
+
253
+ # only rank 0 write to tensorboard to avoid multi-process write
254
+ if writer is not None:
255
+ if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \
256
+ (info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0):
257
+ for k in ['epoch', 'lr', 'grad_norm']:
258
+ writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
259
+ for k, v in loss_dict.items():
260
+ writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)
261
+
262
+ # TRAIN & CV, Shell log (stdout)
263
+ if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0:
264
+ log_str = '{} Batch {}/{} '.format(tag, epoch, batch_idx + 1)
265
+ for name, value in loss_dict.items():
266
+ log_str += '{} {:.6f} '.format(name, value)
267
+ if tag == "TRAIN":
268
+ log_str += 'lr {:.8f} grad_norm {:.6f}'.format(
269
+ info_dict["lr"], info_dict['grad_norm'])
270
+ log_str += ' rank {}'.format(rank)
271
+ logging.debug(log_str)
272
+
273
+
274
+ def log_per_save(writer, info_dict):
275
+ tag = info_dict["tag"]
276
+ epoch = info_dict["epoch"]
277
+ step = info_dict["step"]
278
+ loss_dict = info_dict["loss_dict"]
279
+ lr = info_dict['lr']
280
+ rank = int(os.environ.get('RANK', 0))
281
+ logging.info(
282
+ 'Epoch {} Step {} CV info lr {} {} rank {}'.format(
283
+ epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()])))
284
+
285
+ if writer is not None:
286
+ for k in ['epoch', 'lr']:
287
+ writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
288
+ for k, v in loss_dict.items():
289
+ writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)
examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set random seed, so that you may reproduce your result.
2
+ __set_seed1: !apply:random.seed [1986]
3
+ __set_seed2: !apply:numpy.random.seed [1986]
4
+ __set_seed3: !apply:torch.manual_seed [1986]
5
+ __set_seed4: !apply:torch.cuda.manual_seed_all [1986]
6
+
7
+ # fixed params
8
+ sample_rate: 22050
9
+ text_encoder_input_size: 512
10
+ llm_input_size: 1024
11
+ llm_output_size: 1024
12
+ spk_embed_dim: 192
13
+
14
+ # model params
15
+ # for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
16
+ # for system/third_party class/function, we do not require this.
17
+ llm: !new:cosyvoice.llm.llm.TransformerLM
18
+ text_encoder_input_size: !ref <text_encoder_input_size>
19
+ llm_input_size: !ref <llm_input_size>
20
+ llm_output_size: !ref <llm_output_size>
21
+ text_token_size: 51866
22
+ speech_token_size: 4096
23
+ length_normalized_loss: True
24
+ lsm_weight: 0
25
+ spk_embed_dim: !ref <spk_embed_dim>
26
+ text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
27
+ input_size: !ref <text_encoder_input_size>
28
+ output_size: 1024
29
+ attention_heads: 8
30
+ linear_units: 2048
31
+ num_blocks: 3
32
+ dropout_rate: 0.1
33
+ positional_dropout_rate: 0.1
34
+ attention_dropout_rate: 0
35
+ normalize_before: True
36
+ input_layer: 'linear'
37
+ pos_enc_layer_type: 'rel_pos_espnet'
38
+ selfattention_layer_type: 'rel_selfattn'
39
+ use_cnn_module: False
40
+ macaron_style: False
41
+ use_dynamic_chunk: False
42
+ use_dynamic_left_chunk: False
43
+ static_chunk_size: 1
44
+ llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
45
+ input_size: !ref <llm_input_size>
46
+ output_size: !ref <llm_output_size>
47
+ attention_heads: 8
48
+ linear_units: 2048
49
+ num_blocks: 7
50
+ dropout_rate: 0.1
51
+ positional_dropout_rate: 0.1
52
+ attention_dropout_rate: 0
53
+ input_layer: 'linear_legacy'
54
+ pos_enc_layer_type: 'rel_pos_espnet'
55
+ selfattention_layer_type: 'rel_selfattn'
56
+ static_chunk_size: 1
57
+
58
+ flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
59
+ input_size: 512
60
+ output_size: 80
61
+ spk_embed_dim: !ref <spk_embed_dim>
62
+ output_type: 'mel'
63
+ vocab_size: 4096
64
+ input_frame_rate: 50
65
+ only_mask_loss: True
66
+ encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
67
+ output_size: 512
68
+ attention_heads: 8
69
+ linear_units: 2048
70
+ num_blocks: 6
71
+ dropout_rate: 0.1
72
+ positional_dropout_rate: 0.1
73
+ attention_dropout_rate: 0.1
74
+ normalize_before: True
75
+ input_layer: 'linear'
76
+ pos_enc_layer_type: 'rel_pos_espnet'
77
+ selfattention_layer_type: 'rel_selfattn'
78
+ input_size: 512
79
+ use_cnn_module: False
80
+ macaron_style: False
81
+ length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
82
+ channels: 80
83
+ sampling_ratios: [1, 1, 1, 1]
84
+ decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
85
+ in_channels: 240
86
+ n_spks: 1
87
+ spk_emb_dim: 80
88
+ cfm_params: !new:omegaconf.DictConfig
89
+ content:
90
+ sigma_min: 1e-06
91
+ solver: 'euler'
92
+ t_scheduler: 'cosine'
93
+ training_cfg_rate: 0.2
94
+ inference_cfg_rate: 0.7
95
+ reg_loss_type: 'l1'
96
+ estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
97
+ in_channels: 320
98
+ out_channels: 80
99
+ channels: [256, 256]
100
+ dropout: 0
101
+ attention_head_dim: 64
102
+ n_blocks: 4
103
+ num_mid_blocks: 12
104
+ num_heads: 8
105
+ act_fn: 'gelu'
106
+
107
+ hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
108
+ in_channels: 80
109
+ base_channels: 512
110
+ nb_harmonics: 8
111
+ sampling_rate: !ref <sample_rate>
112
+ nsf_alpha: 0.1
113
+ nsf_sigma: 0.003
114
+ nsf_voiced_threshold: 10
115
+ upsample_rates: [8, 8]
116
+ upsample_kernel_sizes: [16, 16]
117
+ istft_params:
118
+ n_fft: 16
119
+ hop_len: 4
120
+ resblock_kernel_sizes: [3, 7, 11]
121
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
122
+ source_resblock_kernel_sizes: [7, 11]
123
+ source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
124
+ lrelu_slope: 0.1
125
+ audio_limit: 0.99
126
+ f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
127
+ num_class: 1
128
+ in_channels: 80
129
+ cond_channels: 512
130
+
131
+ # processor functions
132
+ parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
133
+ get_tokenizer: !name:whisper.tokenizer.get_tokenizer
134
+ multilingual: True
135
+ num_languages: 100
136
+ language: 'en'
137
+ task: 'transcribe'
138
+ allowed_special: 'all'
139
+ tokenize: !name:cosyvoice.dataset.processor.tokenize
140
+ get_tokenizer: !ref <get_tokenizer>
141
+ allowed_special: !ref <allowed_special>
142
+ filter: !name:cosyvoice.dataset.processor.filter
143
+ max_length: 40960
144
+ min_length: 0
145
+ token_max_length: 200
146
+ token_min_length: 1
147
+ resample: !name:cosyvoice.dataset.processor.resample
148
+ resample_rate: !ref <sample_rate>
149
+ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
150
+ n_fft: 1024
151
+ num_mels: 80
152
+ sampling_rate: !ref <sample_rate>
153
+ hop_size: 256
154
+ win_size: 1024
155
+ fmin: 0
156
+ fmax: 8000
157
+ center: False
158
+ compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
159
+ feat_extractor: !ref <feat_extractor>
160
+ parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
161
+ normalize: True
162
+ shuffle: !name:cosyvoice.dataset.processor.shuffle
163
+ shuffle_size: 1000
164
+ sort: !name:cosyvoice.dataset.processor.sort
165
+ sort_size: 500 # sort_size should be less than shuffle_size
166
+ batch: !name:cosyvoice.dataset.processor.batch
167
+ batch_type: 'dynamic'
168
+ max_frames_in_batch: 12000
169
+ padding: !name:cosyvoice.dataset.processor.padding
170
+ use_spk_embedding: False # change to True during sft
171
+
172
+ # dataset processor pipeline
173
+ data_pipeline: [
174
+ !ref <parquet_opener>,
175
+ !ref <tokenize>,
176
+ !ref <filter>,
177
+ !ref <resample>,
178
+ !ref <compute_fbank>,
179
+ !ref <parse_embedding>,
180
+ !ref <shuffle>,
181
+ !ref <sort>,
182
+ !ref <batch>,
183
+ !ref <padding>,
184
+ ]
185
+
186
+ # train conf
187
+ train_conf:
188
+ optim: adam
189
+ optim_conf:
190
+ lr: 0.002 # change to 0.001 if you want to train flow from scratch
191
+ scheduler: warmuplr
192
+ scheduler_conf:
193
+ warmup_steps: 25000
194
+ max_epoch: 200
195
+ grad_clip: 5
196
+ accum_grad: 2
197
+ log_interval: 100
198
+ save_per_step: -1
examples/libritts/cosyvoice/conf/cosyvoice.yaml ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set random seed, so that you may reproduce your result.
2
+ __set_seed1: !apply:random.seed [1986]
3
+ __set_seed2: !apply:numpy.random.seed [1986]
4
+ __set_seed3: !apply:torch.manual_seed [1986]
5
+ __set_seed4: !apply:torch.cuda.manual_seed_all [1986]
6
+
7
+ # fixed params
8
+ sample_rate: 22050
9
+ text_encoder_input_size: 512
10
+ llm_input_size: 1024
11
+ llm_output_size: 1024
12
+ spk_embed_dim: 192
13
+
14
+ # model params
15
+ # for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
16
+ # for system/third_party class/function, we do not require this.
17
+ llm: !new:cosyvoice.llm.llm.TransformerLM
18
+ text_encoder_input_size: !ref <text_encoder_input_size>
19
+ llm_input_size: !ref <llm_input_size>
20
+ llm_output_size: !ref <llm_output_size>
21
+ text_token_size: 51866
22
+ speech_token_size: 4096
23
+ length_normalized_loss: True
24
+ lsm_weight: 0
25
+ spk_embed_dim: !ref <spk_embed_dim>
26
+ text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
27
+ input_size: !ref <text_encoder_input_size>
28
+ output_size: 1024
29
+ attention_heads: 16
30
+ linear_units: 4096
31
+ num_blocks: 6
32
+ dropout_rate: 0.1
33
+ positional_dropout_rate: 0.1
34
+ attention_dropout_rate: 0
35
+ normalize_before: True
36
+ input_layer: 'linear'
37
+ pos_enc_layer_type: 'rel_pos_espnet'
38
+ selfattention_layer_type: 'rel_selfattn'
39
+ use_cnn_module: False
40
+ macaron_style: False
41
+ use_dynamic_chunk: False
42
+ use_dynamic_left_chunk: False
43
+ static_chunk_size: 1
44
+ llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
45
+ input_size: !ref <llm_input_size>
46
+ output_size: !ref <llm_output_size>
47
+ attention_heads: 16
48
+ linear_units: 4096
49
+ num_blocks: 14
50
+ dropout_rate: 0.1
51
+ positional_dropout_rate: 0.1
52
+ attention_dropout_rate: 0
53
+ input_layer: 'linear_legacy'
54
+ pos_enc_layer_type: 'rel_pos_espnet'
55
+ selfattention_layer_type: 'rel_selfattn'
56
+ static_chunk_size: 1
57
+
58
+ flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
59
+ input_size: 512
60
+ output_size: 80
61
+ spk_embed_dim: !ref <spk_embed_dim>
62
+ output_type: 'mel'
63
+ vocab_size: 4096
64
+ input_frame_rate: 50
65
+ only_mask_loss: True
66
+ encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
67
+ output_size: 512
68
+ attention_heads: 8
69
+ linear_units: 2048
70
+ num_blocks: 6
71
+ dropout_rate: 0.1
72
+ positional_dropout_rate: 0.1
73
+ attention_dropout_rate: 0.1
74
+ normalize_before: True
75
+ input_layer: 'linear'
76
+ pos_enc_layer_type: 'rel_pos_espnet'
77
+ selfattention_layer_type: 'rel_selfattn'
78
+ input_size: 512
79
+ use_cnn_module: False
80
+ macaron_style: False
81
+ length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
82
+ channels: 80
83
+ sampling_ratios: [1, 1, 1, 1]
84
+ decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
85
+ in_channels: 240
86
+ n_spks: 1
87
+ spk_emb_dim: 80
88
+ cfm_params: !new:omegaconf.DictConfig
89
+ content:
90
+ sigma_min: 1e-06
91
+ solver: 'euler'
92
+ t_scheduler: 'cosine'
93
+ training_cfg_rate: 0.2
94
+ inference_cfg_rate: 0.7
95
+ reg_loss_type: 'l1'
96
+ estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
97
+ in_channels: 320
98
+ out_channels: 80
99
+ channels: [256, 256]
100
+ dropout: 0
101
+ attention_head_dim: 64
102
+ n_blocks: 4
103
+ num_mid_blocks: 12
104
+ num_heads: 8
105
+ act_fn: 'gelu'
106
+
107
+ hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
108
+ in_channels: 80
109
+ base_channels: 512
110
+ nb_harmonics: 8
111
+ sampling_rate: !ref <sample_rate>
112
+ nsf_alpha: 0.1
113
+ nsf_sigma: 0.003
114
+ nsf_voiced_threshold: 10
115
+ upsample_rates: [8, 8]
116
+ upsample_kernel_sizes: [16, 16]
117
+ istft_params:
118
+ n_fft: 16
119
+ hop_len: 4
120
+ resblock_kernel_sizes: [3, 7, 11]
121
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
122
+ source_resblock_kernel_sizes: [7, 11]
123
+ source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
124
+ lrelu_slope: 0.1
125
+ audio_limit: 0.99
126
+ f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
127
+ num_class: 1
128
+ in_channels: 80
129
+ cond_channels: 512
130
+
131
+ # processor functions
132
+ parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
133
+ get_tokenizer: !name:whisper.tokenizer.get_tokenizer
134
+ multilingual: True
135
+ num_languages: 100
136
+ language: 'en'
137
+ task: 'transcribe'
138
+ allowed_special: 'all'
139
+ tokenize: !name:cosyvoice.dataset.processor.tokenize
140
+ get_tokenizer: !ref <get_tokenizer>
141
+ allowed_special: !ref <allowed_special>
142
+ filter: !name:cosyvoice.dataset.processor.filter
143
+ max_length: 40960
144
+ min_length: 0
145
+ token_max_length: 200
146
+ token_min_length: 1
147
+ resample: !name:cosyvoice.dataset.processor.resample
148
+ resample_rate: !ref <sample_rate>
149
+ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
150
+ n_fft: 1024
151
+ num_mels: 80
152
+ sampling_rate: !ref <sample_rate>
153
+ hop_size: 256
154
+ win_size: 1024
155
+ fmin: 0
156
+ fmax: 8000
157
+ center: False
158
+ compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
159
+ feat_extractor: !ref <feat_extractor>
160
+ parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
161
+ normalize: True
162
+ shuffle: !name:cosyvoice.dataset.processor.shuffle
163
+ shuffle_size: 1000
164
+ sort: !name:cosyvoice.dataset.processor.sort
165
+ sort_size: 500 # sort_size should be less than shuffle_size
166
+ batch: !name:cosyvoice.dataset.processor.batch
167
+ batch_type: 'dynamic'
168
+ max_frames_in_batch: 2000
169
+ padding: !name:cosyvoice.dataset.processor.padding
170
+ use_spk_embedding: False # change to True during sft
171
+
172
+ # dataset processor pipeline
173
+ data_pipeline: [
174
+ !ref <parquet_opener>,
175
+ !ref <tokenize>,
176
+ !ref <filter>,
177
+ !ref <resample>,
178
+ !ref <compute_fbank>,
179
+ !ref <parse_embedding>,
180
+ !ref <shuffle>,
181
+ !ref <sort>,
182
+ !ref <batch>,
183
+ !ref <padding>,
184
+ ]
185
+
186
+ # train conf
187
+ train_conf:
188
+ optim: adam
189
+ optim_conf:
190
+ lr: 0.001 # change to 1e-5 during sft
191
+ scheduler: warmuplr # change to constantlr during sft
192
+ scheduler_conf:
193
+ warmup_steps: 2500
194
+ max_epoch: 200
195
+ grad_clip: 5
196
+ accum_grad: 2
197
+ log_interval: 100
198
+ save_per_step: -1
examples/libritts/cosyvoice/conf/ds_stage2.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train_micro_batch_size_per_gpu": 1,
3
+ "gradient_accumulation_steps": 1,
4
+ "steps_per_print": 100,
5
+ "gradient_clipping": 5,
6
+ "fp16": {
7
+ "enabled": false,
8
+ "auto_cast": false,
9
+ "loss_scale": 0,
10
+ "initial_scale_power": 16,
11
+ "loss_scale_window": 256,
12
+ "hysteresis": 2,
13
+ "consecutive_hysteresis": false,
14
+ "min_loss_scale": 1
15
+ },
16
+ "bf16": {
17
+ "enabled": false
18
+ },
19
+ "zero_force_ds_cpu_optimizer": false,
20
+ "zero_optimization": {
21
+ "stage": 2,
22
+ "offload_optimizer": {
23
+ "device": "none",
24
+ "pin_memory": true
25
+ },
26
+ "allgather_partitions": true,
27
+ "allgather_bucket_size": 5e8,
28
+ "overlap_comm": false,
29
+ "reduce_scatter": true,
30
+ "reduce_bucket_size": 5e8,
31
+ "contiguous_gradients" : true
32
+ },
33
+ "optimizer": {
34
+ "type": "AdamW",
35
+ "params": {
36
+ "lr": 0.001,
37
+ "weight_decay": 0.0001,
38
+ "torch_adam": true,
39
+ "adam_w_mode": true
40
+ }
41
+ }
42
+ }