yym68686 commited on
Commit
17409c4
·
1 Parent(s): 79bd233

✨ Feature: Add support for v1/audio/transcriptions endpoint

Browse files
Files changed (4) hide show
  1. main.py +44 -6
  2. models.py +14 -1
  3. request.py +28 -0
  4. response.py +9 -2
main.py CHANGED
@@ -12,7 +12,7 @@ from fastapi.responses import StreamingResponse, JSONResponse
12
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
13
  from fastapi.exceptions import RequestValidationError
14
 
15
- from models import RequestModel, ImageGenerationRequest
16
  from request import get_payload
17
  from response import fetch_response, fetch_response_stream
18
  from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
@@ -191,7 +191,7 @@ app.add_middleware(
191
  app.add_middleware(StatsMiddleware)
192
 
193
  # 在 process_request 函数中更新成功和失败计数
194
- async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None, token=None):
195
  url = provider['base_url']
196
  parsed_url = urlparse(url)
197
  # print("parsed_url", parsed_url)
@@ -233,6 +233,10 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
233
  engine = "dalle"
234
  request.stream = False
235
 
 
 
 
 
236
  if provider.get("engine"):
237
  engine = provider["engine"]
238
 
@@ -241,7 +245,10 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
241
  url, headers, payload = await get_payload(request, engine, provider)
242
  if is_debug:
243
  logger.info(json.dumps(headers, indent=4, ensure_ascii=False))
244
- logger.info(json.dumps(payload, indent=4, ensure_ascii=False))
 
 
 
245
  try:
246
  if request.stream:
247
  model = provider['model'][request.model]
@@ -356,7 +363,7 @@ class ModelRequestHandler:
356
  print(json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder))
357
  return provider_list
358
 
359
- async def request_model(self, request: Union[RequestModel, ImageGenerationRequest], token: str, endpoint=None):
360
  config = app.state.config
361
  # api_keys_db = app.state.api_keys_db
362
  api_list = app.state.api_list
@@ -399,7 +406,7 @@ class ModelRequestHandler:
399
  return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint, token)
400
 
401
  # 在 try_all_providers 函数中处理失败的情况
402
- async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None, token: str = None):
403
  status_code = 500
404
  error_message = None
405
  num_providers = len(providers)
@@ -421,6 +428,9 @@ class ModelRequestHandler:
421
  raise HTTPException(status_code=500, detail=f"Error: Current provider response failed: {error_message}")
422
  except (Exception, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError) as e:
423
  logger.error(f"Error with provider {provider['provider']}: {str(e)}")
 
 
 
424
  error_message = str(e)
425
  if auto_retry:
426
  continue
@@ -523,7 +533,7 @@ def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(sec
523
  return token
524
 
525
  @app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
526
- async def request_model(request: Union[RequestModel, ImageGenerationRequest], token: str = Depends(verify_api_key)):
527
  # logger.info(f"Request received: {request}")
528
  return await model_handler.request_model(request, token)
529
 
@@ -546,6 +556,34 @@ async def images_generations(
546
  ):
547
  return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
548
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
  @app.get("/generate-api-key", dependencies=[Depends(rate_limit_dependency)])
550
  def generate_api_key():
551
  api_key = "sk-" + secrets.token_urlsafe(36)
 
12
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
13
  from fastapi.exceptions import RequestValidationError
14
 
15
+ from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest
16
  from request import get_payload
17
  from response import fetch_response, fetch_response_stream
18
  from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
 
191
  app.add_middleware(StatsMiddleware)
192
 
193
  # 在 process_request 函数中更新成功和失败计数
194
+ async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest], provider: Dict, endpoint=None, token=None):
195
  url = provider['base_url']
196
  parsed_url = urlparse(url)
197
  # print("parsed_url", parsed_url)
 
233
  engine = "dalle"
234
  request.stream = False
235
 
236
+ if endpoint == "/v1/audio/transcriptions":
237
+ engine = "whisper"
238
+ request.stream = False
239
+
240
  if provider.get("engine"):
241
  engine = provider["engine"]
242
 
 
245
  url, headers, payload = await get_payload(request, engine, provider)
246
  if is_debug:
247
  logger.info(json.dumps(headers, indent=4, ensure_ascii=False))
248
+ if payload.get("file"):
249
+ pass
250
+ else:
251
+ logger.info(json.dumps(payload, indent=4, ensure_ascii=False))
252
  try:
253
  if request.stream:
254
  model = provider['model'][request.model]
 
363
  print(json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder))
364
  return provider_list
365
 
366
+ async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest], token: str, endpoint=None):
367
  config = app.state.config
368
  # api_keys_db = app.state.api_keys_db
369
  api_list = app.state.api_list
 
406
  return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint, token)
407
 
408
  # 在 try_all_providers 函数中处理失败的情况
409
+ async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None, token: str = None):
410
  status_code = 500
411
  error_message = None
412
  num_providers = len(providers)
 
428
  raise HTTPException(status_code=500, detail=f"Error: Current provider response failed: {error_message}")
429
  except (Exception, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError) as e:
430
  logger.error(f"Error with provider {provider['provider']}: {str(e)}")
431
+ if is_debug:
432
+ import traceback
433
+ traceback.print_exc()
434
  error_message = str(e)
435
  if auto_retry:
436
  continue
 
533
  return token
534
 
535
  @app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
536
+ async def request_model(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest], token: str = Depends(verify_api_key)):
537
  # logger.info(f"Request received: {request}")
538
  return await model_handler.request_model(request, token)
539
 
 
556
  ):
557
  return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
558
 
559
+ from fastapi import UploadFile, File, Form, HTTPException
560
+ import io
561
+ @app.post("/v1/audio/transcriptions", dependencies=[Depends(rate_limit_dependency)])
562
+ async def audio_transcriptions(
563
+ file: UploadFile = File(...),
564
+ model: str = Form(...),
565
+ token: str = Depends(verify_api_key)
566
+ ):
567
+ try:
568
+ # 读取上传的文件内容
569
+ content = await file.read()
570
+ file_obj = io.BytesIO(content)
571
+
572
+ # 创建AudioTranscriptionRequest对象
573
+ request = AudioTranscriptionRequest(
574
+ file=(file.filename, file_obj, file.content_type),
575
+ model=model
576
+ )
577
+
578
+ return await model_handler.request_model(request, token, endpoint="/v1/audio/transcriptions")
579
+ except UnicodeDecodeError:
580
+ raise HTTPException(status_code=400, detail="Invalid audio file encoding")
581
+ except Exception as e:
582
+ if is_debug:
583
+ import traceback
584
+ traceback.print_exc()
585
+ raise HTTPException(status_code=500, detail=f"Error processing audio file: {str(e)}")
586
+
587
  @app.get("/generate-api-key", dependencies=[Depends(rate_limit_dependency)])
588
  def generate_api_key():
589
  api_key = "sk-" + secrets.token_urlsafe(36)
models.py CHANGED
@@ -1,5 +1,6 @@
 
1
  from pydantic import BaseModel, Field
2
- from typing import List, Dict, Optional, Union
3
 
4
  class ImageGenerationRequest(BaseModel):
5
  model: str
@@ -8,6 +9,18 @@ class ImageGenerationRequest(BaseModel):
8
  size: str
9
  stream: bool = False
10
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class FunctionParameter(BaseModel):
12
  type: str
13
  properties: Dict[str, Dict[str, Union[str, Dict[str, str]]]]
 
1
+ from io import IOBase
2
  from pydantic import BaseModel, Field
3
+ from typing import List, Dict, Optional, Union, Tuple
4
 
5
  class ImageGenerationRequest(BaseModel):
6
  model: str
 
9
  size: str
10
  stream: bool = False
11
 
12
+ class AudioTranscriptionRequest(BaseModel):
13
+ file: Tuple[str, IOBase, str]
14
+ model: str
15
+ language: Optional[str] = None
16
+ prompt: Optional[str] = None
17
+ response_format: Optional[str] = None
18
+ temperature: Optional[float] = None
19
+ stream: bool = False
20
+
21
+ class Config:
22
+ arbitrary_types_allowed = True
23
+
24
  class FunctionParameter(BaseModel):
25
  type: str
26
  properties: Dict[str, Dict[str, Union[str, Dict[str, str]]]]
request.py CHANGED
@@ -1040,6 +1040,32 @@ async def get_dalle_payload(request, engine, provider):
1040
 
1041
  return url, headers, payload
1042
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1043
  async def get_payload(request: RequestModel, engine, provider):
1044
  if engine == "gemini":
1045
  return await get_gemini_payload(request, engine, provider)
@@ -1061,5 +1087,7 @@ async def get_payload(request: RequestModel, engine, provider):
1061
  return await get_cohere_payload(request, engine, provider)
1062
  elif engine == "dalle":
1063
  return await get_dalle_payload(request, engine, provider)
 
 
1064
  else:
1065
  raise ValueError("Unknown payload")
 
1040
 
1041
  return url, headers, payload
1042
 
1043
+ async def get_whisper_payload(request, engine, provider):
1044
+ model = provider['model'][request.model]
1045
+ headers = {
1046
+ "Content-Type": "application/json",
1047
+ }
1048
+ if provider.get("api"):
1049
+ headers['Authorization'] = f"Bearer {provider['api'].next()}"
1050
+ url = provider['base_url']
1051
+ url = BaseAPI(url).audio_transcriptions
1052
+
1053
+ payload = {
1054
+ "model": model,
1055
+ "file": request.file,
1056
+ }
1057
+
1058
+ if request.prompt:
1059
+ payload["prompt"] = request.prompt
1060
+ if request.response_format:
1061
+ payload["response_format"] = request.response_format
1062
+ if request.temperature:
1063
+ payload["temperature"] = request.temperature
1064
+ if request.language:
1065
+ payload["language"] = request.language
1066
+
1067
+ return url, headers, payload
1068
+
1069
  async def get_payload(request: RequestModel, engine, provider):
1070
  if engine == "gemini":
1071
  return await get_gemini_payload(request, engine, provider)
 
1087
  return await get_cohere_payload(request, engine, provider)
1088
  elif engine == "dalle":
1089
  return await get_dalle_payload(request, engine, provider)
1090
+ elif engine == "whisper":
1091
+ return await get_whisper_payload(request, engine, provider)
1092
  else:
1093
  raise ValueError("Unknown payload")
response.py CHANGED
@@ -1,6 +1,7 @@
1
  import json
2
  import httpx
3
  from datetime import datetime
 
4
 
5
  from log_config import logger
6
 
@@ -41,7 +42,7 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
41
  return sse_response
42
 
43
  async def check_response(response, error_log):
44
- if response.status_code != 200:
45
  error_message = await response.aread()
46
  error_str = error_message.decode('utf-8', errors='replace')
47
  try:
@@ -269,7 +270,13 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
269
  yield "data: [DONE]\n\r\n"
270
 
271
  async def fetch_response(client, url, headers, payload):
272
- response = await client.post(url, headers=headers, json=payload)
 
 
 
 
 
 
273
  error_message = await check_response(response, "fetch_response")
274
  if error_message:
275
  yield error_message
 
1
  import json
2
  import httpx
3
  from datetime import datetime
4
+ from io import BytesIO
5
 
6
  from log_config import logger
7
 
 
42
  return sse_response
43
 
44
  async def check_response(response, error_log):
45
+ if response and response.status_code != 200:
46
  error_message = await response.aread()
47
  error_str = error_message.decode('utf-8', errors='replace')
48
  try:
 
270
  yield "data: [DONE]\n\r\n"
271
 
272
  async def fetch_response(client, url, headers, payload):
273
+ response = None
274
+ if payload.get("file"):
275
+ file = payload.pop("file")
276
+ headers.pop("Content-Type")
277
+ response = await client.post(url, headers=headers, data=payload, files={"file": file})
278
+ else:
279
+ response = await client.post(url, headers=headers, json=payload)
280
  error_message = await check_response(response, "fetch_response")
281
  if error_message:
282
  yield error_message