yym68686 commited on
Commit
9d7f082
·
1 Parent(s): f85160d

Fix polling algorithm error

Browse files
Files changed (2) hide show
  1. .dockerignore +3 -1
  2. main.py +19 -19
.dockerignore CHANGED
@@ -1 +1,3 @@
1
- api.yaml
 
 
 
1
+ api.yaml
2
+ test
3
+ json_str
main.py CHANGED
@@ -62,6 +62,10 @@ config, api_keys_db, api_list = load_config()
62
  async def error_handling_wrapper(generator, status_code=200):
63
  try:
64
  first_item = await generator.__anext__()
 
 
 
 
65
  if isinstance(first_item, dict) and "error" in first_item:
66
  # 如果第一个 yield 的项是错误信息,抛出 HTTPException
67
  raise HTTPException(status_code=status_code, detail=first_item)
@@ -112,15 +116,15 @@ async def process_request(request: RequestModel, provider: Dict):
112
 
113
  if request.stream:
114
  model = provider['model'][request.model]
115
- try:
116
- generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
117
- wrapped_generator = await error_handling_wrapper(generator, status_code=500)
118
- return StreamingResponse(wrapped_generator, media_type="text/event-stream")
119
- except HTTPException as e:
120
- return JSONResponse(status_code=e.status_code, content={"error": str(e.detail)})
121
- except Exception as e:
122
- # 处理其他异常
123
- return JSONResponse(status_code=500, content={"error": str(e)})
124
  else:
125
  return await fetch_response(app.state.client, url, headers, payload)
126
 
@@ -170,14 +174,8 @@ class ModelRequestHandler:
170
  async def try_all_providers(self, request: RequestModel, providers: List[Dict], use_round_robin: bool):
171
  num_providers = len(providers)
172
 
173
- for i in range(num_providers):
174
- if use_round_robin:
175
- # 始终从第一个提供者开始轮询
176
- self.last_provider_index = i % num_providers
177
- else:
178
- # 非轮询模式,按顺序尝试
179
- self.last_provider_index = i
180
-
181
  provider = providers[self.last_provider_index]
182
  try:
183
  response = await process_request(request, provider)
@@ -185,9 +183,11 @@ class ModelRequestHandler:
185
  except Exception as e:
186
  print('\033[31m')
187
  print(f"Error with provider {provider['provider']}: {str(e)}")
188
- traceback.print_exc()
189
  print('\033[0m')
190
- continue
 
 
 
191
 
192
  raise HTTPException(status_code=500, detail="All providers failed")
193
 
 
62
  async def error_handling_wrapper(generator, status_code=200):
63
  try:
64
  first_item = await generator.__anext__()
65
+ if isinstance(first_item, (bytes, bytearray)):
66
+ first_item = first_item.decode("utf-8")
67
+ if isinstance(first_item, str):
68
+ first_item = json.loads(first_item)
69
  if isinstance(first_item, dict) and "error" in first_item:
70
  # 如果第一个 yield 的项是错误信息,抛出 HTTPException
71
  raise HTTPException(status_code=status_code, detail=first_item)
 
116
 
117
  if request.stream:
118
  model = provider['model'][request.model]
119
+ # try:
120
+ generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
121
+ wrapped_generator = await error_handling_wrapper(generator, status_code=500)
122
+ return StreamingResponse(wrapped_generator, media_type="text/event-stream")
123
+ # except HTTPException as e:
124
+ # return JSONResponse(status_code=e.status_code, content={"error": str(e.detail)})
125
+ # except Exception as e:
126
+ # # 处理其他异常
127
+ # return JSONResponse(status_code=500, content={"error": str(e)})
128
  else:
129
  return await fetch_response(app.state.client, url, headers, payload)
130
 
 
174
  async def try_all_providers(self, request: RequestModel, providers: List[Dict], use_round_robin: bool):
175
  num_providers = len(providers)
176
 
177
+ for i in range(num_providers + 1):
178
+ self.last_provider_index = i % num_providers
 
 
 
 
 
 
179
  provider = providers[self.last_provider_index]
180
  try:
181
  response = await process_request(request, provider)
 
183
  except Exception as e:
184
  print('\033[31m')
185
  print(f"Error with provider {provider['provider']}: {str(e)}")
 
186
  print('\033[0m')
187
+ if use_round_robin:
188
+ continue
189
+ else:
190
+ raise HTTPException(status_code=500, detail="Error: Current provider response failed!")
191
 
192
  raise HTTPException(status_code=500, detail="All providers failed")
193