yym68686 commited on
Commit
483c524
·
2 Parent(s): 5962f8a 9bd1487

✨ Feature: Add feature: Support for text-to-speech endpoint /v1/audio/speech

Browse files
Files changed (5) hide show
  1. main.py +26 -20
  2. models.py +14 -2
  3. request.py +29 -0
  4. response.py +7 -2
  5. utils.py +23 -1
main.py CHANGED
@@ -15,7 +15,7 @@ from starlette.responses import StreamingResponse as StarletteStreamingResponse
15
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
16
  from fastapi.exceptions import RequestValidationError
17
 
18
- from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest, EmbeddingRequest
19
  from request import get_payload
20
  from response import fetch_response, fetch_response_stream
21
  from utils import (
@@ -392,6 +392,9 @@ class LoggingStreamingResponse(Response):
392
  async for chunk in self.body_iterator:
393
  if isinstance(chunk, str):
394
  chunk = chunk.encode('utf-8')
 
 
 
395
  line = chunk.decode('utf-8')
396
  if is_debug:
397
  logger.info(f"{line.encode('utf-8').decode('unicode_escape')}")
@@ -504,6 +507,8 @@ class StatsMiddleware(BaseHTTPMiddleware):
504
  moderated_content = request_model.get_last_text_message()
505
  elif request_model.request_type == "image":
506
  moderated_content = request_model.prompt
 
 
507
  elif request_model.request_type == "moderation":
508
  pass
509
  elif request_model.request_type == "embedding":
@@ -817,6 +822,9 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
817
 
818
  if endpoint == "/v1/embeddings":
819
  engine = "embedding"
 
 
 
820
  request.stream = False
821
 
822
  if provider.get("engine"):
@@ -848,19 +856,6 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
848
 
849
  try:
850
  async with app.state.client_manager.get_client(timeout_value, url, proxy) as client:
851
- # 打印client配置信息
852
- # logger.info(f"Client config - Timeout: {client.timeout}")
853
- # logger.info(f"Client config - Headers: {client.headers}")
854
- # if hasattr(client, '_transport'):
855
- # if hasattr(client._transport, 'proxy_url'):
856
- # logger.info(f"Client config - Proxy: {client._transport.proxy_url}")
857
- # elif hasattr(client._transport, 'proxies'):
858
- # logger.info(f"Client config - Proxies: {client._transport.proxies}")
859
- # else:
860
- # logger.info("Client config - No proxy configured")
861
- # else:
862
- # logger.info("Client config - No transport configured")
863
- # logger.info(f"Client config - Follow Redirects: {client.follow_redirects}")
864
  if request.stream:
865
  generator = fetch_response_stream(client, url, headers, payload, engine, original_model)
866
  wrapped_generator, first_response_time = await error_handling_wrapper(generator, channel_id)
@@ -868,12 +863,16 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
868
  else:
869
  generator = fetch_response(client, url, headers, payload, engine, original_model)
870
  wrapped_generator, first_response_time = await error_handling_wrapper(generator, channel_id)
871
- first_element = await anext(wrapped_generator)
872
- first_element = first_element.lstrip("data: ")
873
- # print("first_element", first_element)
874
- first_element = json.loads(first_element)
875
- response = StarletteStreamingResponse(iter([json.dumps(first_element)]), media_type="application/json")
876
- # response = JSONResponse(first_element)
 
 
 
 
877
 
878
  # 更新成功计数和首次响应时间
879
  await update_channel_stats(current_info["request_id"], channel_id, request.model, current_info["api_key"], success=True)
@@ -1269,6 +1268,13 @@ async def embeddings(
1269
  ):
1270
  return await model_handler.request_model(request, api_index, endpoint="/v1/embeddings")
1271
 
 
 
 
 
 
 
 
1272
  @app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
1273
  async def moderations(
1274
  request: ModerationRequest,
 
15
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
16
  from fastapi.exceptions import RequestValidationError
17
 
18
+ from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, TextToSpeechRequest, UnifiedRequest, EmbeddingRequest
19
  from request import get_payload
20
  from response import fetch_response, fetch_response_stream
21
  from utils import (
 
392
  async for chunk in self.body_iterator:
393
  if isinstance(chunk, str):
394
  chunk = chunk.encode('utf-8')
395
+ if isinstance(chunk, bytes):
396
+ yield chunk
397
+ continue
398
  line = chunk.decode('utf-8')
399
  if is_debug:
400
  logger.info(f"{line.encode('utf-8').decode('unicode_escape')}")
 
507
  moderated_content = request_model.get_last_text_message()
508
  elif request_model.request_type == "image":
509
  moderated_content = request_model.prompt
510
+ elif request_model.request_type == "tts":
511
+ moderated_content = request_model.input
512
  elif request_model.request_type == "moderation":
513
  pass
514
  elif request_model.request_type == "embedding":
 
822
 
823
  if endpoint == "/v1/embeddings":
824
  engine = "embedding"
825
+
826
+ if endpoint == "/v1/audio/speech":
827
+ engine = "tts"
828
  request.stream = False
829
 
830
  if provider.get("engine"):
 
856
 
857
  try:
858
  async with app.state.client_manager.get_client(timeout_value, url, proxy) as client:
 
 
 
 
 
 
 
 
 
 
 
 
 
859
  if request.stream:
860
  generator = fetch_response_stream(client, url, headers, payload, engine, original_model)
861
  wrapped_generator, first_response_time = await error_handling_wrapper(generator, channel_id)
 
863
  else:
864
  generator = fetch_response(client, url, headers, payload, engine, original_model)
865
  wrapped_generator, first_response_time = await error_handling_wrapper(generator, channel_id)
866
+
867
+ # 处理音频和其他二进制响应
868
+ if endpoint == "/v1/audio/speech":
869
+ if isinstance(wrapped_generator, bytes):
870
+ response = Response(content=wrapped_generator, media_type="audio/mpeg")
871
+ else:
872
+ first_element = await anext(wrapped_generator)
873
+ first_element = first_element.lstrip("data: ")
874
+ first_element = json.loads(first_element)
875
+ response = StarletteStreamingResponse(iter([json.dumps(first_element)]), media_type="application/json")
876
 
877
  # 更新成功计数和首次响应时间
878
  await update_channel_stats(current_info["request_id"], channel_id, request.model, current_info["api_key"], success=True)
 
1268
  ):
1269
  return await model_handler.request_model(request, api_index, endpoint="/v1/embeddings")
1270
 
1271
+ @app.post("/v1/audio/speech", dependencies=[Depends(rate_limit_dependency)])
1272
+ async def audio_speech(
1273
+ request: TextToSpeechRequest,
1274
+ api_index: str = Depends(verify_api_key)
1275
+ ):
1276
+ return await model_handler.request_model(request, api_index, endpoint="/v1/audio/speech")
1277
+
1278
  @app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
1279
  async def moderations(
1280
  request: ModerationRequest,
models.py CHANGED
@@ -136,8 +136,16 @@ class ModerationRequest(BaseRequest):
136
  model: Optional[str] = "text-moderation-latest"
137
  stream: bool = False
138
 
 
 
 
 
 
 
 
 
139
  class UnifiedRequest(BaseModel):
140
- data: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest]
141
 
142
  @model_validator(mode='before')
143
  @classmethod
@@ -152,6 +160,10 @@ class UnifiedRequest(BaseModel):
152
  elif "file" in values:
153
  values["data"] = AudioTranscriptionRequest(**values)
154
  values["data"].request_type = "audio"
 
 
 
 
155
  elif "text-embedding" in values.get("model", ""):
156
  values["data"] = EmbeddingRequest(**values)
157
  values["data"].request_type = "embedding"
@@ -160,4 +172,4 @@ class UnifiedRequest(BaseModel):
160
  values["data"].request_type = "moderation"
161
  else:
162
  raise ValueError("无法确定请求类型")
163
- return values
 
136
  model: Optional[str] = "text-moderation-latest"
137
  stream: bool = False
138
 
139
+ class TextToSpeechRequest(BaseRequest):
140
+ model: str
141
+ input: str
142
+ voice: str
143
+ response_format: Optional[str] = "mp3"
144
+ speed: Optional[float] = 1.0
145
+ stream: Optional[bool] = False # Add this line
146
+
147
  class UnifiedRequest(BaseModel):
148
+ data: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest, TextToSpeechRequest]
149
 
150
  @model_validator(mode='before')
151
  @classmethod
 
160
  elif "file" in values:
161
  values["data"] = AudioTranscriptionRequest(**values)
162
  values["data"].request_type = "audio"
163
+ elif "tts" in values.get("model", ""):
164
+ logger.info(f"TextToSpeechRequest: {values}")
165
+ values["data"] = TextToSpeechRequest(**values)
166
+ values["data"].request_type = "tts"
167
  elif "text-embedding" in values.get("model", ""):
168
  values["data"] = EmbeddingRequest(**values)
169
  values["data"].request_type = "embedding"
 
172
  values["data"].request_type = "moderation"
173
  else:
174
  raise ValueError("无法确定请求类型")
175
+ return values
request.py CHANGED
@@ -1145,6 +1145,33 @@ async def get_embedding_payload(request, engine, provider):
1145
 
1146
  return url, headers, payload
1147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1148
  async def get_payload(request: RequestModel, engine, provider):
1149
  if engine == "gemini":
1150
  return await get_gemini_payload(request, engine, provider)
@@ -1168,6 +1195,8 @@ async def get_payload(request: RequestModel, engine, provider):
1168
  return await get_dalle_payload(request, engine, provider)
1169
  elif engine == "whisper":
1170
  return await get_whisper_payload(request, engine, provider)
 
 
1171
  elif engine == "moderation":
1172
  return await get_moderation_payload(request, engine, provider)
1173
  elif engine == "embedding":
 
1145
 
1146
  return url, headers, payload
1147
 
1148
+ async def get_tts_payload(request, engine, provider):
1149
+ model_dict = get_model_dict(provider)
1150
+ model = model_dict[request.model]
1151
+ headers = {
1152
+ "Content-Type": "application/json",
1153
+ }
1154
+ if provider.get("api"):
1155
+ headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
1156
+ url = provider['base_url']
1157
+ url = BaseAPI(url).audio_speech
1158
+
1159
+ payload = {
1160
+ "model": model,
1161
+ "input": request.input,
1162
+ "voice": request.voice,
1163
+ }
1164
+
1165
+ if request.response_format:
1166
+ payload["response_format"] = request.response_format
1167
+ if request.speed:
1168
+ payload["speed"] = request.speed
1169
+ if request.stream is not None:
1170
+ payload["stream"] = request.stream
1171
+
1172
+ return url, headers, payload
1173
+
1174
+
1175
  async def get_payload(request: RequestModel, engine, provider):
1176
  if engine == "gemini":
1177
  return await get_gemini_payload(request, engine, provider)
 
1195
  return await get_dalle_payload(request, engine, provider)
1196
  elif engine == "whisper":
1197
  return await get_whisper_payload(request, engine, provider)
1198
+ elif engine == "tts":
1199
+ return await get_tts_payload(request, engine, provider)
1200
  elif engine == "moderation":
1201
  return await get_moderation_payload(request, engine, provider)
1202
  elif engine == "embedding":
response.py CHANGED
@@ -326,8 +326,12 @@ async def fetch_response(client, url, headers, payload, engine, model):
326
  if error_message:
327
  yield error_message
328
  return
329
- response_json = response.json()
330
- if engine == "gemini" or engine == "vertex-gemini":
 
 
 
 
331
 
332
  if isinstance(response_json, str):
333
  import ast
@@ -361,6 +365,7 @@ async def fetch_response(client, url, headers, payload, engine, model):
361
  yield await generate_no_stream_response(timestamp, model, content=content, tools_id=None, function_call_name=None, function_call_content=None, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=candidates_tokens)
362
 
363
  else:
 
364
  yield response_json
365
 
366
  async def fetch_response_stream(client, url, headers, payload, engine, model):
 
326
  if error_message:
327
  yield error_message
328
  return
329
+
330
+ if engine == "tts":
331
+ yield response.read()
332
+
333
+ elif engine == "gemini" or engine == "vertex-gemini":
334
+ response_json = response.json()
335
 
336
  if isinstance(response_json, str):
337
  import ast
 
365
  yield await generate_no_stream_response(timestamp, model, content=content, tools_id=None, function_call_name=None, function_call_content=None, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=candidates_tokens)
366
 
367
  else:
368
+ response_json = response.json()
369
  yield response_json
370
 
371
  async def fetch_response_stream(client, url, headers, payload, engine, model):
utils.py CHANGED
@@ -416,6 +416,24 @@ def ensure_string(item):
416
  else:
417
  return str(item)
418
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
  import asyncio
420
  import time as time_module
421
  async def error_handling_wrapper(generator, channel_id):
@@ -426,7 +444,10 @@ async def error_handling_wrapper(generator, channel_id):
426
  first_item_str = first_item
427
  # logger.info("first_item_str: %s", first_item_str)
428
  if isinstance(first_item_str, (bytes, bytearray)):
429
- first_item_str = first_item_str.decode("utf-8")
 
 
 
430
  if isinstance(first_item_str, str):
431
  if first_item_str.startswith("data:"):
432
  first_item_str = first_item_str.lstrip("data: ")
@@ -598,6 +619,7 @@ class BaseAPI:
598
  self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "audio/transcriptions",) + ("",) * 3)
599
  self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "moderations",) + ("",) * 3)
600
  self.embeddings: str = urlunparse(parsed_url[:2] + (before_v1 + "embeddings",) + ("",) * 3)
 
601
 
602
  def safe_get(data, *keys, default=None):
603
  for key in keys:
 
416
  else:
417
  return str(item)
418
 
419
+ def identify_audio_format(file_bytes):
420
+ # 读取开头的字节
421
+ if file_bytes.startswith(b'\xFF\xFB') or file_bytes.startswith(b'\xFF\xF3'):
422
+ return "MP3"
423
+ elif file_bytes.startswith(b'ID3'):
424
+ return "MP3 with ID3"
425
+ elif file_bytes.startswith(b'OpusHead'):
426
+ return "OPUS"
427
+ elif file_bytes.startswith(b'ADIF'):
428
+ return "AAC (ADIF)"
429
+ elif file_bytes.startswith(b'\xFF\xF1') or file_bytes.startswith(b'\xFF\xF9'):
430
+ return "AAC (ADTS)"
431
+ elif file_bytes.startswith(b'fLaC'):
432
+ return "FLAC"
433
+ elif file_bytes.startswith(b'RIFF') and file_bytes[8:12] == b'WAVE':
434
+ return "WAV"
435
+ return "Unknown/PCM"
436
+
437
  import asyncio
438
  import time as time_module
439
  async def error_handling_wrapper(generator, channel_id):
 
444
  first_item_str = first_item
445
  # logger.info("first_item_str: %s", first_item_str)
446
  if isinstance(first_item_str, (bytes, bytearray)):
447
+ if identify_audio_format(first_item_str) in ["MP3", "MP3 with ID3", "OPUS", "AAC (ADIF)", "AAC (ADTS)", "FLAC", "WAV"]:
448
+ return first_item, first_response_time
449
+ else:
450
+ first_item_str = first_item_str.decode("utf-8")
451
  if isinstance(first_item_str, str):
452
  if first_item_str.startswith("data:"):
453
  first_item_str = first_item_str.lstrip("data: ")
 
619
  self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "audio/transcriptions",) + ("",) * 3)
620
  self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "moderations",) + ("",) * 3)
621
  self.embeddings: str = urlunparse(parsed_url[:2] + (before_v1 + "embeddings",) + ("",) * 3)
622
+ self.audio_speech: str = urlunparse(parsed_url[:2] + (before_v1 + "audio/speech",) + ("",) * 3)
623
 
624
  def safe_get(data, *keys, default=None):
625
  for key in keys: