Benedict King commited on
Commit
2ec384d
·
1 Parent(s): d4d650a

feat: add TextToSpeechRequest model and implement audio speech endpoint with processing logic

Browse files
Files changed (5) hide show
  1. main.py +17 -3
  2. models.py +9 -1
  3. request.py +28 -1
  4. response.py +4 -1
  5. utils.py +1 -0
main.py CHANGED
@@ -15,7 +15,7 @@ from starlette.responses import StreamingResponse as StarletteStreamingResponse
15
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
16
  from fastapi.exceptions import RequestValidationError
17
 
18
- from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest
19
  from request import get_payload
20
  from response import fetch_response, fetch_response_stream
21
  from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
@@ -360,6 +360,9 @@ class StatsMiddleware(BaseHTTPMiddleware):
360
  moderated_content = request_model.get_last_text_message()
361
  elif request_model.request_type == "image":
362
  moderated_content = request_model.prompt
 
 
 
363
  if moderated_content:
364
  current_info["text"] = moderated_content
365
 
@@ -521,6 +524,10 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
521
  engine = "moderation"
522
  request.stream = False
523
 
 
 
 
 
524
  if provider.get("engine"):
525
  engine = provider["engine"]
526
 
@@ -662,7 +669,7 @@ class ModelRequestHandler:
662
  logger.info("available provider: %s", json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder))
663
  return provider_list
664
 
665
- async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], token: str, endpoint=None):
666
  config = app.state.config
667
  # api_keys_db = app.state.api_keys_db
668
  api_list = app.state.api_list
@@ -705,7 +712,7 @@ class ModelRequestHandler:
705
  return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint, token)
706
 
707
  # 在 try_all_providers 函数中处理失败的情况
708
- async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None, token: str = None):
709
  status_code = 500
710
  error_message = None
711
  num_providers = len(providers)
@@ -866,6 +873,13 @@ async def images_generations(
866
  ):
867
  return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
868
 
 
 
 
 
 
 
 
869
  @app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
870
  async def moderations(
871
  request: ModerationRequest,
 
15
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
16
  from fastapi.exceptions import RequestValidationError
17
 
18
+ from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, TextToSpeechRequest, UnifiedRequest
19
  from request import get_payload
20
  from response import fetch_response, fetch_response_stream
21
  from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
 
360
  moderated_content = request_model.get_last_text_message()
361
  elif request_model.request_type == "image":
362
  moderated_content = request_model.prompt
363
+ elif model.startswith("tts"):
364
+ moderated_content = request_model.input
365
+
366
  if moderated_content:
367
  current_info["text"] = moderated_content
368
 
 
524
  engine = "moderation"
525
  request.stream = False
526
 
527
+ if endpoint == "/v1/audio/speech":
528
+ engine = "tts"
529
+ request.stream = False
530
+
531
  if provider.get("engine"):
532
  engine = provider["engine"]
533
 
 
669
  logger.info("available provider: %s", json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder))
670
  return provider_list
671
 
672
+ async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, TextToSpeechRequest], token: str, endpoint=None):
673
  config = app.state.config
674
  # api_keys_db = app.state.api_keys_db
675
  api_list = app.state.api_list
 
712
  return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint, token)
713
 
714
  # 在 try_all_providers 函数中处理失败的情况
715
+ async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, TextToSpeechRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None, token: str = None):
716
  status_code = 500
717
  error_message = None
718
  num_providers = len(providers)
 
873
  ):
874
  return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
875
 
876
+ @app.post("/v1/audio/speech", dependencies=[Depends(rate_limit_dependency)])
877
+ async def audio_speech(
878
+ request: TextToSpeechRequest,
879
+ token: str = Depends(verify_api_key)
880
+ ):
881
+ return await model_handler.request_model(request, token, endpoint="/v1/audio/speech")
882
+
883
  @app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
884
  async def moderations(
885
  request: ModerationRequest,
models.py CHANGED
@@ -134,4 +134,12 @@ class UnifiedRequest(BaseModel):
134
  values["data"].request_type = "moderation"
135
  else:
136
  raise ValueError("无法确定请求类型")
137
- return values
 
 
 
 
 
 
 
 
 
134
  values["data"].request_type = "moderation"
135
  else:
136
  raise ValueError("无法确定请求类型")
137
+ return values
138
+
139
+ class TextToSpeechRequest(BaseRequest):
140
+ model: str
141
+ input: str
142
+ voice: str
143
+ response_format: Optional[str] = "mp3"
144
+ speed: Optional[float] = 1.0
145
+ stream: Optional[bool] = False # Add this line
request.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import re
3
  import json
 
4
  import httpx
5
  import base64
6
  import urllib.parse
@@ -1134,7 +1135,33 @@ async def get_payload(request: RequestModel, engine, provider):
1134
  return await get_dalle_payload(request, engine, provider)
1135
  elif engine == "whisper":
1136
  return await get_whisper_payload(request, engine, provider)
 
 
1137
  elif engine == "moderation":
1138
  return await get_moderation_payload(request, engine, provider)
1139
  else:
1140
- raise ValueError("Unknown payload")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import re
3
  import json
4
+ from venv import logger
5
  import httpx
6
  import base64
7
  import urllib.parse
 
1135
  return await get_dalle_payload(request, engine, provider)
1136
  elif engine == "whisper":
1137
  return await get_whisper_payload(request, engine, provider)
1138
+ elif engine == "tts":
1139
+ return await get_tts_payload(request, engine, provider)
1140
  elif engine == "moderation":
1141
  return await get_moderation_payload(request, engine, provider)
1142
  else:
1143
+ raise ValueError("Unknown payload")
1144
+
1145
+ async def get_tts_payload(request, engine, provider):
1146
+ headers = {
1147
+ "Content-Type": "application/json",
1148
+ }
1149
+ if provider.get("api"):
1150
+ headers['Authorization'] = f"Bearer {provider['api'].next()}"
1151
+ url = provider['base_url']
1152
+ url = BaseAPI(url).audio_speech
1153
+
1154
+ payload = {
1155
+ "model": provider['model'][request.model],
1156
+ "input": request.input,
1157
+ "voice": request.voice,
1158
+ }
1159
+
1160
+ if request.response_format:
1161
+ payload["response_format"] = request.response_format
1162
+ if request.speed:
1163
+ payload["speed"] = request.speed
1164
+ if request.stream is not None:
1165
+ payload["stream"] = request.stream
1166
+
1167
+ return url, headers, payload
response.py CHANGED
@@ -285,7 +285,10 @@ async def fetch_response(client, url, headers, payload):
285
  if error_message:
286
  yield error_message
287
  return
288
- yield response.json()
 
 
 
289
 
290
  async def fetch_response_stream(client, url, headers, payload, engine, model):
291
  try:
 
285
  if error_message:
286
  yield error_message
287
  return
288
+ if url.endswith("/v1/audio/speech"):
289
+ yield response.read()
290
+ else:
291
+ yield response.json()
292
 
293
  async def fetch_response_stream(client, url, headers, payload, engine, model):
294
  try:
utils.py CHANGED
@@ -313,6 +313,7 @@ class BaseAPI:
313
  self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
314
  self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
315
  self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3)
 
316
 
317
  def safe_get(data, *keys, default=None):
318
  for key in keys:
 
313
  self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
314
  self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
315
  self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3)
316
+ self.audio_speech: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/speech",) + ("",) * 3)
317
 
318
  def safe_get(data, *keys, default=None):
319
  for key in keys: