yym68686 commited on
Commit
c50b8cc
·
1 Parent(s): d91f3fa

✨ Feature: Add support for embeddings model

Browse files
Files changed (4) hide show
  1. main.py +14 -3
  2. models.py +10 -1
  3. request.py +23 -0
  4. utils.py +1 -0
main.py CHANGED
@@ -16,7 +16,7 @@ from starlette.responses import StreamingResponse as StarletteStreamingResponse
16
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
17
  from fastapi.exceptions import RequestValidationError
18
 
19
- from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest
20
  from request import get_payload
21
  from response import fetch_response, fetch_response_stream
22
  from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder, get_model_dict, save_api_yaml
@@ -478,7 +478,7 @@ async def ensure_config(request: Request, call_next):
478
  return await call_next(request)
479
 
480
  # 在 process_request 函数中更新成功和失败计数
481
- async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], provider: Dict, endpoint=None, token=None):
482
  url = provider['base_url']
483
  parsed_url = urlparse(url)
484
  # print("parsed_url", parsed_url)
@@ -529,6 +529,10 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
529
  engine = "moderation"
530
  request.stream = False
531
 
 
 
 
 
532
  if provider.get("engine"):
533
  engine = provider["engine"]
534
 
@@ -700,7 +704,7 @@ class ModelRequestHandler:
700
  # print("provider_list", provider_list)
701
  return provider_list
702
 
703
- async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], token: str, endpoint=None):
704
  config = app.state.config
705
  api_list = app.state.api_list
706
  api_index = api_list.index(token)
@@ -904,6 +908,13 @@ async def images_generations(
904
  ):
905
  return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
906
 
 
 
 
 
 
 
 
907
  @app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
908
  async def moderations(
909
  request: ModerationRequest,
 
16
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
17
  from fastapi.exceptions import RequestValidationError
18
 
19
+ from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest, EmbeddingRequest
20
  from request import get_payload
21
  from response import fetch_response, fetch_response_stream
22
  from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder, get_model_dict, save_api_yaml
 
478
  return await call_next(request)
479
 
480
  # 在 process_request 函数中更新成功和失败计数
481
+ async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], provider: Dict, endpoint=None, token=None):
482
  url = provider['base_url']
483
  parsed_url = urlparse(url)
484
  # print("parsed_url", parsed_url)
 
529
  engine = "moderation"
530
  request.stream = False
531
 
532
+ if endpoint == "/v1/embeddings":
533
+ engine = "embedding"
534
+ request.stream = False
535
+
536
  if provider.get("engine"):
537
  engine = provider["engine"]
538
 
 
704
  # print("provider_list", provider_list)
705
  return provider_list
706
 
707
+ async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], token: str, endpoint=None):
708
  config = app.state.config
709
  api_list = app.state.api_list
710
  api_index = api_list.index(token)
 
908
  ):
909
  return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
910
 
911
+ @app.post("/v1/embeddings", dependencies=[Depends(rate_limit_dependency)])
912
+ async def embeddings(
913
+ request: EmbeddingRequest,
914
+ token: str = Depends(verify_api_key)
915
+ ):
916
+ return await model_handler.request_model(request, token, endpoint="/v1/embeddings")
917
+
918
  @app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
919
  async def moderations(
920
  request: ModerationRequest,
models.py CHANGED
@@ -111,6 +111,12 @@ class ImageGenerationRequest(BaseRequest):
111
  size: Optional[str] = "1024x1024"
112
  stream: bool = False
113
 
 
 
 
 
 
 
114
  class AudioTranscriptionRequest(BaseRequest):
115
  file: Tuple[str, IOBase, str]
116
  model: str
@@ -129,7 +135,7 @@ class ModerationRequest(BaseRequest):
129
  stream: bool = False
130
 
131
  class UnifiedRequest(BaseModel):
132
- data: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest]
133
 
134
  @model_validator(mode='before')
135
  @classmethod
@@ -147,6 +153,9 @@ class UnifiedRequest(BaseModel):
147
  elif "input" in values:
148
  values["data"] = ModerationRequest(**values)
149
  values["data"].request_type = "moderation"
 
 
 
150
  else:
151
  raise ValueError("无法确定请求类型")
152
  return values
 
111
  size: Optional[str] = "1024x1024"
112
  stream: bool = False
113
 
114
+ class EmbeddingRequest(BaseRequest):
115
+ input: str
116
+ model: str
117
+ encoding_format: Optional[str] = "float"
118
+ stream: bool = False
119
+
120
  class AudioTranscriptionRequest(BaseRequest):
121
  file: Tuple[str, IOBase, str]
122
  model: str
 
135
  stream: bool = False
136
 
137
  class UnifiedRequest(BaseModel):
138
+ data: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest]
139
 
140
  @model_validator(mode='before')
141
  @classmethod
 
153
  elif "input" in values:
154
  values["data"] = ModerationRequest(**values)
155
  values["data"].request_type = "moderation"
156
+ elif "input" in values:
157
+ values["data"] = EmbeddingRequest(**values)
158
+ values["data"].request_type = "embedding"
159
  else:
160
  raise ValueError("无法确定请求类型")
161
  return values
request.py CHANGED
@@ -1125,6 +1125,27 @@ async def get_moderation_payload(request, engine, provider):
1125
 
1126
  return url, headers, payload
1127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1128
  async def get_payload(request: RequestModel, engine, provider):
1129
  if engine == "gemini":
1130
  return await get_gemini_payload(request, engine, provider)
@@ -1150,5 +1171,7 @@ async def get_payload(request: RequestModel, engine, provider):
1150
  return await get_whisper_payload(request, engine, provider)
1151
  elif engine == "moderation":
1152
  return await get_moderation_payload(request, engine, provider)
 
 
1153
  else:
1154
  raise ValueError("Unknown payload")
 
1125
 
1126
  return url, headers, payload
1127
 
1128
+ async def get_embedding_payload(request, engine, provider):
1129
+ model_dict = get_model_dict(provider)
1130
+ model = model_dict[request.model]
1131
+ headers = {
1132
+ "Content-Type": "application/json",
1133
+ }
1134
+ if provider.get("api"):
1135
+ headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
1136
+ url = provider['base_url']
1137
+ url = BaseAPI(url).embeddings
1138
+
1139
+ payload = {
1140
+ "input": request.input,
1141
+ "model": model,
1142
+ }
1143
+
1144
+ if request.encoding_format:
1145
+ payload["encoding_format"] = request.encoding_format
1146
+
1147
+ return url, headers, payload
1148
+
1149
  async def get_payload(request: RequestModel, engine, provider):
1150
  if engine == "gemini":
1151
  return await get_gemini_payload(request, engine, provider)
 
1171
  return await get_whisper_payload(request, engine, provider)
1172
  elif engine == "moderation":
1173
  return await get_moderation_payload(request, engine, provider)
1174
+ elif engine == "embedding":
1175
+ return await get_embedding_payload(request, engine, provider)
1176
  else:
1177
  raise ValueError("Unknown payload")
utils.py CHANGED
@@ -377,6 +377,7 @@ class BaseAPI:
377
  self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
378
  self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
379
  self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3)
 
380
 
381
  def safe_get(data, *keys, default=None):
382
  for key in keys:
 
377
  self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
378
  self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
379
  self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3)
380
+ self.embeddings: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/embeddings",) + ("",) * 3)
381
 
382
  def safe_get(data, *keys, default=None):
383
  for key in keys: