Benedict King
commited on
Commit
·
2ec384d
1
Parent(s):
d4d650a
feat: add TextToSpeechRequest model and implement audio speech endpoint with processing logic
Browse files- main.py +17 -3
- models.py +9 -1
- request.py +28 -1
- response.py +4 -1
- utils.py +1 -0
main.py
CHANGED
@@ -15,7 +15,7 @@ from starlette.responses import StreamingResponse as StarletteStreamingResponse
|
|
15 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
16 |
from fastapi.exceptions import RequestValidationError
|
17 |
|
18 |
-
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest
|
19 |
from request import get_payload
|
20 |
from response import fetch_response, fetch_response_stream
|
21 |
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
|
@@ -360,6 +360,9 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
360 |
moderated_content = request_model.get_last_text_message()
|
361 |
elif request_model.request_type == "image":
|
362 |
moderated_content = request_model.prompt
|
|
|
|
|
|
|
363 |
if moderated_content:
|
364 |
current_info["text"] = moderated_content
|
365 |
|
@@ -521,6 +524,10 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
521 |
engine = "moderation"
|
522 |
request.stream = False
|
523 |
|
|
|
|
|
|
|
|
|
524 |
if provider.get("engine"):
|
525 |
engine = provider["engine"]
|
526 |
|
@@ -662,7 +669,7 @@ class ModelRequestHandler:
|
|
662 |
logger.info("available provider: %s", json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder))
|
663 |
return provider_list
|
664 |
|
665 |
-
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], token: str, endpoint=None):
|
666 |
config = app.state.config
|
667 |
# api_keys_db = app.state.api_keys_db
|
668 |
api_list = app.state.api_list
|
@@ -705,7 +712,7 @@ class ModelRequestHandler:
|
|
705 |
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint, token)
|
706 |
|
707 |
# 在 try_all_providers 函数中处理失败的情况
|
708 |
-
async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None, token: str = None):
|
709 |
status_code = 500
|
710 |
error_message = None
|
711 |
num_providers = len(providers)
|
@@ -866,6 +873,13 @@ async def images_generations(
|
|
866 |
):
|
867 |
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
|
868 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
869 |
@app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
|
870 |
async def moderations(
|
871 |
request: ModerationRequest,
|
|
|
15 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
16 |
from fastapi.exceptions import RequestValidationError
|
17 |
|
18 |
+
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, TextToSpeechRequest, UnifiedRequest
|
19 |
from request import get_payload
|
20 |
from response import fetch_response, fetch_response_stream
|
21 |
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
|
|
|
360 |
moderated_content = request_model.get_last_text_message()
|
361 |
elif request_model.request_type == "image":
|
362 |
moderated_content = request_model.prompt
|
363 |
+
elif model.startswith("tts"):
|
364 |
+
moderated_content = request_model.input
|
365 |
+
|
366 |
if moderated_content:
|
367 |
current_info["text"] = moderated_content
|
368 |
|
|
|
524 |
engine = "moderation"
|
525 |
request.stream = False
|
526 |
|
527 |
+
if endpoint == "/v1/audio/speech":
|
528 |
+
engine = "tts"
|
529 |
+
request.stream = False
|
530 |
+
|
531 |
if provider.get("engine"):
|
532 |
engine = provider["engine"]
|
533 |
|
|
|
669 |
logger.info("available provider: %s", json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder))
|
670 |
return provider_list
|
671 |
|
672 |
+
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, TextToSpeechRequest], token: str, endpoint=None):
|
673 |
config = app.state.config
|
674 |
# api_keys_db = app.state.api_keys_db
|
675 |
api_list = app.state.api_list
|
|
|
712 |
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint, token)
|
713 |
|
714 |
# 在 try_all_providers 函数中处理失败的情况
|
715 |
+
async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, TextToSpeechRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None, token: str = None):
|
716 |
status_code = 500
|
717 |
error_message = None
|
718 |
num_providers = len(providers)
|
|
|
873 |
):
|
874 |
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
|
875 |
|
876 |
+
@app.post("/v1/audio/speech", dependencies=[Depends(rate_limit_dependency)])
|
877 |
+
async def audio_speech(
|
878 |
+
request: TextToSpeechRequest,
|
879 |
+
token: str = Depends(verify_api_key)
|
880 |
+
):
|
881 |
+
return await model_handler.request_model(request, token, endpoint="/v1/audio/speech")
|
882 |
+
|
883 |
@app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
|
884 |
async def moderations(
|
885 |
request: ModerationRequest,
|
models.py
CHANGED
@@ -134,4 +134,12 @@ class UnifiedRequest(BaseModel):
|
|
134 |
values["data"].request_type = "moderation"
|
135 |
else:
|
136 |
raise ValueError("无法确定请求类型")
|
137 |
-
return values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
values["data"].request_type = "moderation"
|
135 |
else:
|
136 |
raise ValueError("无法确定请求类型")
|
137 |
+
return values
|
138 |
+
|
139 |
+
class TextToSpeechRequest(BaseRequest):
|
140 |
+
model: str
|
141 |
+
input: str
|
142 |
+
voice: str
|
143 |
+
response_format: Optional[str] = "mp3"
|
144 |
+
speed: Optional[float] = 1.0
|
145 |
+
stream: Optional[bool] = False # Add this line
|
request.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import os
|
2 |
import re
|
3 |
import json
|
|
|
4 |
import httpx
|
5 |
import base64
|
6 |
import urllib.parse
|
@@ -1134,7 +1135,33 @@ async def get_payload(request: RequestModel, engine, provider):
|
|
1134 |
return await get_dalle_payload(request, engine, provider)
|
1135 |
elif engine == "whisper":
|
1136 |
return await get_whisper_payload(request, engine, provider)
|
|
|
|
|
1137 |
elif engine == "moderation":
|
1138 |
return await get_moderation_payload(request, engine, provider)
|
1139 |
else:
|
1140 |
-
raise ValueError("Unknown payload")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import re
|
3 |
import json
|
4 |
+
from venv import logger
|
5 |
import httpx
|
6 |
import base64
|
7 |
import urllib.parse
|
|
|
1135 |
return await get_dalle_payload(request, engine, provider)
|
1136 |
elif engine == "whisper":
|
1137 |
return await get_whisper_payload(request, engine, provider)
|
1138 |
+
elif engine == "tts":
|
1139 |
+
return await get_tts_payload(request, engine, provider)
|
1140 |
elif engine == "moderation":
|
1141 |
return await get_moderation_payload(request, engine, provider)
|
1142 |
else:
|
1143 |
+
raise ValueError("Unknown payload")
|
1144 |
+
|
1145 |
+
async def get_tts_payload(request, engine, provider):
|
1146 |
+
headers = {
|
1147 |
+
"Content-Type": "application/json",
|
1148 |
+
}
|
1149 |
+
if provider.get("api"):
|
1150 |
+
headers['Authorization'] = f"Bearer {provider['api'].next()}"
|
1151 |
+
url = provider['base_url']
|
1152 |
+
url = BaseAPI(url).audio_speech
|
1153 |
+
|
1154 |
+
payload = {
|
1155 |
+
"model": provider['model'][request.model],
|
1156 |
+
"input": request.input,
|
1157 |
+
"voice": request.voice,
|
1158 |
+
}
|
1159 |
+
|
1160 |
+
if request.response_format:
|
1161 |
+
payload["response_format"] = request.response_format
|
1162 |
+
if request.speed:
|
1163 |
+
payload["speed"] = request.speed
|
1164 |
+
if request.stream is not None:
|
1165 |
+
payload["stream"] = request.stream
|
1166 |
+
|
1167 |
+
return url, headers, payload
|
response.py
CHANGED
@@ -285,7 +285,10 @@ async def fetch_response(client, url, headers, payload):
|
|
285 |
if error_message:
|
286 |
yield error_message
|
287 |
return
|
288 |
-
|
|
|
|
|
|
|
289 |
|
290 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
291 |
try:
|
|
|
285 |
if error_message:
|
286 |
yield error_message
|
287 |
return
|
288 |
+
if url.endswith("/v1/audio/speech"):
|
289 |
+
yield response.read()
|
290 |
+
else:
|
291 |
+
yield response.json()
|
292 |
|
293 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
294 |
try:
|
utils.py
CHANGED
@@ -313,6 +313,7 @@ class BaseAPI:
|
|
313 |
self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
|
314 |
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
|
315 |
self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3)
|
|
|
316 |
|
317 |
def safe_get(data, *keys, default=None):
|
318 |
for key in keys:
|
|
|
313 |
self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
|
314 |
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
|
315 |
self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3)
|
316 |
+
self.audio_speech: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/speech",) + ("",) * 3)
|
317 |
|
318 |
def safe_get(data, *keys, default=None):
|
319 |
for key in keys:
|