✨ Feature: Add support for embeddings model
Browse files
main.py
CHANGED
@@ -16,7 +16,7 @@ from starlette.responses import StreamingResponse as StarletteStreamingResponse
|
|
16 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
17 |
from fastapi.exceptions import RequestValidationError
|
18 |
|
19 |
-
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest
|
20 |
from request import get_payload
|
21 |
from response import fetch_response, fetch_response_stream
|
22 |
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder, get_model_dict, save_api_yaml
|
@@ -478,7 +478,7 @@ async def ensure_config(request: Request, call_next):
|
|
478 |
return await call_next(request)
|
479 |
|
480 |
# 在 process_request 函数中更新成功和失败计数
|
481 |
-
async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], provider: Dict, endpoint=None, token=None):
|
482 |
url = provider['base_url']
|
483 |
parsed_url = urlparse(url)
|
484 |
# print("parsed_url", parsed_url)
|
@@ -529,6 +529,10 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
529 |
engine = "moderation"
|
530 |
request.stream = False
|
531 |
|
|
|
|
|
|
|
|
|
532 |
if provider.get("engine"):
|
533 |
engine = provider["engine"]
|
534 |
|
@@ -700,7 +704,7 @@ class ModelRequestHandler:
|
|
700 |
# print("provider_list", provider_list)
|
701 |
return provider_list
|
702 |
|
703 |
-
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], token: str, endpoint=None):
|
704 |
config = app.state.config
|
705 |
api_list = app.state.api_list
|
706 |
api_index = api_list.index(token)
|
@@ -904,6 +908,13 @@ async def images_generations(
|
|
904 |
):
|
905 |
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
|
906 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
907 |
@app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
|
908 |
async def moderations(
|
909 |
request: ModerationRequest,
|
|
|
16 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
17 |
from fastapi.exceptions import RequestValidationError
|
18 |
|
19 |
+
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest, EmbeddingRequest
|
20 |
from request import get_payload
|
21 |
from response import fetch_response, fetch_response_stream
|
22 |
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder, get_model_dict, save_api_yaml
|
|
|
478 |
return await call_next(request)
|
479 |
|
480 |
# 在 process_request 函数中更新成功和失败计数
|
481 |
+
async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], provider: Dict, endpoint=None, token=None):
|
482 |
url = provider['base_url']
|
483 |
parsed_url = urlparse(url)
|
484 |
# print("parsed_url", parsed_url)
|
|
|
529 |
engine = "moderation"
|
530 |
request.stream = False
|
531 |
|
532 |
+
if endpoint == "/v1/embeddings":
|
533 |
+
engine = "embedding"
|
534 |
+
request.stream = False
|
535 |
+
|
536 |
if provider.get("engine"):
|
537 |
engine = provider["engine"]
|
538 |
|
|
|
704 |
# print("provider_list", provider_list)
|
705 |
return provider_list
|
706 |
|
707 |
+
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], token: str, endpoint=None):
|
708 |
config = app.state.config
|
709 |
api_list = app.state.api_list
|
710 |
api_index = api_list.index(token)
|
|
|
908 |
):
|
909 |
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
|
910 |
|
911 |
+
@app.post("/v1/embeddings", dependencies=[Depends(rate_limit_dependency)])
|
912 |
+
async def embeddings(
|
913 |
+
request: EmbeddingRequest,
|
914 |
+
token: str = Depends(verify_api_key)
|
915 |
+
):
|
916 |
+
return await model_handler.request_model(request, token, endpoint="/v1/embeddings")
|
917 |
+
|
918 |
@app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
|
919 |
async def moderations(
|
920 |
request: ModerationRequest,
|
models.py
CHANGED
@@ -111,6 +111,12 @@ class ImageGenerationRequest(BaseRequest):
|
|
111 |
size: Optional[str] = "1024x1024"
|
112 |
stream: bool = False
|
113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
class AudioTranscriptionRequest(BaseRequest):
|
115 |
file: Tuple[str, IOBase, str]
|
116 |
model: str
|
@@ -129,7 +135,7 @@ class ModerationRequest(BaseRequest):
|
|
129 |
stream: bool = False
|
130 |
|
131 |
class UnifiedRequest(BaseModel):
|
132 |
-
data: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest]
|
133 |
|
134 |
@model_validator(mode='before')
|
135 |
@classmethod
|
@@ -147,6 +153,9 @@ class UnifiedRequest(BaseModel):
|
|
147 |
elif "input" in values:
|
148 |
values["data"] = ModerationRequest(**values)
|
149 |
values["data"].request_type = "moderation"
|
|
|
|
|
|
|
150 |
else:
|
151 |
raise ValueError("无法确定请求类型")
|
152 |
return values
|
|
|
111 |
size: Optional[str] = "1024x1024"
|
112 |
stream: bool = False
|
113 |
|
114 |
+
class EmbeddingRequest(BaseRequest):
|
115 |
+
input: str
|
116 |
+
model: str
|
117 |
+
encoding_format: Optional[str] = "float"
|
118 |
+
stream: bool = False
|
119 |
+
|
120 |
class AudioTranscriptionRequest(BaseRequest):
|
121 |
file: Tuple[str, IOBase, str]
|
122 |
model: str
|
|
|
135 |
stream: bool = False
|
136 |
|
137 |
class UnifiedRequest(BaseModel):
|
138 |
+
data: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest]
|
139 |
|
140 |
@model_validator(mode='before')
|
141 |
@classmethod
|
|
|
153 |
elif "input" in values:
|
154 |
values["data"] = ModerationRequest(**values)
|
155 |
values["data"].request_type = "moderation"
|
156 |
+
elif "input" in values:
|
157 |
+
values["data"] = EmbeddingRequest(**values)
|
158 |
+
values["data"].request_type = "embedding"
|
159 |
else:
|
160 |
raise ValueError("无法确定请求类型")
|
161 |
return values
|
request.py
CHANGED
@@ -1125,6 +1125,27 @@ async def get_moderation_payload(request, engine, provider):
|
|
1125 |
|
1126 |
return url, headers, payload
|
1127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1128 |
async def get_payload(request: RequestModel, engine, provider):
|
1129 |
if engine == "gemini":
|
1130 |
return await get_gemini_payload(request, engine, provider)
|
@@ -1150,5 +1171,7 @@ async def get_payload(request: RequestModel, engine, provider):
|
|
1150 |
return await get_whisper_payload(request, engine, provider)
|
1151 |
elif engine == "moderation":
|
1152 |
return await get_moderation_payload(request, engine, provider)
|
|
|
|
|
1153 |
else:
|
1154 |
raise ValueError("Unknown payload")
|
|
|
1125 |
|
1126 |
return url, headers, payload
|
1127 |
|
1128 |
+
async def get_embedding_payload(request, engine, provider):
|
1129 |
+
model_dict = get_model_dict(provider)
|
1130 |
+
model = model_dict[request.model]
|
1131 |
+
headers = {
|
1132 |
+
"Content-Type": "application/json",
|
1133 |
+
}
|
1134 |
+
if provider.get("api"):
|
1135 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
1136 |
+
url = provider['base_url']
|
1137 |
+
url = BaseAPI(url).embeddings
|
1138 |
+
|
1139 |
+
payload = {
|
1140 |
+
"input": request.input,
|
1141 |
+
"model": model,
|
1142 |
+
}
|
1143 |
+
|
1144 |
+
if request.encoding_format:
|
1145 |
+
payload["encoding_format"] = request.encoding_format
|
1146 |
+
|
1147 |
+
return url, headers, payload
|
1148 |
+
|
1149 |
async def get_payload(request: RequestModel, engine, provider):
|
1150 |
if engine == "gemini":
|
1151 |
return await get_gemini_payload(request, engine, provider)
|
|
|
1171 |
return await get_whisper_payload(request, engine, provider)
|
1172 |
elif engine == "moderation":
|
1173 |
return await get_moderation_payload(request, engine, provider)
|
1174 |
+
elif engine == "embedding":
|
1175 |
+
return await get_embedding_payload(request, engine, provider)
|
1176 |
else:
|
1177 |
raise ValueError("Unknown payload")
|
utils.py
CHANGED
@@ -377,6 +377,7 @@ class BaseAPI:
|
|
377 |
self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
|
378 |
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
|
379 |
self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3)
|
|
|
380 |
|
381 |
def safe_get(data, *keys, default=None):
|
382 |
for key in keys:
|
|
|
377 |
self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
|
378 |
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
|
379 |
self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3)
|
380 |
+
self.embeddings: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/embeddings",) + ("",) * 3)
|
381 |
|
382 |
def safe_get(data, *keys, default=None):
|
383 |
for key in keys:
|