✨ Feature: Add support for v1/audio/transcriptions endpoint
Browse files- main.py +44 -6
- models.py +14 -1
- request.py +28 -0
- response.py +9 -2
main.py
CHANGED
@@ -12,7 +12,7 @@ from fastapi.responses import StreamingResponse, JSONResponse
|
|
12 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
13 |
from fastapi.exceptions import RequestValidationError
|
14 |
|
15 |
-
from models import RequestModel, ImageGenerationRequest
|
16 |
from request import get_payload
|
17 |
from response import fetch_response, fetch_response_stream
|
18 |
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
|
@@ -191,7 +191,7 @@ app.add_middleware(
|
|
191 |
app.add_middleware(StatsMiddleware)
|
192 |
|
193 |
# 在 process_request 函数中更新成功和失败计数
|
194 |
-
async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None, token=None):
|
195 |
url = provider['base_url']
|
196 |
parsed_url = urlparse(url)
|
197 |
# print("parsed_url", parsed_url)
|
@@ -233,6 +233,10 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
|
|
233 |
engine = "dalle"
|
234 |
request.stream = False
|
235 |
|
|
|
|
|
|
|
|
|
236 |
if provider.get("engine"):
|
237 |
engine = provider["engine"]
|
238 |
|
@@ -241,7 +245,10 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
|
|
241 |
url, headers, payload = await get_payload(request, engine, provider)
|
242 |
if is_debug:
|
243 |
logger.info(json.dumps(headers, indent=4, ensure_ascii=False))
|
244 |
-
|
|
|
|
|
|
|
245 |
try:
|
246 |
if request.stream:
|
247 |
model = provider['model'][request.model]
|
@@ -356,7 +363,7 @@ class ModelRequestHandler:
|
|
356 |
print(json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder))
|
357 |
return provider_list
|
358 |
|
359 |
-
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest], token: str, endpoint=None):
|
360 |
config = app.state.config
|
361 |
# api_keys_db = app.state.api_keys_db
|
362 |
api_list = app.state.api_list
|
@@ -399,7 +406,7 @@ class ModelRequestHandler:
|
|
399 |
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint, token)
|
400 |
|
401 |
# 在 try_all_providers 函数中处理失败的情况
|
402 |
-
async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None, token: str = None):
|
403 |
status_code = 500
|
404 |
error_message = None
|
405 |
num_providers = len(providers)
|
@@ -421,6 +428,9 @@ class ModelRequestHandler:
|
|
421 |
raise HTTPException(status_code=500, detail=f"Error: Current provider response failed: {error_message}")
|
422 |
except (Exception, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError) as e:
|
423 |
logger.error(f"Error with provider {provider['provider']}: {str(e)}")
|
|
|
|
|
|
|
424 |
error_message = str(e)
|
425 |
if auto_retry:
|
426 |
continue
|
@@ -523,7 +533,7 @@ def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(sec
|
|
523 |
return token
|
524 |
|
525 |
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
|
526 |
-
async def request_model(request: Union[RequestModel, ImageGenerationRequest], token: str = Depends(verify_api_key)):
|
527 |
# logger.info(f"Request received: {request}")
|
528 |
return await model_handler.request_model(request, token)
|
529 |
|
@@ -546,6 +556,34 @@ async def images_generations(
|
|
546 |
):
|
547 |
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
|
548 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
549 |
@app.get("/generate-api-key", dependencies=[Depends(rate_limit_dependency)])
|
550 |
def generate_api_key():
|
551 |
api_key = "sk-" + secrets.token_urlsafe(36)
|
|
|
12 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
13 |
from fastapi.exceptions import RequestValidationError
|
14 |
|
15 |
+
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest
|
16 |
from request import get_payload
|
17 |
from response import fetch_response, fetch_response_stream
|
18 |
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
|
|
|
191 |
app.add_middleware(StatsMiddleware)
|
192 |
|
193 |
# 在 process_request 函数中更新成功和失败计数
|
194 |
+
async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest], provider: Dict, endpoint=None, token=None):
|
195 |
url = provider['base_url']
|
196 |
parsed_url = urlparse(url)
|
197 |
# print("parsed_url", parsed_url)
|
|
|
233 |
engine = "dalle"
|
234 |
request.stream = False
|
235 |
|
236 |
+
if endpoint == "/v1/audio/transcriptions":
|
237 |
+
engine = "whisper"
|
238 |
+
request.stream = False
|
239 |
+
|
240 |
if provider.get("engine"):
|
241 |
engine = provider["engine"]
|
242 |
|
|
|
245 |
url, headers, payload = await get_payload(request, engine, provider)
|
246 |
if is_debug:
|
247 |
logger.info(json.dumps(headers, indent=4, ensure_ascii=False))
|
248 |
+
if payload.get("file"):
|
249 |
+
pass
|
250 |
+
else:
|
251 |
+
logger.info(json.dumps(payload, indent=4, ensure_ascii=False))
|
252 |
try:
|
253 |
if request.stream:
|
254 |
model = provider['model'][request.model]
|
|
|
363 |
print(json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder))
|
364 |
return provider_list
|
365 |
|
366 |
+
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest], token: str, endpoint=None):
|
367 |
config = app.state.config
|
368 |
# api_keys_db = app.state.api_keys_db
|
369 |
api_list = app.state.api_list
|
|
|
406 |
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint, token)
|
407 |
|
408 |
# 在 try_all_providers 函数中处理失败的情况
|
409 |
+
async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None, token: str = None):
|
410 |
status_code = 500
|
411 |
error_message = None
|
412 |
num_providers = len(providers)
|
|
|
428 |
raise HTTPException(status_code=500, detail=f"Error: Current provider response failed: {error_message}")
|
429 |
except (Exception, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError) as e:
|
430 |
logger.error(f"Error with provider {provider['provider']}: {str(e)}")
|
431 |
+
if is_debug:
|
432 |
+
import traceback
|
433 |
+
traceback.print_exc()
|
434 |
error_message = str(e)
|
435 |
if auto_retry:
|
436 |
continue
|
|
|
533 |
return token
|
534 |
|
535 |
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
|
536 |
+
async def request_model(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest], token: str = Depends(verify_api_key)):
|
537 |
# logger.info(f"Request received: {request}")
|
538 |
return await model_handler.request_model(request, token)
|
539 |
|
|
|
556 |
):
|
557 |
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
|
558 |
|
559 |
+
from fastapi import UploadFile, File, Form, HTTPException
|
560 |
+
import io
|
561 |
+
@app.post("/v1/audio/transcriptions", dependencies=[Depends(rate_limit_dependency)])
|
562 |
+
async def audio_transcriptions(
|
563 |
+
file: UploadFile = File(...),
|
564 |
+
model: str = Form(...),
|
565 |
+
token: str = Depends(verify_api_key)
|
566 |
+
):
|
567 |
+
try:
|
568 |
+
# 读取上传的文件内容
|
569 |
+
content = await file.read()
|
570 |
+
file_obj = io.BytesIO(content)
|
571 |
+
|
572 |
+
# 创建AudioTranscriptionRequest对象
|
573 |
+
request = AudioTranscriptionRequest(
|
574 |
+
file=(file.filename, file_obj, file.content_type),
|
575 |
+
model=model
|
576 |
+
)
|
577 |
+
|
578 |
+
return await model_handler.request_model(request, token, endpoint="/v1/audio/transcriptions")
|
579 |
+
except UnicodeDecodeError:
|
580 |
+
raise HTTPException(status_code=400, detail="Invalid audio file encoding")
|
581 |
+
except Exception as e:
|
582 |
+
if is_debug:
|
583 |
+
import traceback
|
584 |
+
traceback.print_exc()
|
585 |
+
raise HTTPException(status_code=500, detail=f"Error processing audio file: {str(e)}")
|
586 |
+
|
587 |
@app.get("/generate-api-key", dependencies=[Depends(rate_limit_dependency)])
|
588 |
def generate_api_key():
|
589 |
api_key = "sk-" + secrets.token_urlsafe(36)
|
models.py
CHANGED
@@ -1,5 +1,6 @@
|
|
|
|
1 |
from pydantic import BaseModel, Field
|
2 |
-
from typing import List, Dict, Optional, Union
|
3 |
|
4 |
class ImageGenerationRequest(BaseModel):
|
5 |
model: str
|
@@ -8,6 +9,18 @@ class ImageGenerationRequest(BaseModel):
|
|
8 |
size: str
|
9 |
stream: bool = False
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
class FunctionParameter(BaseModel):
|
12 |
type: str
|
13 |
properties: Dict[str, Dict[str, Union[str, Dict[str, str]]]]
|
|
|
1 |
+
from io import IOBase
|
2 |
from pydantic import BaseModel, Field
|
3 |
+
from typing import List, Dict, Optional, Union, Tuple
|
4 |
|
5 |
class ImageGenerationRequest(BaseModel):
|
6 |
model: str
|
|
|
9 |
size: str
|
10 |
stream: bool = False
|
11 |
|
12 |
+
class AudioTranscriptionRequest(BaseModel):
|
13 |
+
file: Tuple[str, IOBase, str]
|
14 |
+
model: str
|
15 |
+
language: Optional[str] = None
|
16 |
+
prompt: Optional[str] = None
|
17 |
+
response_format: Optional[str] = None
|
18 |
+
temperature: Optional[float] = None
|
19 |
+
stream: bool = False
|
20 |
+
|
21 |
+
class Config:
|
22 |
+
arbitrary_types_allowed = True
|
23 |
+
|
24 |
class FunctionParameter(BaseModel):
|
25 |
type: str
|
26 |
properties: Dict[str, Dict[str, Union[str, Dict[str, str]]]]
|
request.py
CHANGED
@@ -1040,6 +1040,32 @@ async def get_dalle_payload(request, engine, provider):
|
|
1040 |
|
1041 |
return url, headers, payload
|
1042 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1043 |
async def get_payload(request: RequestModel, engine, provider):
|
1044 |
if engine == "gemini":
|
1045 |
return await get_gemini_payload(request, engine, provider)
|
@@ -1061,5 +1087,7 @@ async def get_payload(request: RequestModel, engine, provider):
|
|
1061 |
return await get_cohere_payload(request, engine, provider)
|
1062 |
elif engine == "dalle":
|
1063 |
return await get_dalle_payload(request, engine, provider)
|
|
|
|
|
1064 |
else:
|
1065 |
raise ValueError("Unknown payload")
|
|
|
1040 |
|
1041 |
return url, headers, payload
|
1042 |
|
1043 |
+
async def get_whisper_payload(request, engine, provider):
|
1044 |
+
model = provider['model'][request.model]
|
1045 |
+
headers = {
|
1046 |
+
"Content-Type": "application/json",
|
1047 |
+
}
|
1048 |
+
if provider.get("api"):
|
1049 |
+
headers['Authorization'] = f"Bearer {provider['api'].next()}"
|
1050 |
+
url = provider['base_url']
|
1051 |
+
url = BaseAPI(url).audio_transcriptions
|
1052 |
+
|
1053 |
+
payload = {
|
1054 |
+
"model": model,
|
1055 |
+
"file": request.file,
|
1056 |
+
}
|
1057 |
+
|
1058 |
+
if request.prompt:
|
1059 |
+
payload["prompt"] = request.prompt
|
1060 |
+
if request.response_format:
|
1061 |
+
payload["response_format"] = request.response_format
|
1062 |
+
if request.temperature:
|
1063 |
+
payload["temperature"] = request.temperature
|
1064 |
+
if request.language:
|
1065 |
+
payload["language"] = request.language
|
1066 |
+
|
1067 |
+
return url, headers, payload
|
1068 |
+
|
1069 |
async def get_payload(request: RequestModel, engine, provider):
|
1070 |
if engine == "gemini":
|
1071 |
return await get_gemini_payload(request, engine, provider)
|
|
|
1087 |
return await get_cohere_payload(request, engine, provider)
|
1088 |
elif engine == "dalle":
|
1089 |
return await get_dalle_payload(request, engine, provider)
|
1090 |
+
elif engine == "whisper":
|
1091 |
+
return await get_whisper_payload(request, engine, provider)
|
1092 |
else:
|
1093 |
raise ValueError("Unknown payload")
|
response.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import json
|
2 |
import httpx
|
3 |
from datetime import datetime
|
|
|
4 |
|
5 |
from log_config import logger
|
6 |
|
@@ -41,7 +42,7 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
|
|
41 |
return sse_response
|
42 |
|
43 |
async def check_response(response, error_log):
|
44 |
-
if response.status_code != 200:
|
45 |
error_message = await response.aread()
|
46 |
error_str = error_message.decode('utf-8', errors='replace')
|
47 |
try:
|
@@ -269,7 +270,13 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
|
|
269 |
yield "data: [DONE]\n\r\n"
|
270 |
|
271 |
async def fetch_response(client, url, headers, payload):
|
272 |
-
response =
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
error_message = await check_response(response, "fetch_response")
|
274 |
if error_message:
|
275 |
yield error_message
|
|
|
1 |
import json
|
2 |
import httpx
|
3 |
from datetime import datetime
|
4 |
+
from io import BytesIO
|
5 |
|
6 |
from log_config import logger
|
7 |
|
|
|
42 |
return sse_response
|
43 |
|
44 |
async def check_response(response, error_log):
|
45 |
+
if response and response.status_code != 200:
|
46 |
error_message = await response.aread()
|
47 |
error_str = error_message.decode('utf-8', errors='replace')
|
48 |
try:
|
|
|
270 |
yield "data: [DONE]\n\r\n"
|
271 |
|
272 |
async def fetch_response(client, url, headers, payload):
|
273 |
+
response = None
|
274 |
+
if payload.get("file"):
|
275 |
+
file = payload.pop("file")
|
276 |
+
headers.pop("Content-Type")
|
277 |
+
response = await client.post(url, headers=headers, data=payload, files={"file": file})
|
278 |
+
else:
|
279 |
+
response = await client.post(url, headers=headers, json=payload)
|
280 |
error_message = await check_response(response, "fetch_response")
|
281 |
if error_message:
|
282 |
yield error_message
|