yym68686 commited on
Commit
1d1b0f1
·
1 Parent(s): e5b8220

✨ Feature: Add database, count the first character time.

Browse files
Files changed (9) hide show
  1. .dockerignore +5 -1
  2. .gitignore +2 -1
  3. README.md +8 -4
  4. README_CN.md +8 -4
  5. docker-compose.yml +2 -1
  6. main.py +173 -132
  7. requirements.txt +3 -0
  8. test/provider_test.py +2 -1
  9. utils.py +4 -1
.dockerignore CHANGED
@@ -1,3 +1,7 @@
1
  api.yaml
2
  test
3
- json_str
 
 
 
 
 
1
  api.yaml
2
  test
3
+ json_str
4
+ *.jpg
5
+ *.json
6
+ *.png
7
+ *.db
.gitignore CHANGED
@@ -8,4 +8,5 @@ node_modules
8
  .pytest_cache
9
  *.jpg
10
  *.json
11
- *.png
 
 
8
  .pytest_cache
9
  *.jpg
10
  *.json
11
+ *.png
12
+ *.db
README.md CHANGED
@@ -133,7 +133,9 @@ Start the container
133
 
134
  ```bash
135
  docker run --user root -p 8001:8000 --name uni-api -dit \
136
- -v ./api.yaml:/home/api.yaml \
 
 
137
  yym68686/uni-api:latest
138
  ```
139
 
@@ -145,14 +147,15 @@ services:
145
  container_name: uni-api
146
  image: yym68686/uni-api:latest
147
  environment:
148
- - CONFIG_URL=http://file_url/api.yaml
149
  ports:
150
  - 8001:8000
151
  volumes:
152
- - ./api.yaml:/home/api.yaml
 
153
  ```
154
 
155
- CONFIG_URL is a link that can automatically download a remote configuration file. For example, if you find it inconvenient to modify the configuration file on a certain platform, you can upload the configuration file to a hosting service that provides a direct link for uni-api to download. CONFIG_URL is this direct link.
156
 
157
  Run Docker Compose container in the background
158
 
@@ -178,6 +181,7 @@ docker rm -f uni-api
178
  docker run --user root -p 8001:8000 -dit --name uni-api \
179
  -e CONFIG_URL=http://file_url/api.yaml \
180
  -v ./api.yaml:/home/api.yaml \
 
181
  yym68686/uni-api:latest
182
  docker logs -f uni-api
183
  ```
 
133
 
134
  ```bash
135
  docker run --user root -p 8001:8000 --name uni-api -dit \
136
+ -e CONFIG_URL=http://file_url/api.yaml \ # If the local configuration file is already mounted, you do not need to set CONFIG_URL
137
+ -v ./api.yaml:/home/api.yaml \ # If CONFIG_URL is already set, you do not need to mount the configuration file
138
+ -v ./stats.db:/home/stats.db \ # If you do not want to save statistical data, you do not need to mount the stats.db file
139
  yym68686/uni-api:latest
140
  ```
141
 
 
147
  container_name: uni-api
148
  image: yym68686/uni-api:latest
149
  environment:
150
+ - CONFIG_URL=http://file_url/api.yaml # If the local configuration file is already mounted, there is no need to set CONFIG_URL
151
  ports:
152
  - 8001:8000
153
  volumes:
154
+ - ./api.yaml:/home/api.yaml # If CONFIG_URL is already set, there is no need to mount the configuration file
155
+ - ./stats.db:/home/stats.db # If you do not want to save statistical data, there is no need to mount the stats.db file
156
  ```
157
 
158
+ CONFIG_URL is used to automatically download remote configuration files. For example, if it is inconvenient to modify the configuration file on a certain platform, you can upload the configuration file to a hosting service and provide a direct link for uni-api to download. CONFIG_URL is this direct link. If you are using a locally mounted configuration file, you do not need to set CONFIG_URL. CONFIG_URL is used in situations where it is inconvenient to mount the configuration file.
159
 
160
  Run Docker Compose container in the background
161
 
 
181
  docker run --user root -p 8001:8000 -dit --name uni-api \
182
  -e CONFIG_URL=http://file_url/api.yaml \
183
  -v ./api.yaml:/home/api.yaml \
184
+ -v ./stats.db:/home/stats.db \
185
  yym68686/uni-api:latest
186
  docker logs -f uni-api
187
  ```
README_CN.md CHANGED
@@ -133,7 +133,9 @@ Start the container
133
 
134
  ```bash
135
  docker run --user root -p 8001:8000 --name uni-api -dit \
136
- -v ./api.yaml:/home/api.yaml \
 
 
137
  yym68686/uni-api:latest
138
  ```
139
 
@@ -145,14 +147,15 @@ services:
145
  container_name: uni-api
146
  image: yym68686/uni-api:latest
147
  environment:
148
- - CONFIG_URL=http://file_url/api.yaml
149
  ports:
150
  - 8001:8000
151
  volumes:
152
- - ./api.yaml:/home/api.yaml
 
153
  ```
154
 
155
- CONFIG_URL 就是可以自动下载远程的配置文件。比如你在某个平台不方便修改配置文件,可以把配置文件传到某个托管服务,可以提供直链给 uni-api 下载,CONFIG_URL 就是这个直链。
156
 
157
  Run Docker Compose container in the background
158
 
@@ -178,6 +181,7 @@ docker rm -f uni-api
178
  docker run --user root -p 8001:8000 -dit --name uni-api \
179
  -e CONFIG_URL=http://file_url/api.yaml \
180
  -v ./api.yaml:/home/api.yaml \
 
181
  yym68686/uni-api:latest
182
  docker logs -f uni-api
183
  ```
 
133
 
134
  ```bash
135
  docker run --user root -p 8001:8000 --name uni-api -dit \
136
+ -e CONFIG_URL=http://file_url/api.yaml \ # 如果已经挂载了本地配置文件,不需要设置 CONFIG_URL
137
+ -v ./api.yaml:/home/api.yaml \ # 如果已经设置 CONFIG_URL,不需要挂载配置文件
138
+ -v ./stats.db:/home/stats.db \ # 如果不想保存统计数据,不需要挂载 stats.db 文件
139
  yym68686/uni-api:latest
140
  ```
141
 
 
147
  container_name: uni-api
148
  image: yym68686/uni-api:latest
149
  environment:
150
+ - CONFIG_URL=http://file_url/api.yaml # 如果已经挂载了本地配置文件,不需要设置 CONFIG_URL
151
  ports:
152
  - 8001:8000
153
  volumes:
154
+ - ./api.yaml:/home/api.yaml # 如果已经设置 CONFIG_URL,不需要挂载配置文件
155
+ - ./stats.db:/home/stats.db # 如果不想保存统计数据,不需要挂载 stats.db 文件
156
  ```
157
 
158
+ CONFIG_URL 就是可以自动下载远程的配置文件。比如你在某个平台不方便修改配置文件,可以把配置文件传到某个托管服务,可以提供直链给 uni-api 下载,CONFIG_URL 就是这个直链。如果使用本地挂载的配置文件,不需要设置 CONFIG_URL。CONFIG_URL 是在不方便挂载配置文件的情况下使用。
159
 
160
  Run Docker Compose container in the background
161
 
 
181
  docker run --user root -p 8001:8000 -dit --name uni-api \
182
  -e CONFIG_URL=http://file_url/api.yaml \
183
  -v ./api.yaml:/home/api.yaml \
184
+ -v ./stats.db:/home/stats.db \
185
  yym68686/uni-api:latest
186
  docker logs -f uni-api
187
  ```
docker-compose.yml CHANGED
@@ -7,4 +7,5 @@ services:
7
  ports:
8
  - 8001:8000
9
  volumes:
10
- - ./api.yaml:/home/api.yaml
 
 
7
  ports:
8
  - 8001:8000
9
  volumes:
10
+ - ./api.yaml:/home/api.yaml
11
+ - ./stats.db:/home/stats.db
main.py CHANGED
@@ -22,15 +22,16 @@ from typing import List, Dict, Union
22
  from urllib.parse import urlparse
23
 
24
  import os
25
- is_debug = os.getenv("DEBUG", False)
 
 
 
 
26
 
27
  @asynccontextmanager
28
  async def lifespan(app: FastAPI):
29
  # 启动时的代码
30
-
31
- # # 启动事件
32
- # routes = [{"path": route.path, "name": route.name} for route in app.routes]
33
- # logger.info(f"Registered routes: {routes}")
34
 
35
  TIMEOUT = float(os.getenv("TIMEOUT", 100))
36
  timeout = httpx.Timeout(connect=15.0, read=TIMEOUT, write=30.0, pool=30.0)
@@ -66,10 +67,7 @@ import asyncio
66
  from time import time
67
  from collections import defaultdict
68
  from starlette.middleware.base import BaseHTTPMiddleware
69
- from datetime import datetime
70
- from datetime import timedelta
71
  import json
72
- import aiofiles
73
 
74
  async def parse_request_body(request: Request):
75
  if request.method == "POST" and "application/json" in request.headers.get("content-type", ""):
@@ -79,30 +77,53 @@ async def parse_request_body(request: Request):
79
  return None
80
  return None
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  class StatsMiddleware(BaseHTTPMiddleware):
83
- def __init__(self, app, exclude_paths=None, save_interval=3600, filename="stats.json"):
84
  super().__init__(app)
85
- self.request_counts = defaultdict(int)
86
- self.request_times = defaultdict(float)
87
- self.ip_counts = defaultdict(lambda: defaultdict(int))
88
- self.request_arrivals = defaultdict(list)
89
- self.channel_success_counts = defaultdict(int)
90
- self.model_counts = defaultdict(int)
91
- self.channel_failure_counts = defaultdict(int)
92
- self.lock = asyncio.Lock()
93
- self.exclude_paths = set(exclude_paths or [])
94
- self.save_interval = save_interval
95
- self.filename = filename
96
- self.last_save_time = time()
97
-
98
- # 启动定期保存和清理任务
99
- asyncio.create_task(self.periodic_save_and_cleanup())
100
 
101
  async def dispatch(self, request: Request, call_next):
102
- arrival_time = datetime.now()
 
 
 
 
 
 
103
  start_time = time()
104
 
105
- # 使用依赖注入获取预解析的请求体
106
  request.state.parsed_body = await parse_request_body(request)
107
 
108
  model = "unknown"
@@ -121,86 +142,35 @@ class StatsMiddleware(BaseHTTPMiddleware):
121
  endpoint = f"{request.method} {request.url.path}"
122
  client_ip = request.client.host
123
 
124
- if request.url.path not in self.exclude_paths:
125
- async with self.lock:
126
- self.request_counts[endpoint] += 1
127
- self.request_times[endpoint] += process_time
128
- self.ip_counts[endpoint][client_ip] += 1
129
- self.request_arrivals[endpoint].append(arrival_time)
130
- if model != "unknown":
131
- self.model_counts[model] += 1
132
 
133
  return response
134
 
135
- async def periodic_save_and_cleanup(self):
136
- while True:
137
- await asyncio.sleep(self.save_interval)
138
- await self.save_stats()
139
- await self.cleanup_old_data()
140
-
141
- async def save_stats(self):
142
- current_time = time()
143
- if current_time - self.last_save_time < self.save_interval:
144
- return
145
-
146
- async with self.lock:
147
- stats = {
148
- "request_counts": dict(self.request_counts),
149
- "request_times": dict(self.request_times),
150
- "model_counts": dict(self.model_counts),
151
- "ip_counts": {k: dict(v) for k, v in self.ip_counts.items()},
152
- "request_arrivals": {k: [t.isoformat() for t in v] for k, v in self.request_arrivals.items()},
153
- "channel_success_counts": dict(self.channel_success_counts),
154
- "channel_failure_counts": dict(self.channel_failure_counts),
155
- "channel_success_percentages": self.calculate_success_percentages(),
156
- "channel_failure_percentages": self.calculate_failure_percentages()
157
- }
158
-
159
- filename = self.filename
160
- async with aiofiles.open(filename, mode='w') as f:
161
- await f.write(json.dumps(stats, indent=2))
162
-
163
- self.last_save_time = current_time
164
-
165
- def calculate_success_percentages(self):
166
- percentages = {}
167
- for channel, success_count in self.channel_success_counts.items():
168
- total_count = success_count + self.channel_failure_counts[channel]
169
- if total_count > 0:
170
- percentages[channel] = success_count / total_count * 100
171
- else:
172
- percentages[channel] = 0
173
-
174
- sorted_percentages = dict(sorted(percentages.items(), key=lambda item: item[1], reverse=True))
175
- return sorted_percentages
176
-
177
- def calculate_failure_percentages(self):
178
- percentages = {}
179
- for channel, failure_count in self.channel_failure_counts.items():
180
- total_count = failure_count + self.channel_success_counts[channel]
181
- if total_count > 0:
182
- percentages[channel] = failure_count / total_count * 100
183
- else:
184
- percentages[channel] = 0
185
-
186
- sorted_percentages = dict(sorted(percentages.items(), key=lambda item: item[1], reverse=True))
187
- return sorted_percentages
188
-
189
- async def cleanup_old_data(self):
190
- cutoff_time = datetime.now() - timedelta(hours=24)
191
- async with self.lock:
192
- for endpoint in list(self.request_arrivals.keys()):
193
- self.request_arrivals[endpoint] = [
194
- t for t in self.request_arrivals[endpoint] if t > cutoff_time
195
- ]
196
- if not self.request_arrivals[endpoint]:
197
- del self.request_arrivals[endpoint]
198
- self.request_counts.pop(endpoint, None)
199
- self.request_times.pop(endpoint, None)
200
- self.ip_counts.pop(endpoint, None)
201
-
202
- async def cleanup(self):
203
- await self.save_stats()
204
 
205
  # 配置 CORS 中间件
206
  app.add_middleware(
@@ -211,10 +181,10 @@ app.add_middleware(
211
  allow_headers=["*"], # 允许所有头部字段
212
  )
213
 
214
- app.add_middleware(StatsMiddleware, exclude_paths=["/stats", "/generate-api-key"])
215
 
216
  # 在 process_request 函数中更新成功和失败计数
217
- async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None):
218
  url = provider['base_url']
219
  parsed_url = urlparse(url)
220
  # print("parsed_url", parsed_url)
@@ -269,25 +239,23 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
269
  if request.stream:
270
  model = provider['model'][request.model]
271
  generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
272
- wrapped_generator = await error_handling_wrapper(generator)
273
  response = StreamingResponse(wrapped_generator, media_type="text/event-stream")
274
  else:
275
  generator = fetch_response(app.state.client, url, headers, payload)
276
- wrapped_generator = await error_handling_wrapper(generator)
277
  first_element = await anext(wrapped_generator)
278
  first_element = first_element.lstrip("data: ")
279
  first_element = json.loads(first_element)
280
  response = JSONResponse(first_element)
281
 
282
- # 更新成功计数
283
- async with app.middleware_stack.app.lock:
284
- app.middleware_stack.app.channel_success_counts[provider['provider']] += 1
285
 
286
  return response
287
  except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError) as e:
288
- # 更新失败计数
289
- async with app.middleware_stack.app.lock:
290
- app.middleware_stack.app.channel_failure_counts[provider['provider']] += 1
291
 
292
  raise e
293
 
@@ -421,10 +389,10 @@ class ModelRequestHandler:
421
  if safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY") == False:
422
  auto_retry = False
423
 
424
- return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint)
425
 
426
  # 在 try_all_providers 函数中处理失败的情况
427
- async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None):
428
  status_code = 500
429
  error_message = None
430
  num_providers = len(providers)
@@ -433,7 +401,7 @@ class ModelRequestHandler:
433
  self.last_provider_index = (start_index + i) % num_providers
434
  provider = providers[self.last_provider_index]
435
  try:
436
- response = await process_request(request, provider, endpoint)
437
  return response
438
  except HTTPException as e:
439
  logger.error(f"Error with provider {provider['provider']}: {str(e)}")
@@ -510,6 +478,7 @@ async def get_user_rate_limit(api_index: str = None):
510
  return rate_limit
511
 
512
  security = HTTPBearer()
 
513
  async def rate_limit_dependency(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
514
  token = credentials.credentials if credentials else None
515
  api_list = app.state.api_list
@@ -576,24 +545,96 @@ def generate_api_key():
576
  return JSONResponse(content={"api_key": api_key})
577
 
578
  # 在 /stats 路由中返回成功和失败百分比
 
 
 
 
 
 
579
  @app.get("/stats", dependencies=[Depends(rate_limit_dependency)])
580
  async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)):
581
- middleware = app.middleware_stack.app
582
- if isinstance(middleware, StatsMiddleware):
583
- async with middleware.lock:
584
- stats = {
585
- "channel_success_percentages": middleware.calculate_success_percentages(),
586
- "channel_failure_percentages": middleware.calculate_failure_percentages(),
587
- "model_counts": dict(middleware.model_counts),
588
- "request_counts": dict(middleware.request_counts),
589
- "request_times": dict(middleware.request_times),
590
- "ip_counts": {k: dict(v) for k, v in middleware.ip_counts.items()},
591
- "request_arrivals": {k: [t.isoformat() for t in v] for k, v in middleware.request_arrivals.items()},
592
- "channel_success_counts": dict(middleware.channel_success_counts),
593
- "channel_failure_counts": dict(middleware.channel_failure_counts),
594
- }
595
- return JSONResponse(content=stats)
596
- return {"error": "StatsMiddleware not found"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597
 
598
  # async def on_fetch(request, env):
599
  # import asgi
 
22
  from urllib.parse import urlparse
23
 
24
  import os
25
+ is_debug = bool(os.getenv("DEBUG", False))
26
+
27
+ async def create_tables():
28
+ async with engine.begin() as conn:
29
+ await conn.run_sync(Base.metadata.create_all)
30
 
31
  @asynccontextmanager
32
  async def lifespan(app: FastAPI):
33
  # 启动时的代码
34
+ await create_tables()
 
 
 
35
 
36
  TIMEOUT = float(os.getenv("TIMEOUT", 100))
37
  timeout = httpx.Timeout(connect=15.0, read=TIMEOUT, write=30.0, pool=30.0)
 
67
  from time import time
68
  from collections import defaultdict
69
  from starlette.middleware.base import BaseHTTPMiddleware
 
 
70
  import json
 
71
 
72
  async def parse_request_body(request: Request):
73
  if request.method == "POST" and "application/json" in request.headers.get("content-type", ""):
 
77
  return None
78
  return None
79
 
80
+ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
81
+ from sqlalchemy.orm import declarative_base, sessionmaker
82
+ from sqlalchemy import Column, Integer, String, Float, DateTime, select, Boolean
83
+ from sqlalchemy.sql import func
84
+
85
+ # 定义数据库模型
86
+ Base = declarative_base()
87
+
88
+ class RequestStat(Base):
89
+ __tablename__ = 'request_stats'
90
+ id = Column(Integer, primary_key=True)
91
+ endpoint = Column(String)
92
+ ip = Column(String)
93
+ token = Column(String)
94
+ total_time = Column(Float)
95
+ model = Column(String)
96
+ timestamp = Column(DateTime(timezone=True), server_default=func.now())
97
+
98
+ class ChannelStat(Base):
99
+ __tablename__ = 'channel_stats'
100
+ id = Column(Integer, primary_key=True)
101
+ provider = Column(String)
102
+ model = Column(String)
103
+ api_key = Column(String)
104
+ success = Column(Boolean)
105
+ first_response_time = Column(Float) # 新增: 记录首次响应时间
106
+ timestamp = Column(DateTime(timezone=True), server_default=func.now())
107
+
108
+ # 创建异步引擎和会话
109
+ engine = create_async_engine('sqlite+aiosqlite:///stats.db', echo=is_debug)
110
+ async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
111
+
112
  class StatsMiddleware(BaseHTTPMiddleware):
113
+ def __init__(self, app):
114
  super().__init__(app)
115
+ self.db = async_session()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  async def dispatch(self, request: Request, call_next):
118
+ if request.headers.get("x-api-key"):
119
+ token = request.headers.get("x-api-key")
120
+ elif request.headers.get("Authorization"):
121
+ token = request.headers.get("Authorization").split(" ")[1]
122
+ else:
123
+ token = None
124
+
125
  start_time = time()
126
 
 
127
  request.state.parsed_body = await parse_request_body(request)
128
 
129
  model = "unknown"
 
142
  endpoint = f"{request.method} {request.url.path}"
143
  client_ip = request.client.host
144
 
145
+ # 异步更新数据库
146
+ await self.update_stats(endpoint, process_time, client_ip, model, token)
 
 
 
 
 
 
147
 
148
  return response
149
 
150
+ async def update_stats(self, endpoint, process_time, client_ip, model, token):
151
+ async with self.db as session:
152
+ # 为每个请求创建一条新的记录
153
+ new_request_stat = RequestStat(
154
+ endpoint=endpoint,
155
+ ip=client_ip,
156
+ token=token,
157
+ total_time=process_time,
158
+ model=model
159
+ )
160
+ session.add(new_request_stat)
161
+ await session.commit()
162
+
163
+ async def update_channel_stats(self, provider, model, api_key, success, first_response_time):
164
+ async with self.db as session:
165
+ channel_stat = ChannelStat(
166
+ provider=provider,
167
+ model=model,
168
+ api_key=api_key,
169
+ success=success,
170
+ first_response_time=first_response_time
171
+ )
172
+ session.add(channel_stat)
173
+ await session.commit()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  # 配置 CORS 中间件
176
  app.add_middleware(
 
181
  allow_headers=["*"], # 允许所有头部字段
182
  )
183
 
184
+ app.add_middleware(StatsMiddleware)
185
 
186
  # 在 process_request 函数中更新成功和失败计数
187
+ async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None, token=None):
188
  url = provider['base_url']
189
  parsed_url = urlparse(url)
190
  # print("parsed_url", parsed_url)
 
239
  if request.stream:
240
  model = provider['model'][request.model]
241
  generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
242
+ wrapped_generator, first_response_time = await error_handling_wrapper(generator)
243
  response = StreamingResponse(wrapped_generator, media_type="text/event-stream")
244
  else:
245
  generator = fetch_response(app.state.client, url, headers, payload)
246
+ wrapped_generator, first_response_time = await error_handling_wrapper(generator)
247
  first_element = await anext(wrapped_generator)
248
  first_element = first_element.lstrip("data: ")
249
  first_element = json.loads(first_element)
250
  response = JSONResponse(first_element)
251
 
252
+ # 更新成功计数和首次响应时间
253
+ await app.middleware_stack.app.update_channel_stats(provider['provider'], request.model, token, success=True, first_response_time=first_response_time)
 
254
 
255
  return response
256
  except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError) as e:
257
+ # 更新失败计数,首次响应时间为-1表示失败
258
+ await app.middleware_stack.app.update_channel_stats(provider['provider'], request.model, token, success=False, first_response_time=-1)
 
259
 
260
  raise e
261
 
 
389
  if safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY") == False:
390
  auto_retry = False
391
 
392
+ return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint, token)
393
 
394
  # 在 try_all_providers 函数中处理失败的情况
395
+ async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None, token: str = None):
396
  status_code = 500
397
  error_message = None
398
  num_providers = len(providers)
 
401
  self.last_provider_index = (start_index + i) % num_providers
402
  provider = providers[self.last_provider_index]
403
  try:
404
+ response = await process_request(request, provider, endpoint, token)
405
  return response
406
  except HTTPException as e:
407
  logger.error(f"Error with provider {provider['provider']}: {str(e)}")
 
478
  return rate_limit
479
 
480
  security = HTTPBearer()
481
+
482
  async def rate_limit_dependency(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
483
  token = credentials.credentials if credentials else None
484
  api_list = app.state.api_list
 
545
  return JSONResponse(content={"api_key": api_key})
546
 
547
  # 在 /stats 路由中返回成功和失败百分比
548
+ from collections import defaultdict
549
+ from sqlalchemy import func
550
+
551
+ from collections import defaultdict
552
+ from sqlalchemy import func, desc, case
553
+
554
  @app.get("/stats", dependencies=[Depends(rate_limit_dependency)])
555
  async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)):
556
+ async with async_session() as session:
557
+ # 1. 每个渠道下面每个模型的成功率
558
+ channel_model_stats = await session.execute(
559
+ select(
560
+ ChannelStat.provider,
561
+ ChannelStat.model,
562
+ func.count().label('total'),
563
+ func.sum(case((ChannelStat.success == True, 1), else_=0)).label('success_count')
564
+ ).group_by(ChannelStat.provider, ChannelStat.model)
565
+ )
566
+ channel_model_stats = channel_model_stats.fetchall()
567
+
568
+ # 2. 每个渠道总的成功率
569
+ channel_stats = await session.execute(
570
+ select(
571
+ ChannelStat.provider,
572
+ func.count().label('total'),
573
+ func.sum(case((ChannelStat.success == True, 1), else_=0)).label('success_count')
574
+ ).group_by(ChannelStat.provider)
575
+ )
576
+ channel_stats = channel_stats.fetchall()
577
+
578
+ # 3. 每个模型在所有渠道总的请求次数
579
+ model_stats = await session.execute(
580
+ select(ChannelStat.model, func.count().label('count'))
581
+ .group_by(ChannelStat.model)
582
+ .order_by(desc('count'))
583
+ )
584
+ model_stats = model_stats.fetchall()
585
+
586
+ # 4. 每个端点的请求次数
587
+ endpoint_stats = await session.execute(
588
+ select(RequestStat.endpoint, func.count().label('count'))
589
+ .group_by(RequestStat.endpoint)
590
+ .order_by(desc('count'))
591
+ )
592
+ endpoint_stats = endpoint_stats.fetchall()
593
+
594
+ # 5. 每个ip请求的次数
595
+ ip_stats = await session.execute(
596
+ select(RequestStat.ip, func.count().label('count'))
597
+ .group_by(RequestStat.ip)
598
+ .order_by(desc('count'))
599
+ )
600
+ ip_stats = ip_stats.fetchall()
601
+
602
+ # 处理统计数据并返回
603
+ stats = {
604
+ "channel_model_success_rates": [
605
+ {
606
+ "provider": stat.provider,
607
+ "model": stat.model,
608
+ "success_rate": stat.success_count / stat.total if stat.total > 0 else 0
609
+ } for stat in sorted(channel_model_stats, key=lambda x: x.success_count / x.total if x.total > 0 else 0, reverse=True)
610
+ ],
611
+ "channel_success_rates": [
612
+ {
613
+ "provider": stat.provider,
614
+ "success_rate": stat.success_count / stat.total if stat.total > 0 else 0
615
+ } for stat in sorted(channel_stats, key=lambda x: x.success_count / x.total if x.total > 0 else 0, reverse=True)
616
+ ],
617
+ "model_request_counts": [
618
+ {
619
+ "model": stat.model,
620
+ "count": stat.count
621
+ } for stat in model_stats
622
+ ],
623
+ "endpoint_request_counts": [
624
+ {
625
+ "endpoint": stat.endpoint,
626
+ "count": stat.count
627
+ } for stat in endpoint_stats
628
+ ],
629
+ "ip_request_counts": [
630
+ {
631
+ "ip": stat.ip,
632
+ "count": stat.count
633
+ } for stat in ip_stats
634
+ ]
635
+ }
636
+
637
+ return JSONResponse(content=stats)
638
 
639
  # async def on_fetch(request, env):
640
  # import asgi
requirements.txt CHANGED
@@ -3,6 +3,9 @@ pytest
3
  uvicorn
4
  fastapi
5
  aiofiles
 
 
 
6
  watchfiles
7
  httpx[http2]
8
  cryptography
 
3
  uvicorn
4
  fastapi
5
  aiofiles
6
+ greenlet
7
+ aiosqlite
8
+ sqlalchemy
9
  watchfiles
10
  httpx[http2]
11
  cryptography
test/provider_test.py CHANGED
@@ -70,7 +70,8 @@ def test_request_model(test_client, api_key, get_model):
70
  }
71
  }
72
  }
73
- ]
 
74
  }
75
 
76
  headers = {
 
70
  }
71
  }
72
  }
73
+ ],
74
+ "tool_choice": "auto"
75
  }
76
 
77
  headers = {
utils.py CHANGED
@@ -116,9 +116,12 @@ def ensure_string(item):
116
  return str(item)
117
 
118
  import asyncio
 
119
  async def error_handling_wrapper(generator):
 
120
  try:
121
  first_item = await generator.__anext__()
 
122
  first_item_str = first_item
123
  # logger.info("first_item_str: %s", first_item_str)
124
  if isinstance(first_item_str, (bytes, bytearray)):
@@ -153,7 +156,7 @@ async def error_handling_wrapper(generator):
153
  logger.error(f"Network error in new_generator: {e}")
154
  raise
155
 
156
- return new_generator()
157
 
158
  except StopAsyncIteration:
159
  raise HTTPException(status_code=400, detail="data: {'error': 'No data returned'}")
 
116
  return str(item)
117
 
118
  import asyncio
119
+ import time as time_module
120
  async def error_handling_wrapper(generator):
121
+ start_time = time_module.time()
122
  try:
123
  first_item = await generator.__anext__()
124
+ first_response_time = time_module.time() - start_time
125
  first_item_str = first_item
126
  # logger.info("first_item_str: %s", first_item_str)
127
  if isinstance(first_item_str, (bytes, bytearray)):
 
156
  logger.error(f"Network error in new_generator: {e}")
157
  raise
158
 
159
+ return new_generator(), first_response_time
160
 
161
  except StopAsyncIteration:
162
  raise HTTPException(status_code=400, detail="data: {'error': 'No data returned'}")