Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +13 -0
- Dockerfile.arm64 +13 -0
- LICENSE +201 -0
- LLM/__pycache__/chat.cpython-311.pyc +0 -0
- LLM/__pycache__/language_model.cpython-311.pyc +0 -0
- LLM/__pycache__/mlx_language_model.cpython-311.pyc +0 -0
- LLM/chat.py +25 -0
- LLM/language_model.py +144 -0
- LLM/mlx_language_model.py +107 -0
- README.md +244 -0
- STT/__pycache__/lightning_whisper_mlx_handler.cpython-311.pyc +0 -0
- STT/__pycache__/paraformer_handler.cpython-311.pyc +0 -0
- STT/__pycache__/whisper_stt_handler.cpython-311.pyc +0 -0
- STT/lightning_whisper_mlx_handler.py +85 -0
- STT/paraformer_handler.py +61 -0
- STT/whisper_stt_handler.py +140 -0
- TTS/__pycache__/chatTTS_handler.cpython-311.pyc +0 -0
- TTS/__pycache__/melo_handler.cpython-311.pyc +0 -0
- TTS/__pycache__/parler_handler.cpython-311.pyc +0 -0
- TTS/chatTTS_handler.py +82 -0
- TTS/melo_handler.py +109 -0
- TTS/parler_handler.py +191 -0
- VAD/__pycache__/vad_handler.cpython-311.pyc +0 -0
- VAD/__pycache__/vad_handler.cpython-312.pyc +0 -0
- VAD/__pycache__/vad_iterator.cpython-311.pyc +0 -0
- VAD/__pycache__/vad_iterator.cpython-312.pyc +0 -0
- VAD/vad_handler.py +92 -0
- VAD/vad_iterator.py +100 -0
- arguments_classes/__pycache__/chat_tts_arguments.cpython-311.pyc +0 -0
- arguments_classes/__pycache__/language_model_arguments.cpython-311.pyc +0 -0
- arguments_classes/__pycache__/melo_tts_arguments.cpython-311.pyc +0 -0
- arguments_classes/__pycache__/mlx_language_model_arguments.cpython-311.pyc +0 -0
- arguments_classes/__pycache__/module_arguments.cpython-311.pyc +0 -0
- arguments_classes/__pycache__/paraformer_stt_arguments.cpython-311.pyc +0 -0
- arguments_classes/__pycache__/parler_tts_arguments.cpython-311.pyc +0 -0
- arguments_classes/__pycache__/socket_receiver_arguments.cpython-311.pyc +0 -0
- arguments_classes/__pycache__/socket_sender_arguments.cpython-311.pyc +0 -0
- arguments_classes/__pycache__/vad_arguments.cpython-311.pyc +0 -0
- arguments_classes/__pycache__/whisper_stt_arguments.cpython-311.pyc +0 -0
- arguments_classes/chat_tts_arguments.py +21 -0
- arguments_classes/language_model_arguments.py +71 -0
- arguments_classes/melo_tts_arguments.py +23 -0
- arguments_classes/mlx_language_model_arguments.py +65 -0
- arguments_classes/module_arguments.py +46 -0
- arguments_classes/paraformer_stt_arguments.py +17 -0
- arguments_classes/parler_tts_arguments.py +62 -0
- arguments_classes/socket_receiver_arguments.py +24 -0
- arguments_classes/socket_sender_arguments.py +18 -0
- arguments_classes/vad_arguments.py +47 -0
- arguments_classes/whisper_stt_arguments.py +64 -0
Dockerfile
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM pytorch/pytorch:2.4.0-cuda12.1-cudnn9-devel
|
2 |
+
|
3 |
+
ENV PYTHONUNBUFFERED 1
|
4 |
+
|
5 |
+
WORKDIR /usr/src/app
|
6 |
+
|
7 |
+
# Install packages
|
8 |
+
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
|
9 |
+
|
10 |
+
COPY requirements.txt ./
|
11 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
12 |
+
|
13 |
+
COPY . .
|
Dockerfile.arm64
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvcr.io/nvidia/l4t-pytorch:r35.2.1-pth2.0-py3
|
2 |
+
|
3 |
+
ENV PYTHONUNBUFFERED 1
|
4 |
+
|
5 |
+
WORKDIR /usr/src/app
|
6 |
+
|
7 |
+
# Install packages
|
8 |
+
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
|
9 |
+
|
10 |
+
COPY requirements.txt ./
|
11 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
12 |
+
|
13 |
+
COPY . .
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [2024] [The HuggingFace Inc. team]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
LLM/__pycache__/chat.cpython-311.pyc
ADDED
Binary file (1.59 kB). View file
|
|
LLM/__pycache__/language_model.cpython-311.pyc
ADDED
Binary file (6.31 kB). View file
|
|
LLM/__pycache__/mlx_language_model.cpython-311.pyc
ADDED
Binary file (5 kB). View file
|
|
LLM/chat.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Chat:
|
2 |
+
"""
|
3 |
+
Handles the chat using to avoid OOM issues.
|
4 |
+
"""
|
5 |
+
|
6 |
+
def __init__(self, size):
|
7 |
+
self.size = size
|
8 |
+
self.init_chat_message = None
|
9 |
+
# maxlen is necessary pair, since a each new step we add an prompt and assitant answer
|
10 |
+
self.buffer = []
|
11 |
+
|
12 |
+
def append(self, item):
|
13 |
+
self.buffer.append(item)
|
14 |
+
if len(self.buffer) == 2 * (self.size + 1):
|
15 |
+
self.buffer.pop(0)
|
16 |
+
self.buffer.pop(0)
|
17 |
+
|
18 |
+
def init_chat(self, init_chat_message):
|
19 |
+
self.init_chat_message = init_chat_message
|
20 |
+
|
21 |
+
def to_list(self):
|
22 |
+
if self.init_chat_message:
|
23 |
+
return [self.init_chat_message] + self.buffer
|
24 |
+
else:
|
25 |
+
return self.buffer
|
LLM/language_model.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from threading import Thread
|
2 |
+
from transformers import (
|
3 |
+
AutoModelForCausalLM,
|
4 |
+
AutoTokenizer,
|
5 |
+
pipeline,
|
6 |
+
TextIteratorStreamer,
|
7 |
+
)
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from LLM.chat import Chat
|
11 |
+
from baseHandler import BaseHandler
|
12 |
+
from rich.console import Console
|
13 |
+
import logging
|
14 |
+
from nltk import sent_tokenize
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
console = Console()
|
19 |
+
|
20 |
+
|
21 |
+
WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
|
22 |
+
"en": "english",
|
23 |
+
"fr": "french",
|
24 |
+
"es": "spanish",
|
25 |
+
"zh": "chinese",
|
26 |
+
"ja": "japanese",
|
27 |
+
"ko": "korean",
|
28 |
+
}
|
29 |
+
|
30 |
+
class LanguageModelHandler(BaseHandler):
|
31 |
+
"""
|
32 |
+
Handles the language model part.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def setup(
|
36 |
+
self,
|
37 |
+
model_name="microsoft/Phi-3-mini-4k-instruct",
|
38 |
+
device="cuda",
|
39 |
+
torch_dtype="float16",
|
40 |
+
gen_kwargs={},
|
41 |
+
user_role="user",
|
42 |
+
chat_size=1,
|
43 |
+
init_chat_role=None,
|
44 |
+
init_chat_prompt="You are a helpful AI assistant.",
|
45 |
+
):
|
46 |
+
self.device = device
|
47 |
+
self.torch_dtype = getattr(torch, torch_dtype)
|
48 |
+
|
49 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
50 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
51 |
+
model_name, torch_dtype=torch_dtype, trust_remote_code=True
|
52 |
+
).to(device)
|
53 |
+
self.pipe = pipeline(
|
54 |
+
"text-generation", model=self.model, tokenizer=self.tokenizer, device=device
|
55 |
+
)
|
56 |
+
self.streamer = TextIteratorStreamer(
|
57 |
+
self.tokenizer,
|
58 |
+
skip_prompt=True,
|
59 |
+
skip_special_tokens=True,
|
60 |
+
)
|
61 |
+
self.gen_kwargs = {
|
62 |
+
"streamer": self.streamer,
|
63 |
+
"return_full_text": False,
|
64 |
+
**gen_kwargs,
|
65 |
+
}
|
66 |
+
|
67 |
+
self.chat = Chat(chat_size)
|
68 |
+
if init_chat_role:
|
69 |
+
if not init_chat_prompt:
|
70 |
+
raise ValueError(
|
71 |
+
"An initial promt needs to be specified when setting init_chat_role."
|
72 |
+
)
|
73 |
+
self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
|
74 |
+
self.user_role = user_role
|
75 |
+
|
76 |
+
self.warmup()
|
77 |
+
|
78 |
+
def warmup(self):
|
79 |
+
logger.info(f"Warming up {self.__class__.__name__}")
|
80 |
+
|
81 |
+
dummy_input_text = "Repeat the word 'home'."
|
82 |
+
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
|
83 |
+
warmup_gen_kwargs = {
|
84 |
+
"min_new_tokens": self.gen_kwargs["min_new_tokens"],
|
85 |
+
"max_new_tokens": self.gen_kwargs["max_new_tokens"],
|
86 |
+
**self.gen_kwargs,
|
87 |
+
}
|
88 |
+
|
89 |
+
n_steps = 2
|
90 |
+
|
91 |
+
if self.device == "cuda":
|
92 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
93 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
94 |
+
torch.cuda.synchronize()
|
95 |
+
start_event.record()
|
96 |
+
|
97 |
+
for _ in range(n_steps):
|
98 |
+
thread = Thread(
|
99 |
+
target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs
|
100 |
+
)
|
101 |
+
thread.start()
|
102 |
+
for _ in self.streamer:
|
103 |
+
pass
|
104 |
+
|
105 |
+
if self.device == "cuda":
|
106 |
+
end_event.record()
|
107 |
+
torch.cuda.synchronize()
|
108 |
+
|
109 |
+
logger.info(
|
110 |
+
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
|
111 |
+
)
|
112 |
+
|
113 |
+
def process(self, prompt):
|
114 |
+
logger.debug("infering language model...")
|
115 |
+
language_code = None
|
116 |
+
if isinstance(prompt, tuple):
|
117 |
+
prompt, language_code = prompt
|
118 |
+
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
|
119 |
+
|
120 |
+
self.chat.append({"role": self.user_role, "content": prompt})
|
121 |
+
thread = Thread(
|
122 |
+
target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs
|
123 |
+
)
|
124 |
+
thread.start()
|
125 |
+
if self.device == "mps":
|
126 |
+
generated_text = ""
|
127 |
+
for new_text in self.streamer:
|
128 |
+
generated_text += new_text
|
129 |
+
printable_text = generated_text
|
130 |
+
torch.mps.empty_cache()
|
131 |
+
else:
|
132 |
+
generated_text, printable_text = "", ""
|
133 |
+
for new_text in self.streamer:
|
134 |
+
generated_text += new_text
|
135 |
+
printable_text += new_text
|
136 |
+
sentences = sent_tokenize(printable_text)
|
137 |
+
if len(sentences) > 1:
|
138 |
+
yield (sentences[0], language_code)
|
139 |
+
printable_text = new_text
|
140 |
+
|
141 |
+
self.chat.append({"role": "assistant", "content": generated_text})
|
142 |
+
|
143 |
+
# don't forget last sentence
|
144 |
+
yield (printable_text, language_code)
|
LLM/mlx_language_model.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from LLM.chat import Chat
|
3 |
+
from baseHandler import BaseHandler
|
4 |
+
from mlx_lm import load, stream_generate, generate
|
5 |
+
from rich.console import Console
|
6 |
+
import torch
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
console = Console()
|
11 |
+
|
12 |
+
WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
|
13 |
+
"en": "english",
|
14 |
+
"fr": "french",
|
15 |
+
"es": "spanish",
|
16 |
+
"zh": "chinese",
|
17 |
+
"ja": "japanese",
|
18 |
+
"ko": "korean",
|
19 |
+
}
|
20 |
+
|
21 |
+
class MLXLanguageModelHandler(BaseHandler):
|
22 |
+
"""
|
23 |
+
Handles the language model part.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def setup(
|
27 |
+
self,
|
28 |
+
model_name="microsoft/Phi-3-mini-4k-instruct",
|
29 |
+
device="mps",
|
30 |
+
torch_dtype="float16",
|
31 |
+
gen_kwargs={},
|
32 |
+
user_role="user",
|
33 |
+
chat_size=1,
|
34 |
+
init_chat_role=None,
|
35 |
+
init_chat_prompt="You are a helpful AI assistant.",
|
36 |
+
):
|
37 |
+
self.model_name = model_name
|
38 |
+
self.model, self.tokenizer = load(self.model_name)
|
39 |
+
self.gen_kwargs = gen_kwargs
|
40 |
+
|
41 |
+
self.chat = Chat(chat_size)
|
42 |
+
if init_chat_role:
|
43 |
+
if not init_chat_prompt:
|
44 |
+
raise ValueError(
|
45 |
+
"An initial promt needs to be specified when setting init_chat_role."
|
46 |
+
)
|
47 |
+
self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
|
48 |
+
self.user_role = user_role
|
49 |
+
|
50 |
+
self.warmup()
|
51 |
+
|
52 |
+
def warmup(self):
|
53 |
+
logger.info(f"Warming up {self.__class__.__name__}")
|
54 |
+
|
55 |
+
dummy_input_text = "Repeat the word 'home'."
|
56 |
+
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
|
57 |
+
|
58 |
+
n_steps = 2
|
59 |
+
|
60 |
+
for _ in range(n_steps):
|
61 |
+
prompt = self.tokenizer.apply_chat_template(dummy_chat, tokenize=False)
|
62 |
+
generate(
|
63 |
+
self.model,
|
64 |
+
self.tokenizer,
|
65 |
+
prompt=prompt,
|
66 |
+
max_tokens=self.gen_kwargs["max_new_tokens"],
|
67 |
+
verbose=False,
|
68 |
+
)
|
69 |
+
|
70 |
+
def process(self, prompt):
|
71 |
+
logger.debug("infering language model...")
|
72 |
+
language_code = None
|
73 |
+
|
74 |
+
if isinstance(prompt, tuple):
|
75 |
+
prompt, language_code = prompt
|
76 |
+
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
|
77 |
+
|
78 |
+
self.chat.append({"role": self.user_role, "content": prompt})
|
79 |
+
|
80 |
+
# Remove system messages if using a Gemma model
|
81 |
+
if "gemma" in self.model_name.lower():
|
82 |
+
chat_messages = [
|
83 |
+
msg for msg in self.chat.to_list() if msg["role"] != "system"
|
84 |
+
]
|
85 |
+
else:
|
86 |
+
chat_messages = self.chat.to_list()
|
87 |
+
|
88 |
+
prompt = self.tokenizer.apply_chat_template(
|
89 |
+
chat_messages, tokenize=False, add_generation_prompt=True
|
90 |
+
)
|
91 |
+
output = ""
|
92 |
+
curr_output = ""
|
93 |
+
for t in stream_generate(
|
94 |
+
self.model,
|
95 |
+
self.tokenizer,
|
96 |
+
prompt,
|
97 |
+
max_tokens=self.gen_kwargs["max_new_tokens"],
|
98 |
+
):
|
99 |
+
output += t
|
100 |
+
curr_output += t
|
101 |
+
if curr_output.endswith((".", "?", "!", "<|end|>")):
|
102 |
+
yield (curr_output.replace("<|end|>", ""), language_code)
|
103 |
+
curr_output = ""
|
104 |
+
generated_text = output.replace("<|end|>", "")
|
105 |
+
torch.mps.empty_cache()
|
106 |
+
|
107 |
+
self.chat.append({"role": "assistant", "content": generated_text})
|
README.md
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<div> </div>
|
3 |
+
<img src="logo.png" width="600"/>
|
4 |
+
</div>
|
5 |
+
|
6 |
+
# Speech To Speech: an effort for an open-sourced and modular GPT4-o
|
7 |
+
|
8 |
+
|
9 |
+
## 📖 Quick Index
|
10 |
+
* [Approach](#approach)
|
11 |
+
- [Structure](#structure)
|
12 |
+
- [Modularity](#modularity)
|
13 |
+
* [Setup](#setup)
|
14 |
+
* [Usage](#usage)
|
15 |
+
- [Docker Server approach](#docker-server)
|
16 |
+
- [Server/Client approach](#serverclient-approach)
|
17 |
+
- [Local approach](#local-approach-running-on-mac)
|
18 |
+
* [Command-line usage](#command-line-usage)
|
19 |
+
- [Model parameters](#model-parameters)
|
20 |
+
- [Generation parameters](#generation-parameters)
|
21 |
+
- [Notable parameters](#notable-parameters)
|
22 |
+
|
23 |
+
## Approach
|
24 |
+
|
25 |
+
### Structure
|
26 |
+
This repository implements a speech-to-speech cascaded pipeline with consecutive parts:
|
27 |
+
1. **Voice Activity Detection (VAD)**: [silero VAD v5](https://github.com/snakers4/silero-vad)
|
28 |
+
2. **Speech to Text (STT)**: Whisper checkpoints (including [distilled versions](https://huggingface.co/distil-whisper))
|
29 |
+
3. **Language Model (LM)**: Any instruct model available on the [Hugging Face Hub](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending)! 🤗
|
30 |
+
4. **Text to Speech (TTS)**: [Parler-TTS](https://github.com/huggingface/parler-tts)🤗
|
31 |
+
|
32 |
+
### Modularity
|
33 |
+
The pipeline aims to provide a fully open and modular approach, leveraging models available on the Transformers library via the Hugging Face hub. The level of modularity intended for each part is as follows:
|
34 |
+
- **VAD**: Uses the implementation from [Silero's repo](https://github.com/snakers4/silero-vad).
|
35 |
+
- **STT**: Uses Whisper models exclusively; however, any Whisper checkpoint can be used, enabling options like [Distil-Whisper](https://huggingface.co/distil-whisper/distil-large-v3) and [French Distil-Whisper](https://huggingface.co/eustlb/distil-large-v3-fr).
|
36 |
+
- **LM**: This part is fully modular and can be changed by simply modifying the Hugging Face hub model ID. Users need to select an instruct model since the usage here involves interacting with it.
|
37 |
+
- **TTS**: The mini architecture of Parler-TTS is standard, but different checkpoints, including fine-tuned multilingual checkpoints, can be used.
|
38 |
+
|
39 |
+
The code is designed to facilitate easy modification. Each component is implemented as a class and can be re-implemented to match specific needs.
|
40 |
+
|
41 |
+
## Setup
|
42 |
+
|
43 |
+
Clone the repository:
|
44 |
+
```bash
|
45 |
+
git clone https://github.com/huggingface/speech-to-speech.git
|
46 |
+
cd speech-to-speech
|
47 |
+
```
|
48 |
+
|
49 |
+
Install the required dependencies using [uv](https://github.com/astral-sh/uv):
|
50 |
+
```bash
|
51 |
+
uv pip install -r requirements.txt
|
52 |
+
```
|
53 |
+
|
54 |
+
For Mac users, use the `requirements_mac.txt` file instead:
|
55 |
+
```bash
|
56 |
+
uv pip install -r requirements_mac.txt
|
57 |
+
```
|
58 |
+
|
59 |
+
If you want to use Melo TTS, you also need to run:
|
60 |
+
```bash
|
61 |
+
python -m unidic download
|
62 |
+
```
|
63 |
+
|
64 |
+
|
65 |
+
## Usage
|
66 |
+
|
67 |
+
The pipeline can be run in two ways:
|
68 |
+
- **Server/Client approach**: Models run on a server, and audio input/output are streamed from a client.
|
69 |
+
- **Local approach**: Runs locally.
|
70 |
+
|
71 |
+
### Docker Server
|
72 |
+
|
73 |
+
#### Install the NVIDIA Container Toolkit
|
74 |
+
|
75 |
+
https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
|
76 |
+
|
77 |
+
#### Start the docker container
|
78 |
+
```docker compose up```
|
79 |
+
|
80 |
+
### Server/Client Approach
|
81 |
+
|
82 |
+
1. Run the pipeline on the server:
|
83 |
+
```bash
|
84 |
+
python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0
|
85 |
+
```
|
86 |
+
|
87 |
+
2. Run the client locally to handle microphone input and receive generated audio:
|
88 |
+
```bash
|
89 |
+
python listen_and_play.py --host <IP address of your server>
|
90 |
+
```
|
91 |
+
|
92 |
+
### Local Approach (Mac)
|
93 |
+
|
94 |
+
1. For optimal settings on Mac:
|
95 |
+
```bash
|
96 |
+
python s2s_pipeline.py --local_mac_optimal_settings
|
97 |
+
```
|
98 |
+
|
99 |
+
This setting:
|
100 |
+
- Adds `--device mps` to use MPS for all models.
|
101 |
+
- Sets LightningWhisperMLX for STT
|
102 |
+
- Sets MLX LM for language model
|
103 |
+
- Sets MeloTTS for TTS
|
104 |
+
|
105 |
+
### Recommended usage with Cuda
|
106 |
+
|
107 |
+
Leverage Torch Compile for Whisper and Parler-TTS:
|
108 |
+
|
109 |
+
```bash
|
110 |
+
python s2s_pipeline.py \
|
111 |
+
--recv_host 0.0.0.0 \
|
112 |
+
--send_host 0.0.0.0 \
|
113 |
+
--lm_model_name microsoft/Phi-3-mini-4k-instruct \
|
114 |
+
--init_chat_role system \
|
115 |
+
--stt_compile_mode reduce-overhead \
|
116 |
+
--tts_compile_mode default
|
117 |
+
```
|
118 |
+
|
119 |
+
For the moment, modes capturing CUDA Graphs are not compatible with streaming Parler-TTS (`reduce-overhead`, `max-autotune`).
|
120 |
+
|
121 |
+
|
122 |
+
### Multi-language Support
|
123 |
+
|
124 |
+
The pipeline supports multiple languages, allowing for automatic language detection or specific language settings. Here are examples for both local (Mac) and server setups:
|
125 |
+
|
126 |
+
#### With the server version:
|
127 |
+
|
128 |
+
|
129 |
+
For automatic language detection:
|
130 |
+
|
131 |
+
```bash
|
132 |
+
python s2s_pipeline.py \
|
133 |
+
--stt_model_name large-v3 \
|
134 |
+
--language zh \
|
135 |
+
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \
|
136 |
+
```
|
137 |
+
|
138 |
+
Or for one language in particular, chinese in this example
|
139 |
+
|
140 |
+
```bash
|
141 |
+
python s2s_pipeline.py \
|
142 |
+
--stt_model_name large-v3 \
|
143 |
+
--language zh \
|
144 |
+
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \
|
145 |
+
```
|
146 |
+
|
147 |
+
#### Local Mac Setup
|
148 |
+
|
149 |
+
For automatic language detection:
|
150 |
+
|
151 |
+
```bash
|
152 |
+
python s2s_pipeline.py \
|
153 |
+
--local_mac_optimal_settings \
|
154 |
+
--device mps \
|
155 |
+
--stt_model_name large-v3 \
|
156 |
+
--language zh \
|
157 |
+
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \
|
158 |
+
```
|
159 |
+
|
160 |
+
Or for one language in particular, chinese in this example
|
161 |
+
|
162 |
+
```bash
|
163 |
+
python s2s_pipeline.py \
|
164 |
+
--local_mac_optimal_settings \
|
165 |
+
--device mps \
|
166 |
+
--stt_model_name large-v3 \
|
167 |
+
--language zh \
|
168 |
+
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \
|
169 |
+
```
|
170 |
+
|
171 |
+
|
172 |
+
## Command-line Usage
|
173 |
+
|
174 |
+
### Model Parameters
|
175 |
+
|
176 |
+
`model_name`, `torch_dtype`, and `device` are exposed for each part leveraging the Transformers' implementations: Speech to Text, Language Model, and Text to Speech. Specify the targeted pipeline part with the corresponding prefix:
|
177 |
+
- `stt` (Speech to Text)
|
178 |
+
- `lm` (Language Model)
|
179 |
+
- `tts` (Text to Speech)
|
180 |
+
|
181 |
+
For example:
|
182 |
+
```bash
|
183 |
+
--lm_model_name google/gemma-2b-it
|
184 |
+
```
|
185 |
+
|
186 |
+
### Generation Parameters
|
187 |
+
|
188 |
+
Other generation parameters of the model's generate method can be set using the part's prefix + `_gen_`, e.g., `--stt_gen_max_new_tokens 128`. These parameters can be added to the pipeline part's arguments class if not already exposed (see `LanguageModelHandlerArguments` for example).
|
189 |
+
|
190 |
+
### Notable Parameters
|
191 |
+
|
192 |
+
#### VAD Parameters
|
193 |
+
- `--thresh`: Threshold value to trigger voice activity detection.
|
194 |
+
- `--min_speech_ms`: Minimum duration of detected voice activity to be considered speech.
|
195 |
+
- `--min_silence_ms`: Minimum length of silence intervals for segmenting speech, balancing sentence cutting and latency reduction.
|
196 |
+
|
197 |
+
#### Language Model
|
198 |
+
- `--init_chat_role`: Defaults to `None`. Sets the initial role in the chat template, if applicable. Refer to the model's card to set this value (e.g. for [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) you have to set `--init_chat_role system`)
|
199 |
+
- `--init_chat_prompt`: Defaults to `"You are a helpful AI assistant."` Required when setting `--init_chat_role`.
|
200 |
+
|
201 |
+
#### Speech to Text
|
202 |
+
- `--description`: Sets the description for Parler-TTS generated voice. Defaults to: `"A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast."`
|
203 |
+
|
204 |
+
- `--play_steps_s`: Specifies the duration of the first chunk sent during streaming output from Parler-TTS, impacting readiness and decoding steps.
|
205 |
+
|
206 |
+
## Citations
|
207 |
+
|
208 |
+
### Silero VAD
|
209 |
+
```bibtex
|
210 |
+
@misc{Silero VAD,
|
211 |
+
author = {Silero Team},
|
212 |
+
title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
|
213 |
+
year = {2021},
|
214 |
+
publisher = {GitHub},
|
215 |
+
journal = {GitHub repository},
|
216 |
+
howpublished = {\url{https://github.com/snakers4/silero-vad}},
|
217 |
+
commit = {insert_some_commit_here},
|
218 |
+
email = {[email protected]}
|
219 |
+
}
|
220 |
+
```
|
221 |
+
|
222 |
+
### Distil-Whisper
|
223 |
+
```bibtex
|
224 |
+
@misc{gandhi2023distilwhisper,
|
225 |
+
title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling},
|
226 |
+
author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
|
227 |
+
year={2023},
|
228 |
+
eprint={2311.00430},
|
229 |
+
archivePrefix={arXiv},
|
230 |
+
primaryClass={cs.CL}
|
231 |
+
}
|
232 |
+
```
|
233 |
+
|
234 |
+
### Parler-TTS
|
235 |
+
```bibtex
|
236 |
+
@misc{lacombe-etal-2024-parler-tts,
|
237 |
+
author = {Yoach Lacombe and Vaibhav Srivastav and Sanchit Gandhi},
|
238 |
+
title = {Parler-TTS},
|
239 |
+
year = {2024},
|
240 |
+
publisher = {GitHub},
|
241 |
+
journal = {GitHub repository},
|
242 |
+
howpublished = {\url{https://github.com/huggingface/parler-tts}}
|
243 |
+
}
|
244 |
+
```
|
STT/__pycache__/lightning_whisper_mlx_handler.cpython-311.pyc
ADDED
Binary file (4.17 kB). View file
|
|
STT/__pycache__/paraformer_handler.cpython-311.pyc
ADDED
Binary file (3.59 kB). View file
|
|
STT/__pycache__/whisper_stt_handler.cpython-311.pyc
ADDED
Binary file (6.46 kB). View file
|
|
STT/lightning_whisper_mlx_handler.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from time import perf_counter
|
3 |
+
from baseHandler import BaseHandler
|
4 |
+
from lightning_whisper_mlx import LightningWhisperMLX
|
5 |
+
import numpy as np
|
6 |
+
from rich.console import Console
|
7 |
+
from copy import copy
|
8 |
+
import torch
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
console = Console()
|
13 |
+
|
14 |
+
SUPPORTED_LANGUAGES = [
|
15 |
+
"en",
|
16 |
+
"fr",
|
17 |
+
"es",
|
18 |
+
"zh",
|
19 |
+
"ja",
|
20 |
+
"ko",
|
21 |
+
]
|
22 |
+
|
23 |
+
|
24 |
+
class LightningWhisperSTTHandler(BaseHandler):
|
25 |
+
"""
|
26 |
+
Handles the Speech To Text generation using a Whisper model.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def setup(
|
30 |
+
self,
|
31 |
+
model_name="distil-large-v3",
|
32 |
+
device="mps",
|
33 |
+
torch_dtype="float16",
|
34 |
+
compile_mode=None,
|
35 |
+
language=None,
|
36 |
+
gen_kwargs={},
|
37 |
+
):
|
38 |
+
if len(model_name.split("/")) > 1:
|
39 |
+
model_name = model_name.split("/")[-1]
|
40 |
+
self.device = device
|
41 |
+
self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None)
|
42 |
+
self.start_language = language
|
43 |
+
self.last_language = language
|
44 |
+
|
45 |
+
self.warmup()
|
46 |
+
|
47 |
+
def warmup(self):
|
48 |
+
logger.info(f"Warming up {self.__class__.__name__}")
|
49 |
+
|
50 |
+
# 2 warmup steps for no compile or compile mode with CUDA graphs capture
|
51 |
+
n_steps = 1
|
52 |
+
dummy_input = np.array([0] * 512)
|
53 |
+
|
54 |
+
for _ in range(n_steps):
|
55 |
+
_ = self.model.transcribe(dummy_input)["text"].strip()
|
56 |
+
|
57 |
+
def process(self, spoken_prompt):
|
58 |
+
logger.debug("infering whisper...")
|
59 |
+
|
60 |
+
global pipeline_start
|
61 |
+
pipeline_start = perf_counter()
|
62 |
+
|
63 |
+
if self.start_language != 'auto':
|
64 |
+
transcription_dict = self.model.transcribe(spoken_prompt, language=self.start_language)
|
65 |
+
else:
|
66 |
+
transcription_dict = self.model.transcribe(spoken_prompt)
|
67 |
+
language_code = transcription_dict["language"]
|
68 |
+
if language_code not in SUPPORTED_LANGUAGES:
|
69 |
+
logger.warning(f"Whisper detected unsupported language: {language_code}")
|
70 |
+
if self.last_language in SUPPORTED_LANGUAGES: # reprocess with the last language
|
71 |
+
transcription_dict = self.model.transcribe(spoken_prompt, language=self.last_language)
|
72 |
+
else:
|
73 |
+
transcription_dict = {"text": "", "language": "en"}
|
74 |
+
else:
|
75 |
+
self.last_language = language_code
|
76 |
+
|
77 |
+
pred_text = transcription_dict["text"].strip()
|
78 |
+
language_code = transcription_dict["language"]
|
79 |
+
torch.mps.empty_cache()
|
80 |
+
|
81 |
+
logger.debug("finished whisper inference")
|
82 |
+
console.print(f"[yellow]USER: {pred_text}")
|
83 |
+
logger.debug(f"Language Code Whisper: {language_code}")
|
84 |
+
|
85 |
+
yield (pred_text, language_code)
|
STT/paraformer_handler.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from time import perf_counter
|
3 |
+
|
4 |
+
from baseHandler import BaseHandler
|
5 |
+
from funasr import AutoModel
|
6 |
+
import numpy as np
|
7 |
+
from rich.console import Console
|
8 |
+
import torch
|
9 |
+
|
10 |
+
logging.basicConfig(
|
11 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
12 |
+
)
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
console = Console()
|
16 |
+
|
17 |
+
|
18 |
+
class ParaformerSTTHandler(BaseHandler):
|
19 |
+
"""
|
20 |
+
Handles the Speech To Text generation using a Paraformer model.
|
21 |
+
The default for this model is set to Chinese.
|
22 |
+
This model was contributed by @wuhongsheng.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def setup(
|
26 |
+
self,
|
27 |
+
model_name="paraformer-zh",
|
28 |
+
device="cuda",
|
29 |
+
gen_kwargs={},
|
30 |
+
):
|
31 |
+
print(model_name)
|
32 |
+
if len(model_name.split("/")) > 1:
|
33 |
+
model_name = model_name.split("/")[-1]
|
34 |
+
self.device = device
|
35 |
+
self.model = AutoModel(model=model_name, device=device)
|
36 |
+
self.warmup()
|
37 |
+
|
38 |
+
def warmup(self):
|
39 |
+
logger.info(f"Warming up {self.__class__.__name__}")
|
40 |
+
|
41 |
+
# 2 warmup steps for no compile or compile mode with CUDA graphs capture
|
42 |
+
n_steps = 1
|
43 |
+
dummy_input = np.array([0] * 512, dtype=np.float32)
|
44 |
+
for _ in range(n_steps):
|
45 |
+
_ = self.model.generate(dummy_input)[0]["text"].strip().replace(" ", "")
|
46 |
+
|
47 |
+
def process(self, spoken_prompt):
|
48 |
+
logger.debug("infering paraformer...")
|
49 |
+
|
50 |
+
global pipeline_start
|
51 |
+
pipeline_start = perf_counter()
|
52 |
+
|
53 |
+
pred_text = (
|
54 |
+
self.model.generate(spoken_prompt)[0]["text"].strip().replace(" ", "")
|
55 |
+
)
|
56 |
+
torch.mps.empty_cache()
|
57 |
+
|
58 |
+
logger.debug("finished paraformer inference")
|
59 |
+
console.print(f"[yellow]USER: {pred_text}")
|
60 |
+
|
61 |
+
yield pred_text
|
STT/whisper_stt_handler.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from time import perf_counter
|
2 |
+
from transformers import (
|
3 |
+
AutoProcessor,
|
4 |
+
AutoModelForSpeechSeq2Seq
|
5 |
+
)
|
6 |
+
import torch
|
7 |
+
from copy import copy
|
8 |
+
from baseHandler import BaseHandler
|
9 |
+
from rich.console import Console
|
10 |
+
import logging
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
console = Console()
|
14 |
+
|
15 |
+
SUPPORTED_LANGUAGES = [
|
16 |
+
"en",
|
17 |
+
"fr",
|
18 |
+
"es",
|
19 |
+
"zh",
|
20 |
+
"ja",
|
21 |
+
"ko",
|
22 |
+
]
|
23 |
+
|
24 |
+
|
25 |
+
class WhisperSTTHandler(BaseHandler):
|
26 |
+
"""
|
27 |
+
Handles the Speech To Text generation using a Whisper model.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def setup(
|
31 |
+
self,
|
32 |
+
model_name="distil-whisper/distil-large-v3",
|
33 |
+
device="cuda",
|
34 |
+
torch_dtype="float16",
|
35 |
+
compile_mode=None,
|
36 |
+
language=None,
|
37 |
+
gen_kwargs={},
|
38 |
+
):
|
39 |
+
self.device = device
|
40 |
+
self.torch_dtype = getattr(torch, torch_dtype)
|
41 |
+
self.compile_mode = compile_mode
|
42 |
+
self.gen_kwargs = gen_kwargs
|
43 |
+
if language == 'auto':
|
44 |
+
language = None
|
45 |
+
self.last_language = language
|
46 |
+
if self.last_language is not None:
|
47 |
+
self.gen_kwargs["language"] = self.last_language
|
48 |
+
|
49 |
+
self.processor = AutoProcessor.from_pretrained(model_name)
|
50 |
+
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
51 |
+
model_name,
|
52 |
+
torch_dtype=self.torch_dtype,
|
53 |
+
).to(device)
|
54 |
+
|
55 |
+
# compile
|
56 |
+
if self.compile_mode:
|
57 |
+
self.model.generation_config.cache_implementation = "static"
|
58 |
+
self.model.forward = torch.compile(
|
59 |
+
self.model.forward, mode=self.compile_mode, fullgraph=True
|
60 |
+
)
|
61 |
+
self.warmup()
|
62 |
+
|
63 |
+
def prepare_model_inputs(self, spoken_prompt):
|
64 |
+
input_features = self.processor(
|
65 |
+
spoken_prompt, sampling_rate=16000, return_tensors="pt"
|
66 |
+
).input_features
|
67 |
+
input_features = input_features.to(self.device, dtype=self.torch_dtype)
|
68 |
+
|
69 |
+
return input_features
|
70 |
+
|
71 |
+
def warmup(self):
|
72 |
+
logger.info(f"Warming up {self.__class__.__name__}")
|
73 |
+
|
74 |
+
# 2 warmup steps for no compile or compile mode with CUDA graphs capture
|
75 |
+
n_steps = 1 if self.compile_mode == "default" else 2
|
76 |
+
dummy_input = torch.randn(
|
77 |
+
(1, self.model.config.num_mel_bins, 3000),
|
78 |
+
dtype=self.torch_dtype,
|
79 |
+
device=self.device,
|
80 |
+
)
|
81 |
+
if self.compile_mode not in (None, "default"):
|
82 |
+
# generating more tokens than previously will trigger CUDA graphs capture
|
83 |
+
# one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
|
84 |
+
# hence, having min_new_tokens < max_new_tokens in the future doesn't make sense
|
85 |
+
warmup_gen_kwargs = {
|
86 |
+
"min_new_tokens": self.gen_kwargs[
|
87 |
+
"max_new_tokens"
|
88 |
+
], # Yes, assign max_new_tokens to min_new_tokens
|
89 |
+
"max_new_tokens": self.gen_kwargs["max_new_tokens"],
|
90 |
+
**self.gen_kwargs,
|
91 |
+
}
|
92 |
+
else:
|
93 |
+
warmup_gen_kwargs = self.gen_kwargs
|
94 |
+
|
95 |
+
if self.device == "cuda":
|
96 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
97 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
98 |
+
torch.cuda.synchronize()
|
99 |
+
start_event.record()
|
100 |
+
|
101 |
+
for _ in range(n_steps):
|
102 |
+
_ = self.model.generate(dummy_input, **warmup_gen_kwargs)
|
103 |
+
|
104 |
+
if self.device == "cuda":
|
105 |
+
end_event.record()
|
106 |
+
torch.cuda.synchronize()
|
107 |
+
|
108 |
+
logger.info(
|
109 |
+
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
|
110 |
+
)
|
111 |
+
|
112 |
+
def process(self, spoken_prompt):
|
113 |
+
logger.debug("infering whisper...")
|
114 |
+
|
115 |
+
global pipeline_start
|
116 |
+
pipeline_start = perf_counter()
|
117 |
+
|
118 |
+
input_features = self.prepare_model_inputs(spoken_prompt)
|
119 |
+
pred_ids = self.model.generate(input_features, **self.gen_kwargs)
|
120 |
+
language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
|
121 |
+
|
122 |
+
if language_code not in SUPPORTED_LANGUAGES: # reprocess with the last language
|
123 |
+
logger.warning("Whisper detected unsupported language:", language_code)
|
124 |
+
gen_kwargs = copy(self.gen_kwargs)
|
125 |
+
gen_kwargs['language'] = self.last_language
|
126 |
+
language_code = self.last_language
|
127 |
+
pred_ids = self.model.generate(input_features, **gen_kwargs)
|
128 |
+
else:
|
129 |
+
self.last_language = language_code
|
130 |
+
|
131 |
+
pred_text = self.processor.batch_decode(
|
132 |
+
pred_ids, skip_special_tokens=True, decode_with_timestamps=False
|
133 |
+
)[0]
|
134 |
+
language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
|
135 |
+
|
136 |
+
logger.debug("finished whisper inference")
|
137 |
+
console.print(f"[yellow]USER: {pred_text}")
|
138 |
+
logger.debug(f"Language Code Whisper: {language_code}")
|
139 |
+
|
140 |
+
yield (pred_text, language_code)
|
TTS/__pycache__/chatTTS_handler.cpython-311.pyc
ADDED
Binary file (4.78 kB). View file
|
|
TTS/__pycache__/melo_handler.cpython-311.pyc
ADDED
Binary file (4.98 kB). View file
|
|
TTS/__pycache__/parler_handler.cpython-311.pyc
ADDED
Binary file (9.7 kB). View file
|
|
TTS/chatTTS_handler.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ChatTTS
|
2 |
+
import logging
|
3 |
+
from baseHandler import BaseHandler
|
4 |
+
import librosa
|
5 |
+
import numpy as np
|
6 |
+
from rich.console import Console
|
7 |
+
import torch
|
8 |
+
|
9 |
+
logging.basicConfig(
|
10 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
11 |
+
)
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
console = Console()
|
15 |
+
|
16 |
+
|
17 |
+
class ChatTTSHandler(BaseHandler):
|
18 |
+
def setup(
|
19 |
+
self,
|
20 |
+
should_listen,
|
21 |
+
device="cuda",
|
22 |
+
gen_kwargs={}, # Unused
|
23 |
+
stream=True,
|
24 |
+
chunk_size=512,
|
25 |
+
):
|
26 |
+
self.should_listen = should_listen
|
27 |
+
self.device = device
|
28 |
+
self.model = ChatTTS.Chat()
|
29 |
+
self.model.load(compile=False) # Doesn't work for me with True
|
30 |
+
self.chunk_size = chunk_size
|
31 |
+
self.stream = stream
|
32 |
+
rnd_spk_emb = self.model.sample_random_speaker()
|
33 |
+
self.params_infer_code = ChatTTS.Chat.InferCodeParams(
|
34 |
+
spk_emb=rnd_spk_emb,
|
35 |
+
)
|
36 |
+
self.warmup()
|
37 |
+
|
38 |
+
def warmup(self):
|
39 |
+
logger.info(f"Warming up {self.__class__.__name__}")
|
40 |
+
_ = self.model.infer("text")
|
41 |
+
|
42 |
+
def process(self, llm_sentence):
|
43 |
+
console.print(f"[green]ASSISTANT: {llm_sentence}")
|
44 |
+
if self.device == "mps":
|
45 |
+
import time
|
46 |
+
|
47 |
+
start = time.time()
|
48 |
+
torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete.
|
49 |
+
torch.mps.empty_cache() # Frees all memory allocated by the MPS device.
|
50 |
+
_ = (
|
51 |
+
time.time() - start
|
52 |
+
) # Removing this line makes it fail more often. I'm looking into it.
|
53 |
+
|
54 |
+
wavs_gen = self.model.infer(
|
55 |
+
llm_sentence, params_infer_code=self.params_infer_code, stream=self.stream
|
56 |
+
)
|
57 |
+
|
58 |
+
if self.stream:
|
59 |
+
wavs = [np.array([])]
|
60 |
+
for gen in wavs_gen:
|
61 |
+
if gen[0] is None or len(gen[0]) == 0:
|
62 |
+
self.should_listen.set()
|
63 |
+
return
|
64 |
+
audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000)
|
65 |
+
audio_chunk = (audio_chunk * 32768).astype(np.int16)[0]
|
66 |
+
while len(audio_chunk) > self.chunk_size:
|
67 |
+
yield audio_chunk[: self.chunk_size] # 返回前 chunk_size 字节的数据
|
68 |
+
audio_chunk = audio_chunk[self.chunk_size :] # 移除已返回的数据
|
69 |
+
yield np.pad(audio_chunk, (0, self.chunk_size - len(audio_chunk)))
|
70 |
+
else:
|
71 |
+
wavs = wavs_gen
|
72 |
+
if len(wavs[0]) == 0:
|
73 |
+
self.should_listen.set()
|
74 |
+
return
|
75 |
+
audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=16000)
|
76 |
+
audio_chunk = (audio_chunk * 32768).astype(np.int16)
|
77 |
+
for i in range(0, len(audio_chunk), self.chunk_size):
|
78 |
+
yield np.pad(
|
79 |
+
audio_chunk[i : i + self.chunk_size],
|
80 |
+
(0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])),
|
81 |
+
)
|
82 |
+
self.should_listen.set()
|
TTS/melo_handler.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from melo.api import TTS
|
2 |
+
import logging
|
3 |
+
from baseHandler import BaseHandler
|
4 |
+
import librosa
|
5 |
+
import numpy as np
|
6 |
+
from rich.console import Console
|
7 |
+
import torch
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
console = Console()
|
12 |
+
|
13 |
+
WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
|
14 |
+
"en": "EN",
|
15 |
+
"fr": "FR",
|
16 |
+
"es": "ES",
|
17 |
+
"zh": "ZH",
|
18 |
+
"ja": "JP",
|
19 |
+
"ko": "KR",
|
20 |
+
}
|
21 |
+
|
22 |
+
WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
|
23 |
+
"en": "EN-BR",
|
24 |
+
"fr": "FR",
|
25 |
+
"es": "ES",
|
26 |
+
"zh": "ZH",
|
27 |
+
"ja": "JP",
|
28 |
+
"ko": "KR",
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
class MeloTTSHandler(BaseHandler):
|
33 |
+
def setup(
|
34 |
+
self,
|
35 |
+
should_listen,
|
36 |
+
device="mps",
|
37 |
+
language="en",
|
38 |
+
speaker_to_id="en",
|
39 |
+
gen_kwargs={}, # Unused
|
40 |
+
blocksize=512,
|
41 |
+
):
|
42 |
+
self.should_listen = should_listen
|
43 |
+
self.device = device
|
44 |
+
self.language = language
|
45 |
+
self.model = TTS(
|
46 |
+
language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device
|
47 |
+
)
|
48 |
+
self.speaker_id = self.model.hps.data.spk2id[
|
49 |
+
WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]
|
50 |
+
]
|
51 |
+
self.blocksize = blocksize
|
52 |
+
self.warmup()
|
53 |
+
|
54 |
+
def warmup(self):
|
55 |
+
logger.info(f"Warming up {self.__class__.__name__}")
|
56 |
+
_ = self.model.tts_to_file("text", self.speaker_id, quiet=True)
|
57 |
+
|
58 |
+
def process(self, llm_sentence):
|
59 |
+
language_code = None
|
60 |
+
|
61 |
+
if isinstance(llm_sentence, tuple):
|
62 |
+
llm_sentence, language_code = llm_sentence
|
63 |
+
|
64 |
+
console.print(f"[green]ASSISTANT: {llm_sentence}")
|
65 |
+
|
66 |
+
if language_code is not None and self.language != language_code:
|
67 |
+
try:
|
68 |
+
self.model = TTS(
|
69 |
+
language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_code],
|
70 |
+
device=self.device,
|
71 |
+
)
|
72 |
+
self.speaker_id = self.model.hps.data.spk2id[
|
73 |
+
WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_code]
|
74 |
+
]
|
75 |
+
self.language = language_code
|
76 |
+
except KeyError:
|
77 |
+
console.print(
|
78 |
+
f"[red]Language {language_code} not supported by Melo. Using {self.language} instead."
|
79 |
+
)
|
80 |
+
|
81 |
+
if self.device == "mps":
|
82 |
+
import time
|
83 |
+
|
84 |
+
start = time.time()
|
85 |
+
torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete.
|
86 |
+
torch.mps.empty_cache() # Frees all memory allocated by the MPS device.
|
87 |
+
_ = (
|
88 |
+
time.time() - start
|
89 |
+
) # Removing this line makes it fail more often. I'm looking into it.
|
90 |
+
|
91 |
+
try:
|
92 |
+
audio_chunk = self.model.tts_to_file(
|
93 |
+
llm_sentence, self.speaker_id, quiet=True
|
94 |
+
)
|
95 |
+
except (AssertionError, RuntimeError) as e:
|
96 |
+
logger.error(f"Error in MeloTTSHandler: {e}")
|
97 |
+
audio_chunk = np.array([])
|
98 |
+
if len(audio_chunk) == 0:
|
99 |
+
self.should_listen.set()
|
100 |
+
return
|
101 |
+
audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
|
102 |
+
audio_chunk = (audio_chunk * 32768).astype(np.int16)
|
103 |
+
for i in range(0, len(audio_chunk), self.blocksize):
|
104 |
+
yield np.pad(
|
105 |
+
audio_chunk[i : i + self.blocksize],
|
106 |
+
(0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
|
107 |
+
)
|
108 |
+
|
109 |
+
self.should_listen.set()
|
TTS/parler_handler.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from threading import Thread
|
2 |
+
from time import perf_counter
|
3 |
+
from baseHandler import BaseHandler
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from transformers import (
|
7 |
+
AutoTokenizer,
|
8 |
+
)
|
9 |
+
from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
|
10 |
+
import librosa
|
11 |
+
import logging
|
12 |
+
from rich.console import Console
|
13 |
+
from utils.utils import next_power_of_2
|
14 |
+
from transformers.utils.import_utils import (
|
15 |
+
is_flash_attn_2_available,
|
16 |
+
)
|
17 |
+
|
18 |
+
torch._inductor.config.fx_graph_cache = True
|
19 |
+
# mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
|
20 |
+
torch._dynamo.config.cache_size_limit = 15
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
console = Console()
|
25 |
+
|
26 |
+
|
27 |
+
if not is_flash_attn_2_available() and torch.cuda.is_available():
|
28 |
+
logger.warn(
|
29 |
+
"""Parler TTS works best with flash attention 2, but is not installed
|
30 |
+
Given that CUDA is available in this system, you can install flash attention 2 with `uv pip install flash-attn --no-build-isolation`"""
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
class ParlerTTSHandler(BaseHandler):
|
35 |
+
def setup(
|
36 |
+
self,
|
37 |
+
should_listen,
|
38 |
+
model_name="ylacombe/parler-tts-mini-jenny-30H",
|
39 |
+
device="cuda",
|
40 |
+
torch_dtype="float16",
|
41 |
+
compile_mode=None,
|
42 |
+
gen_kwargs={},
|
43 |
+
max_prompt_pad_length=8,
|
44 |
+
description=(
|
45 |
+
"A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
|
46 |
+
"She speaks very fast."
|
47 |
+
),
|
48 |
+
play_steps_s=1,
|
49 |
+
blocksize=512,
|
50 |
+
):
|
51 |
+
self.should_listen = should_listen
|
52 |
+
self.device = device
|
53 |
+
self.torch_dtype = getattr(torch, torch_dtype)
|
54 |
+
self.gen_kwargs = gen_kwargs
|
55 |
+
self.compile_mode = compile_mode
|
56 |
+
self.max_prompt_pad_length = max_prompt_pad_length
|
57 |
+
self.description = description
|
58 |
+
|
59 |
+
self.description_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
60 |
+
self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
61 |
+
self.model = ParlerTTSForConditionalGeneration.from_pretrained(
|
62 |
+
model_name, torch_dtype=self.torch_dtype
|
63 |
+
).to(device)
|
64 |
+
|
65 |
+
framerate = self.model.audio_encoder.config.frame_rate
|
66 |
+
self.play_steps = int(framerate * play_steps_s)
|
67 |
+
self.blocksize = blocksize
|
68 |
+
|
69 |
+
if self.compile_mode not in (None, "default"):
|
70 |
+
logger.warning(
|
71 |
+
"Torch compilation modes that captures CUDA graphs are not yet compatible with the TTS part. Reverting to 'default'"
|
72 |
+
)
|
73 |
+
self.compile_mode = "default"
|
74 |
+
|
75 |
+
if self.compile_mode:
|
76 |
+
self.model.generation_config.cache_implementation = "static"
|
77 |
+
self.model.forward = torch.compile(
|
78 |
+
self.model.forward, mode=self.compile_mode, fullgraph=True
|
79 |
+
)
|
80 |
+
|
81 |
+
self.warmup()
|
82 |
+
|
83 |
+
def prepare_model_inputs(
|
84 |
+
self,
|
85 |
+
prompt,
|
86 |
+
max_length_prompt=50,
|
87 |
+
pad=False,
|
88 |
+
):
|
89 |
+
pad_args_prompt = (
|
90 |
+
{"padding": "max_length", "max_length": max_length_prompt} if pad else {}
|
91 |
+
)
|
92 |
+
|
93 |
+
tokenized_description = self.description_tokenizer(
|
94 |
+
self.description, return_tensors="pt"
|
95 |
+
)
|
96 |
+
input_ids = tokenized_description.input_ids.to(self.device)
|
97 |
+
attention_mask = tokenized_description.attention_mask.to(self.device)
|
98 |
+
|
99 |
+
tokenized_prompt = self.prompt_tokenizer(
|
100 |
+
prompt, return_tensors="pt", **pad_args_prompt
|
101 |
+
)
|
102 |
+
prompt_input_ids = tokenized_prompt.input_ids.to(self.device)
|
103 |
+
prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device)
|
104 |
+
|
105 |
+
gen_kwargs = {
|
106 |
+
"input_ids": input_ids,
|
107 |
+
"attention_mask": attention_mask,
|
108 |
+
"prompt_input_ids": prompt_input_ids,
|
109 |
+
"prompt_attention_mask": prompt_attention_mask,
|
110 |
+
**self.gen_kwargs,
|
111 |
+
}
|
112 |
+
|
113 |
+
return gen_kwargs
|
114 |
+
|
115 |
+
def warmup(self):
|
116 |
+
logger.info(f"Warming up {self.__class__.__name__}")
|
117 |
+
|
118 |
+
if self.device == "cuda":
|
119 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
120 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
121 |
+
|
122 |
+
# 2 warmup steps for no compile or compile mode with CUDA graphs capture
|
123 |
+
n_steps = 1 if self.compile_mode == "default" else 2
|
124 |
+
|
125 |
+
if self.device == "cuda":
|
126 |
+
torch.cuda.synchronize()
|
127 |
+
start_event.record()
|
128 |
+
if self.compile_mode:
|
129 |
+
pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)]
|
130 |
+
for pad_length in pad_lengths[::-1]:
|
131 |
+
model_kwargs = self.prepare_model_inputs(
|
132 |
+
"dummy prompt", max_length_prompt=pad_length, pad=True
|
133 |
+
)
|
134 |
+
for _ in range(n_steps):
|
135 |
+
_ = self.model.generate(**model_kwargs)
|
136 |
+
logger.info(f"Warmed up length {pad_length} tokens!")
|
137 |
+
else:
|
138 |
+
model_kwargs = self.prepare_model_inputs("dummy prompt")
|
139 |
+
for _ in range(n_steps):
|
140 |
+
_ = self.model.generate(**model_kwargs)
|
141 |
+
|
142 |
+
if self.device == "cuda":
|
143 |
+
end_event.record()
|
144 |
+
torch.cuda.synchronize()
|
145 |
+
logger.info(
|
146 |
+
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
|
147 |
+
)
|
148 |
+
|
149 |
+
def process(self, llm_sentence):
|
150 |
+
if isinstance(llm_sentence, tuple):
|
151 |
+
llm_sentence, _ = llm_sentence
|
152 |
+
|
153 |
+
console.print(f"[green]ASSISTANT: {llm_sentence}")
|
154 |
+
nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids)
|
155 |
+
|
156 |
+
pad_args = {}
|
157 |
+
if self.compile_mode:
|
158 |
+
# pad to closest upper power of two
|
159 |
+
pad_length = next_power_of_2(nb_tokens)
|
160 |
+
logger.debug(f"padding to {pad_length}")
|
161 |
+
pad_args["pad"] = True
|
162 |
+
pad_args["max_length_prompt"] = pad_length
|
163 |
+
|
164 |
+
tts_gen_kwargs = self.prepare_model_inputs(
|
165 |
+
llm_sentence,
|
166 |
+
**pad_args,
|
167 |
+
)
|
168 |
+
|
169 |
+
streamer = ParlerTTSStreamer(
|
170 |
+
self.model, device=self.device, play_steps=self.play_steps
|
171 |
+
)
|
172 |
+
tts_gen_kwargs = {"streamer": streamer, **tts_gen_kwargs}
|
173 |
+
torch.manual_seed(0)
|
174 |
+
thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs)
|
175 |
+
thread.start()
|
176 |
+
|
177 |
+
for i, audio_chunk in enumerate(streamer):
|
178 |
+
global pipeline_start
|
179 |
+
if i == 0 and "pipeline_start" in globals():
|
180 |
+
logger.info(
|
181 |
+
f"Time to first audio: {perf_counter() - pipeline_start:.3f}"
|
182 |
+
)
|
183 |
+
audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
|
184 |
+
audio_chunk = (audio_chunk * 32768).astype(np.int16)
|
185 |
+
for i in range(0, len(audio_chunk), self.blocksize):
|
186 |
+
yield np.pad(
|
187 |
+
audio_chunk[i : i + self.blocksize],
|
188 |
+
(0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
|
189 |
+
)
|
190 |
+
|
191 |
+
self.should_listen.set()
|
VAD/__pycache__/vad_handler.cpython-311.pyc
ADDED
Binary file (4.81 kB). View file
|
|
VAD/__pycache__/vad_handler.cpython-312.pyc
ADDED
Binary file (4.46 kB). View file
|
|
VAD/__pycache__/vad_iterator.cpython-311.pyc
ADDED
Binary file (4.4 kB). View file
|
|
VAD/__pycache__/vad_iterator.cpython-312.pyc
ADDED
Binary file (4.24 kB). View file
|
|
VAD/vad_handler.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchaudio
|
2 |
+
from VAD.vad_iterator import VADIterator
|
3 |
+
from baseHandler import BaseHandler
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from rich.console import Console
|
7 |
+
|
8 |
+
from utils.utils import int2float
|
9 |
+
from df.enhance import enhance, init_df
|
10 |
+
import logging
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
console = Console()
|
15 |
+
|
16 |
+
|
17 |
+
class VADHandler(BaseHandler):
|
18 |
+
"""
|
19 |
+
Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed
|
20 |
+
to the following part.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def setup(
|
24 |
+
self,
|
25 |
+
should_listen,
|
26 |
+
thresh=0.3,
|
27 |
+
sample_rate=16000,
|
28 |
+
min_silence_ms=1000,
|
29 |
+
min_speech_ms=500,
|
30 |
+
max_speech_ms=float("inf"),
|
31 |
+
speech_pad_ms=30,
|
32 |
+
audio_enhancement=False,
|
33 |
+
):
|
34 |
+
self.should_listen = should_listen
|
35 |
+
self.sample_rate = sample_rate
|
36 |
+
self.min_silence_ms = min_silence_ms
|
37 |
+
self.min_speech_ms = min_speech_ms
|
38 |
+
self.max_speech_ms = max_speech_ms
|
39 |
+
self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad")
|
40 |
+
self.iterator = VADIterator(
|
41 |
+
self.model,
|
42 |
+
threshold=thresh,
|
43 |
+
sampling_rate=sample_rate,
|
44 |
+
min_silence_duration_ms=min_silence_ms,
|
45 |
+
speech_pad_ms=speech_pad_ms,
|
46 |
+
)
|
47 |
+
self.audio_enhancement = audio_enhancement
|
48 |
+
if audio_enhancement:
|
49 |
+
self.enhanced_model, self.df_state, _ = init_df()
|
50 |
+
|
51 |
+
def process(self, audio_chunk):
|
52 |
+
audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16)
|
53 |
+
audio_float32 = int2float(audio_int16)
|
54 |
+
vad_output = self.iterator(torch.from_numpy(audio_float32))
|
55 |
+
if vad_output is not None and len(vad_output) != 0:
|
56 |
+
logger.debug("VAD: end of speech detected")
|
57 |
+
array = torch.cat(vad_output).cpu().numpy()
|
58 |
+
duration_ms = len(array) / self.sample_rate * 1000
|
59 |
+
if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms:
|
60 |
+
logger.debug(
|
61 |
+
f"audio input of duration: {len(array) / self.sample_rate}s, skipping"
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
self.should_listen.clear()
|
65 |
+
logger.debug("Stop listening")
|
66 |
+
if self.audio_enhancement:
|
67 |
+
if self.sample_rate != self.df_state.sr():
|
68 |
+
audio_float32 = torchaudio.functional.resample(
|
69 |
+
torch.from_numpy(array),
|
70 |
+
orig_freq=self.sample_rate,
|
71 |
+
new_freq=self.df_state.sr(),
|
72 |
+
)
|
73 |
+
enhanced = enhance(
|
74 |
+
self.enhanced_model,
|
75 |
+
self.df_state,
|
76 |
+
audio_float32.unsqueeze(0),
|
77 |
+
)
|
78 |
+
enhanced = torchaudio.functional.resample(
|
79 |
+
enhanced,
|
80 |
+
orig_freq=self.df_state.sr(),
|
81 |
+
new_freq=self.sample_rate,
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
enhanced = enhance(
|
85 |
+
self.enhanced_model, self.df_state, audio_float32
|
86 |
+
)
|
87 |
+
array = enhanced.numpy().squeeze()
|
88 |
+
yield array
|
89 |
+
|
90 |
+
@property
|
91 |
+
def min_time_to_debug(self):
|
92 |
+
return 0.00001
|
VAD/vad_iterator.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class VADIterator:
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
model,
|
8 |
+
threshold: float = 0.5,
|
9 |
+
sampling_rate: int = 16000,
|
10 |
+
min_silence_duration_ms: int = 100,
|
11 |
+
speech_pad_ms: int = 30,
|
12 |
+
):
|
13 |
+
"""
|
14 |
+
Mainly taken from https://github.com/snakers4/silero-vad
|
15 |
+
Class for stream imitation
|
16 |
+
|
17 |
+
Parameters
|
18 |
+
----------
|
19 |
+
model: preloaded .jit/.onnx silero VAD model
|
20 |
+
|
21 |
+
threshold: float (default - 0.5)
|
22 |
+
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
23 |
+
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
24 |
+
|
25 |
+
sampling_rate: int (default - 16000)
|
26 |
+
Currently silero VAD models support 8000 and 16000 sample rates
|
27 |
+
|
28 |
+
min_silence_duration_ms: int (default - 100 milliseconds)
|
29 |
+
In the end of each speech chunk wait for min_silence_duration_ms before separating it
|
30 |
+
|
31 |
+
speech_pad_ms: int (default - 30 milliseconds)
|
32 |
+
Final speech chunks are padded by speech_pad_ms each side
|
33 |
+
"""
|
34 |
+
|
35 |
+
self.model = model
|
36 |
+
self.threshold = threshold
|
37 |
+
self.sampling_rate = sampling_rate
|
38 |
+
self.is_speaking = False
|
39 |
+
self.buffer = []
|
40 |
+
|
41 |
+
if sampling_rate not in [8000, 16000]:
|
42 |
+
raise ValueError(
|
43 |
+
"VADIterator does not support sampling rates other than [8000, 16000]"
|
44 |
+
)
|
45 |
+
|
46 |
+
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
47 |
+
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
48 |
+
self.reset_states()
|
49 |
+
|
50 |
+
def reset_states(self):
|
51 |
+
self.model.reset_states()
|
52 |
+
self.triggered = False
|
53 |
+
self.temp_end = 0
|
54 |
+
self.current_sample = 0
|
55 |
+
|
56 |
+
@torch.no_grad()
|
57 |
+
def __call__(self, x):
|
58 |
+
"""
|
59 |
+
x: torch.Tensor
|
60 |
+
audio chunk (see examples in repo)
|
61 |
+
|
62 |
+
return_seconds: bool (default - False)
|
63 |
+
whether return timestamps in seconds (default - samples)
|
64 |
+
"""
|
65 |
+
|
66 |
+
if not torch.is_tensor(x):
|
67 |
+
try:
|
68 |
+
x = torch.Tensor(x)
|
69 |
+
except Exception:
|
70 |
+
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
|
71 |
+
|
72 |
+
window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
|
73 |
+
self.current_sample += window_size_samples
|
74 |
+
|
75 |
+
speech_prob = self.model(x, self.sampling_rate).item()
|
76 |
+
|
77 |
+
if (speech_prob >= self.threshold) and self.temp_end:
|
78 |
+
self.temp_end = 0
|
79 |
+
|
80 |
+
if (speech_prob >= self.threshold) and not self.triggered:
|
81 |
+
self.triggered = True
|
82 |
+
return None
|
83 |
+
|
84 |
+
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
85 |
+
if not self.temp_end:
|
86 |
+
self.temp_end = self.current_sample
|
87 |
+
if self.current_sample - self.temp_end < self.min_silence_samples:
|
88 |
+
return None
|
89 |
+
else:
|
90 |
+
# end of speak
|
91 |
+
self.temp_end = 0
|
92 |
+
self.triggered = False
|
93 |
+
spoken_utterance = self.buffer
|
94 |
+
self.buffer = []
|
95 |
+
return spoken_utterance
|
96 |
+
|
97 |
+
if self.triggered:
|
98 |
+
self.buffer.append(x)
|
99 |
+
|
100 |
+
return None
|
arguments_classes/__pycache__/chat_tts_arguments.cpython-311.pyc
ADDED
Binary file (1.22 kB). View file
|
|
arguments_classes/__pycache__/language_model_arguments.cpython-311.pyc
ADDED
Binary file (3.17 kB). View file
|
|
arguments_classes/__pycache__/melo_tts_arguments.cpython-311.pyc
ADDED
Binary file (1.17 kB). View file
|
|
arguments_classes/__pycache__/mlx_language_model_arguments.cpython-311.pyc
ADDED
Binary file (3.02 kB). View file
|
|
arguments_classes/__pycache__/module_arguments.cpython-311.pyc
ADDED
Binary file (2.11 kB). View file
|
|
arguments_classes/__pycache__/paraformer_stt_arguments.cpython-311.pyc
ADDED
Binary file (1.1 kB). View file
|
|
arguments_classes/__pycache__/parler_tts_arguments.cpython-311.pyc
ADDED
Binary file (2.92 kB). View file
|
|
arguments_classes/__pycache__/socket_receiver_arguments.cpython-311.pyc
ADDED
Binary file (1.27 kB). View file
|
|
arguments_classes/__pycache__/socket_sender_arguments.cpython-311.pyc
ADDED
Binary file (1.06 kB). View file
|
|
arguments_classes/__pycache__/vad_arguments.cpython-311.pyc
ADDED
Binary file (2.35 kB). View file
|
|
arguments_classes/__pycache__/whisper_stt_arguments.cpython-311.pyc
ADDED
Binary file (2.9 kB). View file
|
|
arguments_classes/chat_tts_arguments.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class ChatTTSHandlerArguments:
|
6 |
+
chat_tts_stream: bool = field(
|
7 |
+
default=True,
|
8 |
+
metadata={"help": "The tts mode is stream Default is 'stream'."},
|
9 |
+
)
|
10 |
+
chat_tts_device: str = field(
|
11 |
+
default="cuda",
|
12 |
+
metadata={
|
13 |
+
"help": "The device to be used for speech synthesis. Default is 'cuda'."
|
14 |
+
},
|
15 |
+
)
|
16 |
+
chat_tts_chunk_size: int = field(
|
17 |
+
default=512,
|
18 |
+
metadata={
|
19 |
+
"help": "Sets the size of the audio data chunk processed per cycle, balancing playback latency and CPU load.. Default is 512。."
|
20 |
+
},
|
21 |
+
)
|
arguments_classes/language_model_arguments.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class LanguageModelHandlerArguments:
|
6 |
+
lm_model_name: str = field(
|
7 |
+
default="HuggingFaceTB/SmolLM-360M-Instruct",
|
8 |
+
metadata={
|
9 |
+
"help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'."
|
10 |
+
},
|
11 |
+
)
|
12 |
+
lm_device: str = field(
|
13 |
+
default="cuda",
|
14 |
+
metadata={
|
15 |
+
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
|
16 |
+
},
|
17 |
+
)
|
18 |
+
lm_torch_dtype: str = field(
|
19 |
+
default="float16",
|
20 |
+
metadata={
|
21 |
+
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
|
22 |
+
},
|
23 |
+
)
|
24 |
+
user_role: str = field(
|
25 |
+
default="user",
|
26 |
+
metadata={
|
27 |
+
"help": "Role assigned to the user in the chat context. Default is 'user'."
|
28 |
+
},
|
29 |
+
)
|
30 |
+
init_chat_role: str = field(
|
31 |
+
default="system",
|
32 |
+
metadata={
|
33 |
+
"help": "Initial role for setting up the chat context. Default is 'system'."
|
34 |
+
},
|
35 |
+
)
|
36 |
+
init_chat_prompt: str = field(
|
37 |
+
default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
|
38 |
+
metadata={
|
39 |
+
"help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
|
40 |
+
},
|
41 |
+
)
|
42 |
+
lm_gen_max_new_tokens: int = field(
|
43 |
+
default=128,
|
44 |
+
metadata={
|
45 |
+
"help": "Maximum number of new tokens to generate in a single completion. Default is 128."
|
46 |
+
},
|
47 |
+
)
|
48 |
+
lm_gen_min_new_tokens: int = field(
|
49 |
+
default=0,
|
50 |
+
metadata={
|
51 |
+
"help": "Minimum number of new tokens to generate in a single completion. Default is 0."
|
52 |
+
},
|
53 |
+
)
|
54 |
+
lm_gen_temperature: float = field(
|
55 |
+
default=0.0,
|
56 |
+
metadata={
|
57 |
+
"help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."
|
58 |
+
},
|
59 |
+
)
|
60 |
+
lm_gen_do_sample: bool = field(
|
61 |
+
default=False,
|
62 |
+
metadata={
|
63 |
+
"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."
|
64 |
+
},
|
65 |
+
)
|
66 |
+
chat_size: int = field(
|
67 |
+
default=2,
|
68 |
+
metadata={
|
69 |
+
"help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
|
70 |
+
},
|
71 |
+
)
|
arguments_classes/melo_tts_arguments.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class MeloTTSHandlerArguments:
|
6 |
+
melo_language: str = field(
|
7 |
+
default="en",
|
8 |
+
metadata={
|
9 |
+
"help": "The language of the text to be synthesized. Default is 'EN_NEWEST'."
|
10 |
+
},
|
11 |
+
)
|
12 |
+
melo_device: str = field(
|
13 |
+
default="auto",
|
14 |
+
metadata={
|
15 |
+
"help": "The device to be used for speech synthesis. Default is 'auto'."
|
16 |
+
},
|
17 |
+
)
|
18 |
+
melo_speaker_to_id: str = field(
|
19 |
+
default="en",
|
20 |
+
metadata={
|
21 |
+
"help": "Mapping of speaker names to speaker IDs. Default is ['EN-Newest']."
|
22 |
+
},
|
23 |
+
)
|
arguments_classes/mlx_language_model_arguments.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class MLXLanguageModelHandlerArguments:
|
6 |
+
mlx_lm_model_name: str = field(
|
7 |
+
default="mlx-community/SmolLM-360M-Instruct",
|
8 |
+
metadata={
|
9 |
+
"help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'."
|
10 |
+
},
|
11 |
+
)
|
12 |
+
mlx_lm_device: str = field(
|
13 |
+
default="mps",
|
14 |
+
metadata={
|
15 |
+
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
|
16 |
+
},
|
17 |
+
)
|
18 |
+
mlx_lm_torch_dtype: str = field(
|
19 |
+
default="float16",
|
20 |
+
metadata={
|
21 |
+
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
|
22 |
+
},
|
23 |
+
)
|
24 |
+
mlx_lm_user_role: str = field(
|
25 |
+
default="user",
|
26 |
+
metadata={
|
27 |
+
"help": "Role assigned to the user in the chat context. Default is 'user'."
|
28 |
+
},
|
29 |
+
)
|
30 |
+
mlx_lm_init_chat_role: str = field(
|
31 |
+
default="system",
|
32 |
+
metadata={
|
33 |
+
"help": "Initial role for setting up the chat context. Default is 'system'."
|
34 |
+
},
|
35 |
+
)
|
36 |
+
mlx_lm_init_chat_prompt: str = field(
|
37 |
+
default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
|
38 |
+
metadata={
|
39 |
+
"help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
|
40 |
+
},
|
41 |
+
)
|
42 |
+
mlx_lm_gen_max_new_tokens: int = field(
|
43 |
+
default=128,
|
44 |
+
metadata={
|
45 |
+
"help": "Maximum number of new tokens to generate in a single completion. Default is 128."
|
46 |
+
},
|
47 |
+
)
|
48 |
+
mlx_lm_gen_temperature: float = field(
|
49 |
+
default=0.0,
|
50 |
+
metadata={
|
51 |
+
"help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."
|
52 |
+
},
|
53 |
+
)
|
54 |
+
mlx_lm_gen_do_sample: bool = field(
|
55 |
+
default=False,
|
56 |
+
metadata={
|
57 |
+
"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."
|
58 |
+
},
|
59 |
+
)
|
60 |
+
mlx_lm_chat_size: int = field(
|
61 |
+
default=2,
|
62 |
+
metadata={
|
63 |
+
"help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
|
64 |
+
},
|
65 |
+
)
|
arguments_classes/module_arguments.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class ModuleArguments:
|
7 |
+
device: Optional[str] = field(
|
8 |
+
default=None,
|
9 |
+
metadata={"help": "If specified, overrides the device for all handlers."},
|
10 |
+
)
|
11 |
+
mode: Optional[str] = field(
|
12 |
+
default="socket",
|
13 |
+
metadata={
|
14 |
+
"help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'socket'."
|
15 |
+
},
|
16 |
+
)
|
17 |
+
local_mac_optimal_settings: bool = field(
|
18 |
+
default=False,
|
19 |
+
metadata={
|
20 |
+
"help": "If specified, sets the optimal settings for Mac OS. Hence whisper-mlx, MLX LM and MeloTTS will be used."
|
21 |
+
},
|
22 |
+
)
|
23 |
+
stt: Optional[str] = field(
|
24 |
+
default="whisper",
|
25 |
+
metadata={
|
26 |
+
"help": "The STT to use. Either 'whisper', 'whisper-mlx', and 'paraformer'. Default is 'whisper'."
|
27 |
+
},
|
28 |
+
)
|
29 |
+
llm: Optional[str] = field(
|
30 |
+
default="transformers",
|
31 |
+
metadata={
|
32 |
+
"help": "The LLM to use. Either 'transformers' or 'mlx-lm'. Default is 'transformers'"
|
33 |
+
},
|
34 |
+
)
|
35 |
+
tts: Optional[str] = field(
|
36 |
+
default="parler",
|
37 |
+
metadata={
|
38 |
+
"help": "The TTS to use. Either 'parler', 'melo', or 'chatTTS'. Default is 'parler'"
|
39 |
+
},
|
40 |
+
)
|
41 |
+
log_level: str = field(
|
42 |
+
default="info",
|
43 |
+
metadata={
|
44 |
+
"help": "Provide logging level. Example --log_level debug, default=warning."
|
45 |
+
},
|
46 |
+
)
|
arguments_classes/paraformer_stt_arguments.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class ParaformerSTTHandlerArguments:
|
6 |
+
paraformer_stt_model_name: str = field(
|
7 |
+
default="paraformer-zh",
|
8 |
+
metadata={
|
9 |
+
"help": "The pretrained model to use. Default is 'paraformer-zh'. Can be choose from https://github.com/modelscope/FunASR"
|
10 |
+
},
|
11 |
+
)
|
12 |
+
paraformer_stt_device: str = field(
|
13 |
+
default="cuda",
|
14 |
+
metadata={
|
15 |
+
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
|
16 |
+
},
|
17 |
+
)
|
arguments_classes/parler_tts_arguments.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class ParlerTTSHandlerArguments:
|
6 |
+
tts_model_name: str = field(
|
7 |
+
default="ylacombe/parler-tts-mini-jenny-30H",
|
8 |
+
metadata={
|
9 |
+
"help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'."
|
10 |
+
},
|
11 |
+
)
|
12 |
+
tts_device: str = field(
|
13 |
+
default="cuda",
|
14 |
+
metadata={
|
15 |
+
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
|
16 |
+
},
|
17 |
+
)
|
18 |
+
tts_torch_dtype: str = field(
|
19 |
+
default="float16",
|
20 |
+
metadata={
|
21 |
+
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
|
22 |
+
},
|
23 |
+
)
|
24 |
+
tts_compile_mode: str = field(
|
25 |
+
default=None,
|
26 |
+
metadata={
|
27 |
+
"help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
|
28 |
+
},
|
29 |
+
)
|
30 |
+
tts_gen_min_new_tokens: int = field(
|
31 |
+
default=64,
|
32 |
+
metadata={
|
33 |
+
"help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs"
|
34 |
+
},
|
35 |
+
)
|
36 |
+
tts_gen_max_new_tokens: int = field(
|
37 |
+
default=512,
|
38 |
+
metadata={
|
39 |
+
"help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs"
|
40 |
+
},
|
41 |
+
)
|
42 |
+
description: str = field(
|
43 |
+
default=(
|
44 |
+
"A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
|
45 |
+
"She speaks very fast."
|
46 |
+
),
|
47 |
+
metadata={
|
48 |
+
"help": "Description of the speaker's voice and speaking style to guide the TTS model."
|
49 |
+
},
|
50 |
+
)
|
51 |
+
play_steps_s: float = field(
|
52 |
+
default=1.0,
|
53 |
+
metadata={
|
54 |
+
"help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds."
|
55 |
+
},
|
56 |
+
)
|
57 |
+
max_prompt_pad_length: int = field(
|
58 |
+
default=8,
|
59 |
+
metadata={
|
60 |
+
"help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible."
|
61 |
+
},
|
62 |
+
)
|
arguments_classes/socket_receiver_arguments.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class SocketReceiverArguments:
|
6 |
+
recv_host: str = field(
|
7 |
+
default="localhost",
|
8 |
+
metadata={
|
9 |
+
"help": "The host IP ddress for the socket connection. Default is '0.0.0.0' which binds to all "
|
10 |
+
"available interfaces on the host machine."
|
11 |
+
},
|
12 |
+
)
|
13 |
+
recv_port: int = field(
|
14 |
+
default=12345,
|
15 |
+
metadata={
|
16 |
+
"help": "The port number on which the socket server listens. Default is 12346."
|
17 |
+
},
|
18 |
+
)
|
19 |
+
chunk_size: int = field(
|
20 |
+
default=1024,
|
21 |
+
metadata={
|
22 |
+
"help": "The size of each data chunk to be sent or received over the socket. Default is 1024 bytes."
|
23 |
+
},
|
24 |
+
)
|
arguments_classes/socket_sender_arguments.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class SocketSenderArguments:
|
6 |
+
send_host: str = field(
|
7 |
+
default="localhost",
|
8 |
+
metadata={
|
9 |
+
"help": "The host IP address for the socket connection. Default is '0.0.0.0' which binds to all "
|
10 |
+
"available interfaces on the host machine."
|
11 |
+
},
|
12 |
+
)
|
13 |
+
send_port: int = field(
|
14 |
+
default=12346,
|
15 |
+
metadata={
|
16 |
+
"help": "The port number on which the socket server listens. Default is 12346."
|
17 |
+
},
|
18 |
+
)
|
arguments_classes/vad_arguments.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class VADHandlerArguments:
|
6 |
+
thresh: float = field(
|
7 |
+
default=0.3,
|
8 |
+
metadata={
|
9 |
+
"help": "The threshold value for voice activity detection (VAD). Values typically range from 0 to 1, with higher values requiring higher confidence in speech detection."
|
10 |
+
},
|
11 |
+
)
|
12 |
+
sample_rate: int = field(
|
13 |
+
default=16000,
|
14 |
+
metadata={
|
15 |
+
"help": "The sample rate of the audio in Hertz. Default is 16000 Hz, which is a common setting for voice audio."
|
16 |
+
},
|
17 |
+
)
|
18 |
+
min_silence_ms: int = field(
|
19 |
+
default=250,
|
20 |
+
metadata={
|
21 |
+
"help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 250 ms."
|
22 |
+
},
|
23 |
+
)
|
24 |
+
min_speech_ms: int = field(
|
25 |
+
default=500,
|
26 |
+
metadata={
|
27 |
+
"help": "Minimum length of speech segments to be considered valid speech. Measured in milliseconds. Default is 500 ms."
|
28 |
+
},
|
29 |
+
)
|
30 |
+
max_speech_ms: float = field(
|
31 |
+
default=float("inf"),
|
32 |
+
metadata={
|
33 |
+
"help": "Maximum length of continuous speech before forcing a split. Default is infinite, allowing for uninterrupted speech segments."
|
34 |
+
},
|
35 |
+
)
|
36 |
+
speech_pad_ms: int = field(
|
37 |
+
default=500,
|
38 |
+
metadata={
|
39 |
+
"help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms."
|
40 |
+
},
|
41 |
+
)
|
42 |
+
audio_enhancement: bool = field(
|
43 |
+
default=False,
|
44 |
+
metadata={
|
45 |
+
"help": "improves sound quality by applying techniques like noise reduction, equalization, and echo cancellation. Default is False."
|
46 |
+
},
|
47 |
+
)
|
arguments_classes/whisper_stt_arguments.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class WhisperSTTHandlerArguments:
|
7 |
+
stt_model_name: str = field(
|
8 |
+
default="distil-whisper/distil-large-v3",
|
9 |
+
metadata={
|
10 |
+
"help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'."
|
11 |
+
},
|
12 |
+
)
|
13 |
+
stt_device: str = field(
|
14 |
+
default="cuda",
|
15 |
+
metadata={
|
16 |
+
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
|
17 |
+
},
|
18 |
+
)
|
19 |
+
stt_torch_dtype: str = field(
|
20 |
+
default="float16",
|
21 |
+
metadata={
|
22 |
+
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
|
23 |
+
},
|
24 |
+
)
|
25 |
+
stt_compile_mode: str = field(
|
26 |
+
default=None,
|
27 |
+
metadata={
|
28 |
+
"help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
|
29 |
+
},
|
30 |
+
)
|
31 |
+
stt_gen_max_new_tokens: int = field(
|
32 |
+
default=128,
|
33 |
+
metadata={
|
34 |
+
"help": "The maximum number of new tokens to generate. Default is 128."
|
35 |
+
},
|
36 |
+
)
|
37 |
+
stt_gen_num_beams: int = field(
|
38 |
+
default=1,
|
39 |
+
metadata={
|
40 |
+
"help": "The number of beams for beam search. Default is 1, implying greedy decoding."
|
41 |
+
},
|
42 |
+
)
|
43 |
+
stt_gen_return_timestamps: bool = field(
|
44 |
+
default=False,
|
45 |
+
metadata={
|
46 |
+
"help": "Whether to return timestamps with transcriptions. Default is False."
|
47 |
+
},
|
48 |
+
)
|
49 |
+
stt_gen_task: str = field(
|
50 |
+
default="transcribe",
|
51 |
+
metadata={
|
52 |
+
"help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'."
|
53 |
+
},
|
54 |
+
)
|
55 |
+
language: Optional[str] = field(
|
56 |
+
default='en',
|
57 |
+
metadata={
|
58 |
+
"help": """The language for the conversation.
|
59 |
+
Choose between 'en' (english), 'fr' (french), 'es' (spanish),
|
60 |
+
'zh' (chinese), 'ko' (korean), 'ja' (japanese), or 'None'.
|
61 |
+
If using 'auto', the language is automatically detected and can
|
62 |
+
change during the conversation. Default is 'en'."""
|
63 |
+
},
|
64 |
+
)
|