Fix polling algorithm error
Browse files- .dockerignore +3 -1
- main.py +19 -19
.dockerignore
CHANGED
@@ -1 +1,3 @@
|
|
1 |
-
api.yaml
|
|
|
|
|
|
1 |
+
api.yaml
|
2 |
+
test
|
3 |
+
json_str
|
main.py
CHANGED
@@ -62,6 +62,10 @@ config, api_keys_db, api_list = load_config()
|
|
62 |
async def error_handling_wrapper(generator, status_code=200):
|
63 |
try:
|
64 |
first_item = await generator.__anext__()
|
|
|
|
|
|
|
|
|
65 |
if isinstance(first_item, dict) and "error" in first_item:
|
66 |
# 如果第一个 yield 的项是错误信息,抛出 HTTPException
|
67 |
raise HTTPException(status_code=status_code, detail=first_item)
|
@@ -112,15 +116,15 @@ async def process_request(request: RequestModel, provider: Dict):
|
|
112 |
|
113 |
if request.stream:
|
114 |
model = provider['model'][request.model]
|
115 |
-
try:
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
except HTTPException as e:
|
120 |
-
|
121 |
-
except Exception as e:
|
122 |
-
|
123 |
-
|
124 |
else:
|
125 |
return await fetch_response(app.state.client, url, headers, payload)
|
126 |
|
@@ -170,14 +174,8 @@ class ModelRequestHandler:
|
|
170 |
async def try_all_providers(self, request: RequestModel, providers: List[Dict], use_round_robin: bool):
|
171 |
num_providers = len(providers)
|
172 |
|
173 |
-
for i in range(num_providers):
|
174 |
-
|
175 |
-
# 始终从第一个提供者开始轮询
|
176 |
-
self.last_provider_index = i % num_providers
|
177 |
-
else:
|
178 |
-
# 非轮询模式,按顺序尝试
|
179 |
-
self.last_provider_index = i
|
180 |
-
|
181 |
provider = providers[self.last_provider_index]
|
182 |
try:
|
183 |
response = await process_request(request, provider)
|
@@ -185,9 +183,11 @@ class ModelRequestHandler:
|
|
185 |
except Exception as e:
|
186 |
print('\033[31m')
|
187 |
print(f"Error with provider {provider['provider']}: {str(e)}")
|
188 |
-
traceback.print_exc()
|
189 |
print('\033[0m')
|
190 |
-
|
|
|
|
|
|
|
191 |
|
192 |
raise HTTPException(status_code=500, detail="All providers failed")
|
193 |
|
|
|
62 |
async def error_handling_wrapper(generator, status_code=200):
|
63 |
try:
|
64 |
first_item = await generator.__anext__()
|
65 |
+
if isinstance(first_item, (bytes, bytearray)):
|
66 |
+
first_item = first_item.decode("utf-8")
|
67 |
+
if isinstance(first_item, str):
|
68 |
+
first_item = json.loads(first_item)
|
69 |
if isinstance(first_item, dict) and "error" in first_item:
|
70 |
# 如果第一个 yield 的项是错误信息,抛出 HTTPException
|
71 |
raise HTTPException(status_code=status_code, detail=first_item)
|
|
|
116 |
|
117 |
if request.stream:
|
118 |
model = provider['model'][request.model]
|
119 |
+
# try:
|
120 |
+
generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
|
121 |
+
wrapped_generator = await error_handling_wrapper(generator, status_code=500)
|
122 |
+
return StreamingResponse(wrapped_generator, media_type="text/event-stream")
|
123 |
+
# except HTTPException as e:
|
124 |
+
# return JSONResponse(status_code=e.status_code, content={"error": str(e.detail)})
|
125 |
+
# except Exception as e:
|
126 |
+
# # 处理其他异常
|
127 |
+
# return JSONResponse(status_code=500, content={"error": str(e)})
|
128 |
else:
|
129 |
return await fetch_response(app.state.client, url, headers, payload)
|
130 |
|
|
|
174 |
async def try_all_providers(self, request: RequestModel, providers: List[Dict], use_round_robin: bool):
|
175 |
num_providers = len(providers)
|
176 |
|
177 |
+
for i in range(num_providers + 1):
|
178 |
+
self.last_provider_index = i % num_providers
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
provider = providers[self.last_provider_index]
|
180 |
try:
|
181 |
response = await process_request(request, provider)
|
|
|
183 |
except Exception as e:
|
184 |
print('\033[31m')
|
185 |
print(f"Error with provider {provider['provider']}: {str(e)}")
|
|
|
186 |
print('\033[0m')
|
187 |
+
if use_round_robin:
|
188 |
+
continue
|
189 |
+
else:
|
190 |
+
raise HTTPException(status_code=500, detail="Error: Current provider response failed!")
|
191 |
|
192 |
raise HTTPException(status_code=500, detail="All providers failed")
|
193 |
|