✨ Feature: Add features: Add API channel success rate statistics, channel status records.
Browse files- main.py +63 -22
- request.py +6 -6
- response.py +2 -2
- test/test_nostream.py +1 -1
main.py
CHANGED
@@ -58,6 +58,8 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
58 |
self.request_times = defaultdict(float)
|
59 |
self.ip_counts = defaultdict(lambda: defaultdict(int))
|
60 |
self.request_arrivals = defaultdict(list)
|
|
|
|
|
61 |
self.lock = asyncio.Lock()
|
62 |
self.exclude_paths = set(exclude_paths or [])
|
63 |
self.save_interval = save_interval
|
@@ -101,7 +103,11 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
101 |
"request_counts": dict(self.request_counts),
|
102 |
"request_times": dict(self.request_times),
|
103 |
"ip_counts": {k: dict(v) for k, v in self.ip_counts.items()},
|
104 |
-
"request_arrivals": {k: [t.isoformat() for t in v] for k, v in self.request_arrivals.items()}
|
|
|
|
|
|
|
|
|
105 |
}
|
106 |
|
107 |
filename = self.filename
|
@@ -109,10 +115,28 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
109 |
await f.write(json.dumps(stats, indent=2))
|
110 |
|
111 |
self.last_save_time = current_time
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
async def cleanup_old_data(self):
|
115 |
-
# cutoff_time = datetime.now() - timedelta(seconds=30)
|
116 |
cutoff_time = datetime.now() - timedelta(hours=24)
|
117 |
async with self.lock:
|
118 |
for endpoint in list(self.request_arrivals.keys()):
|
@@ -139,10 +163,10 @@ app.add_middleware(
|
|
139 |
|
140 |
app.add_middleware(StatsMiddleware, exclude_paths=["/stats", "/generate-api-key"])
|
141 |
|
|
|
142 |
async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None):
|
143 |
url = provider['base_url']
|
144 |
parsed_url = urlparse(url)
|
145 |
-
# print(parsed_url)
|
146 |
engine = None
|
147 |
if parsed_url.netloc == 'generativelanguage.googleapis.com':
|
148 |
engine = "gemini"
|
@@ -160,6 +184,12 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
|
|
160 |
and "gemini" not in provider['model'][request.model]:
|
161 |
engine = "openrouter"
|
162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
if endpoint == "/v1/images/generations":
|
164 |
engine = "dalle"
|
165 |
request.stream = False
|
@@ -171,21 +201,28 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
|
|
171 |
|
172 |
url, headers, payload = await get_payload(request, engine, provider)
|
173 |
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
import asyncio
|
191 |
class ModelRequestHandler:
|
@@ -270,10 +307,10 @@ class ModelRequestHandler:
|
|
270 |
|
271 |
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint)
|
272 |
|
|
|
273 |
async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None):
|
274 |
num_providers = len(providers)
|
275 |
start_index = self.last_provider_index + 1 if use_round_robin else 0
|
276 |
-
|
277 |
for i in range(num_providers + 1):
|
278 |
self.last_provider_index = (start_index + i) % num_providers
|
279 |
provider = providers[self.last_provider_index]
|
@@ -287,7 +324,6 @@ class ModelRequestHandler:
|
|
287 |
else:
|
288 |
raise HTTPException(status_code=500, detail="Error: Current provider response failed!")
|
289 |
|
290 |
-
|
291 |
raise HTTPException(status_code=500, detail=f"All providers failed: {request.model}")
|
292 |
|
293 |
model_handler = ModelRequestHandler()
|
@@ -341,6 +377,7 @@ def generate_api_key():
|
|
341 |
api_key = "sk-" + secrets.token_urlsafe(36)
|
342 |
return JSONResponse(content={"api_key": api_key})
|
343 |
|
|
|
344 |
@app.get("/stats")
|
345 |
async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)):
|
346 |
middleware = app.middleware_stack.app
|
@@ -350,7 +387,11 @@ async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)
|
|
350 |
"request_counts": dict(middleware.request_counts),
|
351 |
"request_times": dict(middleware.request_times),
|
352 |
"ip_counts": {k: dict(v) for k, v in middleware.ip_counts.items()},
|
353 |
-
"request_arrivals": {k: [t.isoformat() for t in v] for k, v in middleware.request_arrivals.items()}
|
|
|
|
|
|
|
|
|
354 |
}
|
355 |
return JSONResponse(content=stats)
|
356 |
return {"error": "StatsMiddleware not found"}
|
|
|
58 |
self.request_times = defaultdict(float)
|
59 |
self.ip_counts = defaultdict(lambda: defaultdict(int))
|
60 |
self.request_arrivals = defaultdict(list)
|
61 |
+
self.channel_success_counts = defaultdict(int)
|
62 |
+
self.channel_failure_counts = defaultdict(int)
|
63 |
self.lock = asyncio.Lock()
|
64 |
self.exclude_paths = set(exclude_paths or [])
|
65 |
self.save_interval = save_interval
|
|
|
103 |
"request_counts": dict(self.request_counts),
|
104 |
"request_times": dict(self.request_times),
|
105 |
"ip_counts": {k: dict(v) for k, v in self.ip_counts.items()},
|
106 |
+
"request_arrivals": {k: [t.isoformat() for t in v] for k, v in self.request_arrivals.items()},
|
107 |
+
"channel_success_counts": dict(self.channel_success_counts),
|
108 |
+
"channel_failure_counts": dict(self.channel_failure_counts),
|
109 |
+
"channel_success_percentages": self.calculate_success_percentages(),
|
110 |
+
"channel_failure_percentages": self.calculate_failure_percentages()
|
111 |
}
|
112 |
|
113 |
filename = self.filename
|
|
|
115 |
await f.write(json.dumps(stats, indent=2))
|
116 |
|
117 |
self.last_save_time = current_time
|
118 |
+
|
119 |
+
def calculate_success_percentages(self):
|
120 |
+
percentages = {}
|
121 |
+
for channel, success_count in self.channel_success_counts.items():
|
122 |
+
total_count = success_count + self.channel_failure_counts[channel]
|
123 |
+
if total_count > 0:
|
124 |
+
percentages[channel] = success_count / total_count * 100
|
125 |
+
else:
|
126 |
+
percentages[channel] = 0
|
127 |
+
return percentages
|
128 |
+
|
129 |
+
def calculate_failure_percentages(self):
|
130 |
+
percentages = {}
|
131 |
+
for channel, failure_count in self.channel_failure_counts.items():
|
132 |
+
total_count = failure_count + self.channel_success_counts[channel]
|
133 |
+
if total_count > 0:
|
134 |
+
percentages[channel] = failure_count / total_count * 100
|
135 |
+
else:
|
136 |
+
percentages[channel] = 0
|
137 |
+
return percentages
|
138 |
|
139 |
async def cleanup_old_data(self):
|
|
|
140 |
cutoff_time = datetime.now() - timedelta(hours=24)
|
141 |
async with self.lock:
|
142 |
for endpoint in list(self.request_arrivals.keys()):
|
|
|
163 |
|
164 |
app.add_middleware(StatsMiddleware, exclude_paths=["/stats", "/generate-api-key"])
|
165 |
|
166 |
+
# 在 process_request 函数中更新成功和失败计数
|
167 |
async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None):
|
168 |
url = provider['base_url']
|
169 |
parsed_url = urlparse(url)
|
|
|
170 |
engine = None
|
171 |
if parsed_url.netloc == 'generativelanguage.googleapis.com':
|
172 |
engine = "gemini"
|
|
|
184 |
and "gemini" not in provider['model'][request.model]:
|
185 |
engine = "openrouter"
|
186 |
|
187 |
+
if "claude" in provider['model'][request.model] and engine == "vertex":
|
188 |
+
engine = "vertex-claude"
|
189 |
+
|
190 |
+
if "gemini" in provider['model'][request.model] and engine == "vertex":
|
191 |
+
engine = "vertex-gemini"
|
192 |
+
|
193 |
if endpoint == "/v1/images/generations":
|
194 |
engine = "dalle"
|
195 |
request.stream = False
|
|
|
201 |
|
202 |
url, headers, payload = await get_payload(request, engine, provider)
|
203 |
|
204 |
+
try:
|
205 |
+
if request.stream:
|
206 |
+
model = provider['model'][request.model]
|
207 |
+
generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
|
208 |
+
wrapped_generator = await error_handling_wrapper(generator, status_code=500)
|
209 |
+
response = StreamingResponse(wrapped_generator, media_type="text/event-stream")
|
210 |
+
else:
|
211 |
+
response = await anext(fetch_response(app.state.client, url, headers, payload))
|
212 |
+
|
213 |
+
# 更新成功计数
|
214 |
+
async with app.middleware_stack.app.lock:
|
215 |
+
app.middleware_stack.app.channel_success_counts[provider['provider']] += 1
|
216 |
+
|
217 |
+
return response
|
218 |
+
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError) as e:
|
219 |
+
logger.error(f"Error with provider {provider['provider']}: {str(e)}")
|
220 |
+
|
221 |
+
# 更新失败计数
|
222 |
+
async with app.middleware_stack.app.lock:
|
223 |
+
app.middleware_stack.app.channel_failure_counts[provider['provider']] += 1
|
224 |
+
|
225 |
+
raise e
|
226 |
|
227 |
import asyncio
|
228 |
class ModelRequestHandler:
|
|
|
307 |
|
308 |
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint)
|
309 |
|
310 |
+
# 在 try_all_providers 函数中处理失败的情况
|
311 |
async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None):
|
312 |
num_providers = len(providers)
|
313 |
start_index = self.last_provider_index + 1 if use_round_robin else 0
|
|
|
314 |
for i in range(num_providers + 1):
|
315 |
self.last_provider_index = (start_index + i) % num_providers
|
316 |
provider = providers[self.last_provider_index]
|
|
|
324 |
else:
|
325 |
raise HTTPException(status_code=500, detail="Error: Current provider response failed!")
|
326 |
|
|
|
327 |
raise HTTPException(status_code=500, detail=f"All providers failed: {request.model}")
|
328 |
|
329 |
model_handler = ModelRequestHandler()
|
|
|
377 |
api_key = "sk-" + secrets.token_urlsafe(36)
|
378 |
return JSONResponse(content={"api_key": api_key})
|
379 |
|
380 |
+
# 在 /stats 路由中返回成功和失败百分比
|
381 |
@app.get("/stats")
|
382 |
async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)):
|
383 |
middleware = app.middleware_stack.app
|
|
|
387 |
"request_counts": dict(middleware.request_counts),
|
388 |
"request_times": dict(middleware.request_times),
|
389 |
"ip_counts": {k: dict(v) for k, v in middleware.ip_counts.items()},
|
390 |
+
"request_arrivals": {k: [t.isoformat() for t in v] for k, v in middleware.request_arrivals.items()},
|
391 |
+
"channel_success_counts": dict(middleware.channel_success_counts),
|
392 |
+
"channel_failure_counts": dict(middleware.channel_failure_counts),
|
393 |
+
"channel_success_percentages": middleware.calculate_success_percentages(),
|
394 |
+
"channel_failure_percentages": middleware.calculate_failure_percentages()
|
395 |
}
|
396 |
return JSONResponse(content=stats)
|
397 |
return {"error": "StatsMiddleware not found"}
|
request.py
CHANGED
@@ -10,7 +10,7 @@ async def get_image_message(base64_image, engine = None):
|
|
10 |
"url": base64_image,
|
11 |
}
|
12 |
}
|
13 |
-
if "claude" == engine:
|
14 |
return {
|
15 |
"type": "image",
|
16 |
"source": {
|
@@ -19,7 +19,7 @@ async def get_image_message(base64_image, engine = None):
|
|
19 |
"data": base64_image.split(",")[1],
|
20 |
}
|
21 |
}
|
22 |
-
if "gemini" == engine:
|
23 |
return {
|
24 |
"inlineData": {
|
25 |
"mimeType": "image/jpeg",
|
@@ -29,9 +29,9 @@ async def get_image_message(base64_image, engine = None):
|
|
29 |
raise ValueError("Unknown engine")
|
30 |
|
31 |
async def get_text_message(role, message, engine = None):
|
32 |
-
if "gpt" == engine or "claude" == engine or "openrouter" == engine:
|
33 |
return {"type": "text", "text": message}
|
34 |
-
if "gemini" == engine:
|
35 |
return {"text": message}
|
36 |
raise ValueError("Unknown engine")
|
37 |
|
@@ -794,9 +794,9 @@ async def get_dalle_payload(request, engine, provider):
|
|
794 |
async def get_payload(request: RequestModel, engine, provider):
|
795 |
if engine == "gemini":
|
796 |
return await get_gemini_payload(request, engine, provider)
|
797 |
-
elif engine == "vertex
|
798 |
return await get_vertex_gemini_payload(request, engine, provider)
|
799 |
-
elif engine == "vertex
|
800 |
return await get_vertex_claude_payload(request, engine, provider)
|
801 |
elif engine == "claude":
|
802 |
return await get_claude_payload(request, engine, provider)
|
|
|
10 |
"url": base64_image,
|
11 |
}
|
12 |
}
|
13 |
+
if "claude" == engine or "vertex-claude" == engine:
|
14 |
return {
|
15 |
"type": "image",
|
16 |
"source": {
|
|
|
19 |
"data": base64_image.split(",")[1],
|
20 |
}
|
21 |
}
|
22 |
+
if "gemini" == engine or "vertex-gemini" == engine:
|
23 |
return {
|
24 |
"inlineData": {
|
25 |
"mimeType": "image/jpeg",
|
|
|
29 |
raise ValueError("Unknown engine")
|
30 |
|
31 |
async def get_text_message(role, message, engine = None):
|
32 |
+
if "gpt" == engine or "claude" == engine or "openrouter" == engine or "vertex-claude" == engine:
|
33 |
return {"type": "text", "text": message}
|
34 |
+
if "gemini" == engine or "vertex-gemini" == engine:
|
35 |
return {"text": message}
|
36 |
raise ValueError("Unknown engine")
|
37 |
|
|
|
794 |
async def get_payload(request: RequestModel, engine, provider):
|
795 |
if engine == "gemini":
|
796 |
return await get_gemini_payload(request, engine, provider)
|
797 |
+
elif engine == "vertex-gemini":
|
798 |
return await get_vertex_gemini_payload(request, engine, provider)
|
799 |
+
elif engine == "vertex-claude":
|
800 |
return await get_vertex_claude_payload(request, engine, provider)
|
801 |
elif engine == "claude":
|
802 |
return await get_claude_payload(request, engine, provider)
|
response.py
CHANGED
@@ -248,10 +248,10 @@ async def fetch_response(client, url, headers, payload):
|
|
248 |
|
249 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
250 |
try:
|
251 |
-
if engine == "gemini" or
|
252 |
async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
|
253 |
yield chunk
|
254 |
-
elif engine == "claude" or
|
255 |
async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
|
256 |
yield chunk
|
257 |
elif engine == "gpt":
|
|
|
248 |
|
249 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
250 |
try:
|
251 |
+
if engine == "gemini" or engine == "vertex-gemini":
|
252 |
async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
|
253 |
yield chunk
|
254 |
+
elif engine == "claude" or engine == "vertex-claude":
|
255 |
async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
|
256 |
yield chunk
|
257 |
elif engine == "gpt":
|
test/test_nostream.py
CHANGED
@@ -66,7 +66,7 @@ def get_model_response(image_base64):
|
|
66 |
# "stream": True,
|
67 |
"tools": tools,
|
68 |
"tool_choice": {"type": "function", "function": {"name": "extract_underlined_text"}},
|
69 |
-
"max_tokens":
|
70 |
}
|
71 |
|
72 |
try:
|
|
|
66 |
# "stream": True,
|
67 |
"tools": tools,
|
68 |
"tool_choice": {"type": "function", "function": {"name": "extract_underlined_text"}},
|
69 |
+
"max_tokens": 1000
|
70 |
}
|
71 |
|
72 |
try:
|