yym68686 commited on
Commit
73a667f
·
1 Parent(s): 44caf41

✨ Feature: Add features: Add API channel success rate statistics, channel status records.

Browse files
Files changed (4) hide show
  1. main.py +63 -22
  2. request.py +6 -6
  3. response.py +2 -2
  4. test/test_nostream.py +1 -1
main.py CHANGED
@@ -58,6 +58,8 @@ class StatsMiddleware(BaseHTTPMiddleware):
58
  self.request_times = defaultdict(float)
59
  self.ip_counts = defaultdict(lambda: defaultdict(int))
60
  self.request_arrivals = defaultdict(list)
 
 
61
  self.lock = asyncio.Lock()
62
  self.exclude_paths = set(exclude_paths or [])
63
  self.save_interval = save_interval
@@ -101,7 +103,11 @@ class StatsMiddleware(BaseHTTPMiddleware):
101
  "request_counts": dict(self.request_counts),
102
  "request_times": dict(self.request_times),
103
  "ip_counts": {k: dict(v) for k, v in self.ip_counts.items()},
104
- "request_arrivals": {k: [t.isoformat() for t in v] for k, v in self.request_arrivals.items()}
 
 
 
 
105
  }
106
 
107
  filename = self.filename
@@ -109,10 +115,28 @@ class StatsMiddleware(BaseHTTPMiddleware):
109
  await f.write(json.dumps(stats, indent=2))
110
 
111
  self.last_save_time = current_time
112
- # print(f"Stats saved to {filename}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  async def cleanup_old_data(self):
115
- # cutoff_time = datetime.now() - timedelta(seconds=30)
116
  cutoff_time = datetime.now() - timedelta(hours=24)
117
  async with self.lock:
118
  for endpoint in list(self.request_arrivals.keys()):
@@ -139,10 +163,10 @@ app.add_middleware(
139
 
140
  app.add_middleware(StatsMiddleware, exclude_paths=["/stats", "/generate-api-key"])
141
 
 
142
  async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None):
143
  url = provider['base_url']
144
  parsed_url = urlparse(url)
145
- # print(parsed_url)
146
  engine = None
147
  if parsed_url.netloc == 'generativelanguage.googleapis.com':
148
  engine = "gemini"
@@ -160,6 +184,12 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
160
  and "gemini" not in provider['model'][request.model]:
161
  engine = "openrouter"
162
 
 
 
 
 
 
 
163
  if endpoint == "/v1/images/generations":
164
  engine = "dalle"
165
  request.stream = False
@@ -171,21 +201,28 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
171
 
172
  url, headers, payload = await get_payload(request, engine, provider)
173
 
174
- # request_info = {
175
- # "url": url,
176
- # "headers": headers,
177
- # "payload": payload
178
- # }
179
- # import json
180
- # logger.info(f"Request details: {json.dumps(request_info, indent=4, ensure_ascii=False)}")
181
-
182
- if request.stream:
183
- model = provider['model'][request.model]
184
- generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
185
- wrapped_generator = await error_handling_wrapper(generator, status_code=500)
186
- return StreamingResponse(wrapped_generator, media_type="text/event-stream")
187
- else:
188
- return await anext(fetch_response(app.state.client, url, headers, payload))
 
 
 
 
 
 
 
189
 
190
  import asyncio
191
  class ModelRequestHandler:
@@ -270,10 +307,10 @@ class ModelRequestHandler:
270
 
271
  return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint)
272
 
 
273
  async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None):
274
  num_providers = len(providers)
275
  start_index = self.last_provider_index + 1 if use_round_robin else 0
276
-
277
  for i in range(num_providers + 1):
278
  self.last_provider_index = (start_index + i) % num_providers
279
  provider = providers[self.last_provider_index]
@@ -287,7 +324,6 @@ class ModelRequestHandler:
287
  else:
288
  raise HTTPException(status_code=500, detail="Error: Current provider response failed!")
289
 
290
-
291
  raise HTTPException(status_code=500, detail=f"All providers failed: {request.model}")
292
 
293
  model_handler = ModelRequestHandler()
@@ -341,6 +377,7 @@ def generate_api_key():
341
  api_key = "sk-" + secrets.token_urlsafe(36)
342
  return JSONResponse(content={"api_key": api_key})
343
 
 
344
  @app.get("/stats")
345
  async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)):
346
  middleware = app.middleware_stack.app
@@ -350,7 +387,11 @@ async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)
350
  "request_counts": dict(middleware.request_counts),
351
  "request_times": dict(middleware.request_times),
352
  "ip_counts": {k: dict(v) for k, v in middleware.ip_counts.items()},
353
- "request_arrivals": {k: [t.isoformat() for t in v] for k, v in middleware.request_arrivals.items()}
 
 
 
 
354
  }
355
  return JSONResponse(content=stats)
356
  return {"error": "StatsMiddleware not found"}
 
58
  self.request_times = defaultdict(float)
59
  self.ip_counts = defaultdict(lambda: defaultdict(int))
60
  self.request_arrivals = defaultdict(list)
61
+ self.channel_success_counts = defaultdict(int)
62
+ self.channel_failure_counts = defaultdict(int)
63
  self.lock = asyncio.Lock()
64
  self.exclude_paths = set(exclude_paths or [])
65
  self.save_interval = save_interval
 
103
  "request_counts": dict(self.request_counts),
104
  "request_times": dict(self.request_times),
105
  "ip_counts": {k: dict(v) for k, v in self.ip_counts.items()},
106
+ "request_arrivals": {k: [t.isoformat() for t in v] for k, v in self.request_arrivals.items()},
107
+ "channel_success_counts": dict(self.channel_success_counts),
108
+ "channel_failure_counts": dict(self.channel_failure_counts),
109
+ "channel_success_percentages": self.calculate_success_percentages(),
110
+ "channel_failure_percentages": self.calculate_failure_percentages()
111
  }
112
 
113
  filename = self.filename
 
115
  await f.write(json.dumps(stats, indent=2))
116
 
117
  self.last_save_time = current_time
118
+
119
+ def calculate_success_percentages(self):
120
+ percentages = {}
121
+ for channel, success_count in self.channel_success_counts.items():
122
+ total_count = success_count + self.channel_failure_counts[channel]
123
+ if total_count > 0:
124
+ percentages[channel] = success_count / total_count * 100
125
+ else:
126
+ percentages[channel] = 0
127
+ return percentages
128
+
129
+ def calculate_failure_percentages(self):
130
+ percentages = {}
131
+ for channel, failure_count in self.channel_failure_counts.items():
132
+ total_count = failure_count + self.channel_success_counts[channel]
133
+ if total_count > 0:
134
+ percentages[channel] = failure_count / total_count * 100
135
+ else:
136
+ percentages[channel] = 0
137
+ return percentages
138
 
139
  async def cleanup_old_data(self):
 
140
  cutoff_time = datetime.now() - timedelta(hours=24)
141
  async with self.lock:
142
  for endpoint in list(self.request_arrivals.keys()):
 
163
 
164
  app.add_middleware(StatsMiddleware, exclude_paths=["/stats", "/generate-api-key"])
165
 
166
+ # 在 process_request 函数中更新成功和失败计数
167
  async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None):
168
  url = provider['base_url']
169
  parsed_url = urlparse(url)
 
170
  engine = None
171
  if parsed_url.netloc == 'generativelanguage.googleapis.com':
172
  engine = "gemini"
 
184
  and "gemini" not in provider['model'][request.model]:
185
  engine = "openrouter"
186
 
187
+ if "claude" in provider['model'][request.model] and engine == "vertex":
188
+ engine = "vertex-claude"
189
+
190
+ if "gemini" in provider['model'][request.model] and engine == "vertex":
191
+ engine = "vertex-gemini"
192
+
193
  if endpoint == "/v1/images/generations":
194
  engine = "dalle"
195
  request.stream = False
 
201
 
202
  url, headers, payload = await get_payload(request, engine, provider)
203
 
204
+ try:
205
+ if request.stream:
206
+ model = provider['model'][request.model]
207
+ generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
208
+ wrapped_generator = await error_handling_wrapper(generator, status_code=500)
209
+ response = StreamingResponse(wrapped_generator, media_type="text/event-stream")
210
+ else:
211
+ response = await anext(fetch_response(app.state.client, url, headers, payload))
212
+
213
+ # 更新成功计数
214
+ async with app.middleware_stack.app.lock:
215
+ app.middleware_stack.app.channel_success_counts[provider['provider']] += 1
216
+
217
+ return response
218
+ except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError) as e:
219
+ logger.error(f"Error with provider {provider['provider']}: {str(e)}")
220
+
221
+ # 更新失败计数
222
+ async with app.middleware_stack.app.lock:
223
+ app.middleware_stack.app.channel_failure_counts[provider['provider']] += 1
224
+
225
+ raise e
226
 
227
  import asyncio
228
  class ModelRequestHandler:
 
307
 
308
  return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint)
309
 
310
+ # 在 try_all_providers 函数中处理失败的情况
311
  async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None):
312
  num_providers = len(providers)
313
  start_index = self.last_provider_index + 1 if use_round_robin else 0
 
314
  for i in range(num_providers + 1):
315
  self.last_provider_index = (start_index + i) % num_providers
316
  provider = providers[self.last_provider_index]
 
324
  else:
325
  raise HTTPException(status_code=500, detail="Error: Current provider response failed!")
326
 
 
327
  raise HTTPException(status_code=500, detail=f"All providers failed: {request.model}")
328
 
329
  model_handler = ModelRequestHandler()
 
377
  api_key = "sk-" + secrets.token_urlsafe(36)
378
  return JSONResponse(content={"api_key": api_key})
379
 
380
+ # 在 /stats 路由中返回成功和失败百分比
381
  @app.get("/stats")
382
  async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)):
383
  middleware = app.middleware_stack.app
 
387
  "request_counts": dict(middleware.request_counts),
388
  "request_times": dict(middleware.request_times),
389
  "ip_counts": {k: dict(v) for k, v in middleware.ip_counts.items()},
390
+ "request_arrivals": {k: [t.isoformat() for t in v] for k, v in middleware.request_arrivals.items()},
391
+ "channel_success_counts": dict(middleware.channel_success_counts),
392
+ "channel_failure_counts": dict(middleware.channel_failure_counts),
393
+ "channel_success_percentages": middleware.calculate_success_percentages(),
394
+ "channel_failure_percentages": middleware.calculate_failure_percentages()
395
  }
396
  return JSONResponse(content=stats)
397
  return {"error": "StatsMiddleware not found"}
request.py CHANGED
@@ -10,7 +10,7 @@ async def get_image_message(base64_image, engine = None):
10
  "url": base64_image,
11
  }
12
  }
13
- if "claude" == engine:
14
  return {
15
  "type": "image",
16
  "source": {
@@ -19,7 +19,7 @@ async def get_image_message(base64_image, engine = None):
19
  "data": base64_image.split(",")[1],
20
  }
21
  }
22
- if "gemini" == engine:
23
  return {
24
  "inlineData": {
25
  "mimeType": "image/jpeg",
@@ -29,9 +29,9 @@ async def get_image_message(base64_image, engine = None):
29
  raise ValueError("Unknown engine")
30
 
31
  async def get_text_message(role, message, engine = None):
32
- if "gpt" == engine or "claude" == engine or "openrouter" == engine:
33
  return {"type": "text", "text": message}
34
- if "gemini" == engine:
35
  return {"text": message}
36
  raise ValueError("Unknown engine")
37
 
@@ -794,9 +794,9 @@ async def get_dalle_payload(request, engine, provider):
794
  async def get_payload(request: RequestModel, engine, provider):
795
  if engine == "gemini":
796
  return await get_gemini_payload(request, engine, provider)
797
- elif engine == "vertex" and "gemini" in provider['model'][request.model]:
798
  return await get_vertex_gemini_payload(request, engine, provider)
799
- elif engine == "vertex" and "claude" in provider['model'][request.model]:
800
  return await get_vertex_claude_payload(request, engine, provider)
801
  elif engine == "claude":
802
  return await get_claude_payload(request, engine, provider)
 
10
  "url": base64_image,
11
  }
12
  }
13
+ if "claude" == engine or "vertex-claude" == engine:
14
  return {
15
  "type": "image",
16
  "source": {
 
19
  "data": base64_image.split(",")[1],
20
  }
21
  }
22
+ if "gemini" == engine or "vertex-gemini" == engine:
23
  return {
24
  "inlineData": {
25
  "mimeType": "image/jpeg",
 
29
  raise ValueError("Unknown engine")
30
 
31
  async def get_text_message(role, message, engine = None):
32
+ if "gpt" == engine or "claude" == engine or "openrouter" == engine or "vertex-claude" == engine:
33
  return {"type": "text", "text": message}
34
+ if "gemini" == engine or "vertex-gemini" == engine:
35
  return {"text": message}
36
  raise ValueError("Unknown engine")
37
 
 
794
  async def get_payload(request: RequestModel, engine, provider):
795
  if engine == "gemini":
796
  return await get_gemini_payload(request, engine, provider)
797
+ elif engine == "vertex-gemini":
798
  return await get_vertex_gemini_payload(request, engine, provider)
799
+ elif engine == "vertex-claude":
800
  return await get_vertex_claude_payload(request, engine, provider)
801
  elif engine == "claude":
802
  return await get_claude_payload(request, engine, provider)
response.py CHANGED
@@ -248,10 +248,10 @@ async def fetch_response(client, url, headers, payload):
248
 
249
  async def fetch_response_stream(client, url, headers, payload, engine, model):
250
  try:
251
- if engine == "gemini" or (engine == "vertex" and "gemini" in model):
252
  async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
253
  yield chunk
254
- elif engine == "claude" or (engine == "vertex" and "claude" in model):
255
  async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
256
  yield chunk
257
  elif engine == "gpt":
 
248
 
249
  async def fetch_response_stream(client, url, headers, payload, engine, model):
250
  try:
251
+ if engine == "gemini" or engine == "vertex-gemini":
252
  async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
253
  yield chunk
254
+ elif engine == "claude" or engine == "vertex-claude":
255
  async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
256
  yield chunk
257
  elif engine == "gpt":
test/test_nostream.py CHANGED
@@ -66,7 +66,7 @@ def get_model_response(image_base64):
66
  # "stream": True,
67
  "tools": tools,
68
  "tool_choice": {"type": "function", "function": {"name": "extract_underlined_text"}},
69
- "max_tokens": 300
70
  }
71
 
72
  try:
 
66
  # "stream": True,
67
  "tools": tools,
68
  "tool_choice": {"type": "function", "function": {"name": "extract_underlined_text"}},
69
+ "max_tokens": 1000
70
  }
71
 
72
  try: