✨ Feature: Add feature: Support for text-to-speech endpoint /v1/audio/speech
Browse files- main.py +26 -20
- models.py +14 -2
- request.py +29 -0
- response.py +7 -2
- utils.py +23 -1
main.py
CHANGED
@@ -15,7 +15,7 @@ from starlette.responses import StreamingResponse as StarletteStreamingResponse
|
|
15 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
16 |
from fastapi.exceptions import RequestValidationError
|
17 |
|
18 |
-
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest, EmbeddingRequest
|
19 |
from request import get_payload
|
20 |
from response import fetch_response, fetch_response_stream
|
21 |
from utils import (
|
@@ -392,6 +392,9 @@ class LoggingStreamingResponse(Response):
|
|
392 |
async for chunk in self.body_iterator:
|
393 |
if isinstance(chunk, str):
|
394 |
chunk = chunk.encode('utf-8')
|
|
|
|
|
|
|
395 |
line = chunk.decode('utf-8')
|
396 |
if is_debug:
|
397 |
logger.info(f"{line.encode('utf-8').decode('unicode_escape')}")
|
@@ -504,6 +507,8 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
504 |
moderated_content = request_model.get_last_text_message()
|
505 |
elif request_model.request_type == "image":
|
506 |
moderated_content = request_model.prompt
|
|
|
|
|
507 |
elif request_model.request_type == "moderation":
|
508 |
pass
|
509 |
elif request_model.request_type == "embedding":
|
@@ -817,6 +822,9 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
817 |
|
818 |
if endpoint == "/v1/embeddings":
|
819 |
engine = "embedding"
|
|
|
|
|
|
|
820 |
request.stream = False
|
821 |
|
822 |
if provider.get("engine"):
|
@@ -848,19 +856,6 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
848 |
|
849 |
try:
|
850 |
async with app.state.client_manager.get_client(timeout_value, url, proxy) as client:
|
851 |
-
# 打印client配置信息
|
852 |
-
# logger.info(f"Client config - Timeout: {client.timeout}")
|
853 |
-
# logger.info(f"Client config - Headers: {client.headers}")
|
854 |
-
# if hasattr(client, '_transport'):
|
855 |
-
# if hasattr(client._transport, 'proxy_url'):
|
856 |
-
# logger.info(f"Client config - Proxy: {client._transport.proxy_url}")
|
857 |
-
# elif hasattr(client._transport, 'proxies'):
|
858 |
-
# logger.info(f"Client config - Proxies: {client._transport.proxies}")
|
859 |
-
# else:
|
860 |
-
# logger.info("Client config - No proxy configured")
|
861 |
-
# else:
|
862 |
-
# logger.info("Client config - No transport configured")
|
863 |
-
# logger.info(f"Client config - Follow Redirects: {client.follow_redirects}")
|
864 |
if request.stream:
|
865 |
generator = fetch_response_stream(client, url, headers, payload, engine, original_model)
|
866 |
wrapped_generator, first_response_time = await error_handling_wrapper(generator, channel_id)
|
@@ -868,12 +863,16 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
868 |
else:
|
869 |
generator = fetch_response(client, url, headers, payload, engine, original_model)
|
870 |
wrapped_generator, first_response_time = await error_handling_wrapper(generator, channel_id)
|
871 |
-
|
872 |
-
|
873 |
-
|
874 |
-
|
875 |
-
|
876 |
-
|
|
|
|
|
|
|
|
|
877 |
|
878 |
# 更新成功计数和首次响应时间
|
879 |
await update_channel_stats(current_info["request_id"], channel_id, request.model, current_info["api_key"], success=True)
|
@@ -1269,6 +1268,13 @@ async def embeddings(
|
|
1269 |
):
|
1270 |
return await model_handler.request_model(request, api_index, endpoint="/v1/embeddings")
|
1271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1272 |
@app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
|
1273 |
async def moderations(
|
1274 |
request: ModerationRequest,
|
|
|
15 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
16 |
from fastapi.exceptions import RequestValidationError
|
17 |
|
18 |
+
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, TextToSpeechRequest, UnifiedRequest, EmbeddingRequest
|
19 |
from request import get_payload
|
20 |
from response import fetch_response, fetch_response_stream
|
21 |
from utils import (
|
|
|
392 |
async for chunk in self.body_iterator:
|
393 |
if isinstance(chunk, str):
|
394 |
chunk = chunk.encode('utf-8')
|
395 |
+
if isinstance(chunk, bytes):
|
396 |
+
yield chunk
|
397 |
+
continue
|
398 |
line = chunk.decode('utf-8')
|
399 |
if is_debug:
|
400 |
logger.info(f"{line.encode('utf-8').decode('unicode_escape')}")
|
|
|
507 |
moderated_content = request_model.get_last_text_message()
|
508 |
elif request_model.request_type == "image":
|
509 |
moderated_content = request_model.prompt
|
510 |
+
elif request_model.request_type == "tts":
|
511 |
+
moderated_content = request_model.input
|
512 |
elif request_model.request_type == "moderation":
|
513 |
pass
|
514 |
elif request_model.request_type == "embedding":
|
|
|
822 |
|
823 |
if endpoint == "/v1/embeddings":
|
824 |
engine = "embedding"
|
825 |
+
|
826 |
+
if endpoint == "/v1/audio/speech":
|
827 |
+
engine = "tts"
|
828 |
request.stream = False
|
829 |
|
830 |
if provider.get("engine"):
|
|
|
856 |
|
857 |
try:
|
858 |
async with app.state.client_manager.get_client(timeout_value, url, proxy) as client:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
859 |
if request.stream:
|
860 |
generator = fetch_response_stream(client, url, headers, payload, engine, original_model)
|
861 |
wrapped_generator, first_response_time = await error_handling_wrapper(generator, channel_id)
|
|
|
863 |
else:
|
864 |
generator = fetch_response(client, url, headers, payload, engine, original_model)
|
865 |
wrapped_generator, first_response_time = await error_handling_wrapper(generator, channel_id)
|
866 |
+
|
867 |
+
# 处理音频和其他二进制响应
|
868 |
+
if endpoint == "/v1/audio/speech":
|
869 |
+
if isinstance(wrapped_generator, bytes):
|
870 |
+
response = Response(content=wrapped_generator, media_type="audio/mpeg")
|
871 |
+
else:
|
872 |
+
first_element = await anext(wrapped_generator)
|
873 |
+
first_element = first_element.lstrip("data: ")
|
874 |
+
first_element = json.loads(first_element)
|
875 |
+
response = StarletteStreamingResponse(iter([json.dumps(first_element)]), media_type="application/json")
|
876 |
|
877 |
# 更新成功计数和首次响应时间
|
878 |
await update_channel_stats(current_info["request_id"], channel_id, request.model, current_info["api_key"], success=True)
|
|
|
1268 |
):
|
1269 |
return await model_handler.request_model(request, api_index, endpoint="/v1/embeddings")
|
1270 |
|
1271 |
+
@app.post("/v1/audio/speech", dependencies=[Depends(rate_limit_dependency)])
|
1272 |
+
async def audio_speech(
|
1273 |
+
request: TextToSpeechRequest,
|
1274 |
+
api_index: str = Depends(verify_api_key)
|
1275 |
+
):
|
1276 |
+
return await model_handler.request_model(request, api_index, endpoint="/v1/audio/speech")
|
1277 |
+
|
1278 |
@app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
|
1279 |
async def moderations(
|
1280 |
request: ModerationRequest,
|
models.py
CHANGED
@@ -136,8 +136,16 @@ class ModerationRequest(BaseRequest):
|
|
136 |
model: Optional[str] = "text-moderation-latest"
|
137 |
stream: bool = False
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
class UnifiedRequest(BaseModel):
|
140 |
-
data: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest]
|
141 |
|
142 |
@model_validator(mode='before')
|
143 |
@classmethod
|
@@ -152,6 +160,10 @@ class UnifiedRequest(BaseModel):
|
|
152 |
elif "file" in values:
|
153 |
values["data"] = AudioTranscriptionRequest(**values)
|
154 |
values["data"].request_type = "audio"
|
|
|
|
|
|
|
|
|
155 |
elif "text-embedding" in values.get("model", ""):
|
156 |
values["data"] = EmbeddingRequest(**values)
|
157 |
values["data"].request_type = "embedding"
|
@@ -160,4 +172,4 @@ class UnifiedRequest(BaseModel):
|
|
160 |
values["data"].request_type = "moderation"
|
161 |
else:
|
162 |
raise ValueError("无法确定请求类型")
|
163 |
-
return values
|
|
|
136 |
model: Optional[str] = "text-moderation-latest"
|
137 |
stream: bool = False
|
138 |
|
139 |
+
class TextToSpeechRequest(BaseRequest):
|
140 |
+
model: str
|
141 |
+
input: str
|
142 |
+
voice: str
|
143 |
+
response_format: Optional[str] = "mp3"
|
144 |
+
speed: Optional[float] = 1.0
|
145 |
+
stream: Optional[bool] = False # Add this line
|
146 |
+
|
147 |
class UnifiedRequest(BaseModel):
|
148 |
+
data: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest, TextToSpeechRequest]
|
149 |
|
150 |
@model_validator(mode='before')
|
151 |
@classmethod
|
|
|
160 |
elif "file" in values:
|
161 |
values["data"] = AudioTranscriptionRequest(**values)
|
162 |
values["data"].request_type = "audio"
|
163 |
+
elif "tts" in values.get("model", ""):
|
164 |
+
logger.info(f"TextToSpeechRequest: {values}")
|
165 |
+
values["data"] = TextToSpeechRequest(**values)
|
166 |
+
values["data"].request_type = "tts"
|
167 |
elif "text-embedding" in values.get("model", ""):
|
168 |
values["data"] = EmbeddingRequest(**values)
|
169 |
values["data"].request_type = "embedding"
|
|
|
172 |
values["data"].request_type = "moderation"
|
173 |
else:
|
174 |
raise ValueError("无法确定请求类型")
|
175 |
+
return values
|
request.py
CHANGED
@@ -1145,6 +1145,33 @@ async def get_embedding_payload(request, engine, provider):
|
|
1145 |
|
1146 |
return url, headers, payload
|
1147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1148 |
async def get_payload(request: RequestModel, engine, provider):
|
1149 |
if engine == "gemini":
|
1150 |
return await get_gemini_payload(request, engine, provider)
|
@@ -1168,6 +1195,8 @@ async def get_payload(request: RequestModel, engine, provider):
|
|
1168 |
return await get_dalle_payload(request, engine, provider)
|
1169 |
elif engine == "whisper":
|
1170 |
return await get_whisper_payload(request, engine, provider)
|
|
|
|
|
1171 |
elif engine == "moderation":
|
1172 |
return await get_moderation_payload(request, engine, provider)
|
1173 |
elif engine == "embedding":
|
|
|
1145 |
|
1146 |
return url, headers, payload
|
1147 |
|
1148 |
+
async def get_tts_payload(request, engine, provider):
|
1149 |
+
model_dict = get_model_dict(provider)
|
1150 |
+
model = model_dict[request.model]
|
1151 |
+
headers = {
|
1152 |
+
"Content-Type": "application/json",
|
1153 |
+
}
|
1154 |
+
if provider.get("api"):
|
1155 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
|
1156 |
+
url = provider['base_url']
|
1157 |
+
url = BaseAPI(url).audio_speech
|
1158 |
+
|
1159 |
+
payload = {
|
1160 |
+
"model": model,
|
1161 |
+
"input": request.input,
|
1162 |
+
"voice": request.voice,
|
1163 |
+
}
|
1164 |
+
|
1165 |
+
if request.response_format:
|
1166 |
+
payload["response_format"] = request.response_format
|
1167 |
+
if request.speed:
|
1168 |
+
payload["speed"] = request.speed
|
1169 |
+
if request.stream is not None:
|
1170 |
+
payload["stream"] = request.stream
|
1171 |
+
|
1172 |
+
return url, headers, payload
|
1173 |
+
|
1174 |
+
|
1175 |
async def get_payload(request: RequestModel, engine, provider):
|
1176 |
if engine == "gemini":
|
1177 |
return await get_gemini_payload(request, engine, provider)
|
|
|
1195 |
return await get_dalle_payload(request, engine, provider)
|
1196 |
elif engine == "whisper":
|
1197 |
return await get_whisper_payload(request, engine, provider)
|
1198 |
+
elif engine == "tts":
|
1199 |
+
return await get_tts_payload(request, engine, provider)
|
1200 |
elif engine == "moderation":
|
1201 |
return await get_moderation_payload(request, engine, provider)
|
1202 |
elif engine == "embedding":
|
response.py
CHANGED
@@ -326,8 +326,12 @@ async def fetch_response(client, url, headers, payload, engine, model):
|
|
326 |
if error_message:
|
327 |
yield error_message
|
328 |
return
|
329 |
-
|
330 |
-
if engine == "
|
|
|
|
|
|
|
|
|
331 |
|
332 |
if isinstance(response_json, str):
|
333 |
import ast
|
@@ -361,6 +365,7 @@ async def fetch_response(client, url, headers, payload, engine, model):
|
|
361 |
yield await generate_no_stream_response(timestamp, model, content=content, tools_id=None, function_call_name=None, function_call_content=None, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=candidates_tokens)
|
362 |
|
363 |
else:
|
|
|
364 |
yield response_json
|
365 |
|
366 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
|
|
326 |
if error_message:
|
327 |
yield error_message
|
328 |
return
|
329 |
+
|
330 |
+
if engine == "tts":
|
331 |
+
yield response.read()
|
332 |
+
|
333 |
+
elif engine == "gemini" or engine == "vertex-gemini":
|
334 |
+
response_json = response.json()
|
335 |
|
336 |
if isinstance(response_json, str):
|
337 |
import ast
|
|
|
365 |
yield await generate_no_stream_response(timestamp, model, content=content, tools_id=None, function_call_name=None, function_call_content=None, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=candidates_tokens)
|
366 |
|
367 |
else:
|
368 |
+
response_json = response.json()
|
369 |
yield response_json
|
370 |
|
371 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
utils.py
CHANGED
@@ -416,6 +416,24 @@ def ensure_string(item):
|
|
416 |
else:
|
417 |
return str(item)
|
418 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
import asyncio
|
420 |
import time as time_module
|
421 |
async def error_handling_wrapper(generator, channel_id):
|
@@ -426,7 +444,10 @@ async def error_handling_wrapper(generator, channel_id):
|
|
426 |
first_item_str = first_item
|
427 |
# logger.info("first_item_str: %s", first_item_str)
|
428 |
if isinstance(first_item_str, (bytes, bytearray)):
|
429 |
-
first_item_str
|
|
|
|
|
|
|
430 |
if isinstance(first_item_str, str):
|
431 |
if first_item_str.startswith("data:"):
|
432 |
first_item_str = first_item_str.lstrip("data: ")
|
@@ -598,6 +619,7 @@ class BaseAPI:
|
|
598 |
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "audio/transcriptions",) + ("",) * 3)
|
599 |
self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "moderations",) + ("",) * 3)
|
600 |
self.embeddings: str = urlunparse(parsed_url[:2] + (before_v1 + "embeddings",) + ("",) * 3)
|
|
|
601 |
|
602 |
def safe_get(data, *keys, default=None):
|
603 |
for key in keys:
|
|
|
416 |
else:
|
417 |
return str(item)
|
418 |
|
419 |
+
def identify_audio_format(file_bytes):
|
420 |
+
# 读取开头的字节
|
421 |
+
if file_bytes.startswith(b'\xFF\xFB') or file_bytes.startswith(b'\xFF\xF3'):
|
422 |
+
return "MP3"
|
423 |
+
elif file_bytes.startswith(b'ID3'):
|
424 |
+
return "MP3 with ID3"
|
425 |
+
elif file_bytes.startswith(b'OpusHead'):
|
426 |
+
return "OPUS"
|
427 |
+
elif file_bytes.startswith(b'ADIF'):
|
428 |
+
return "AAC (ADIF)"
|
429 |
+
elif file_bytes.startswith(b'\xFF\xF1') or file_bytes.startswith(b'\xFF\xF9'):
|
430 |
+
return "AAC (ADTS)"
|
431 |
+
elif file_bytes.startswith(b'fLaC'):
|
432 |
+
return "FLAC"
|
433 |
+
elif file_bytes.startswith(b'RIFF') and file_bytes[8:12] == b'WAVE':
|
434 |
+
return "WAV"
|
435 |
+
return "Unknown/PCM"
|
436 |
+
|
437 |
import asyncio
|
438 |
import time as time_module
|
439 |
async def error_handling_wrapper(generator, channel_id):
|
|
|
444 |
first_item_str = first_item
|
445 |
# logger.info("first_item_str: %s", first_item_str)
|
446 |
if isinstance(first_item_str, (bytes, bytearray)):
|
447 |
+
if identify_audio_format(first_item_str) in ["MP3", "MP3 with ID3", "OPUS", "AAC (ADIF)", "AAC (ADTS)", "FLAC", "WAV"]:
|
448 |
+
return first_item, first_response_time
|
449 |
+
else:
|
450 |
+
first_item_str = first_item_str.decode("utf-8")
|
451 |
if isinstance(first_item_str, str):
|
452 |
if first_item_str.startswith("data:"):
|
453 |
first_item_str = first_item_str.lstrip("data: ")
|
|
|
619 |
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "audio/transcriptions",) + ("",) * 3)
|
620 |
self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "moderations",) + ("",) * 3)
|
621 |
self.embeddings: str = urlunparse(parsed_url[:2] + (before_v1 + "embeddings",) + ("",) * 3)
|
622 |
+
self.audio_speech: str = urlunparse(parsed_url[:2] + (before_v1 + "audio/speech",) + ("",) * 3)
|
623 |
|
624 |
def safe_get(data, *keys, default=None):
|
625 |
for key in keys:
|