✨ Feature: Add database, count the first character time.
Browse files- .dockerignore +5 -1
- .gitignore +2 -1
- README.md +8 -4
- README_CN.md +8 -4
- docker-compose.yml +2 -1
- main.py +173 -132
- requirements.txt +3 -0
- test/provider_test.py +2 -1
- utils.py +4 -1
.dockerignore
CHANGED
@@ -1,3 +1,7 @@
|
|
1 |
api.yaml
|
2 |
test
|
3 |
-
json_str
|
|
|
|
|
|
|
|
|
|
1 |
api.yaml
|
2 |
test
|
3 |
+
json_str
|
4 |
+
*.jpg
|
5 |
+
*.json
|
6 |
+
*.png
|
7 |
+
*.db
|
.gitignore
CHANGED
@@ -8,4 +8,5 @@ node_modules
|
|
8 |
.pytest_cache
|
9 |
*.jpg
|
10 |
*.json
|
11 |
-
*.png
|
|
|
|
8 |
.pytest_cache
|
9 |
*.jpg
|
10 |
*.json
|
11 |
+
*.png
|
12 |
+
*.db
|
README.md
CHANGED
@@ -133,7 +133,9 @@ Start the container
|
|
133 |
|
134 |
```bash
|
135 |
docker run --user root -p 8001:8000 --name uni-api -dit \
|
136 |
-
-
|
|
|
|
|
137 |
yym68686/uni-api:latest
|
138 |
```
|
139 |
|
@@ -145,14 +147,15 @@ services:
|
|
145 |
container_name: uni-api
|
146 |
image: yym68686/uni-api:latest
|
147 |
environment:
|
148 |
-
- CONFIG_URL=http://file_url/api.yaml
|
149 |
ports:
|
150 |
- 8001:8000
|
151 |
volumes:
|
152 |
-
- ./api.yaml:/home/api.yaml
|
|
|
153 |
```
|
154 |
|
155 |
-
CONFIG_URL is
|
156 |
|
157 |
Run Docker Compose container in the background
|
158 |
|
@@ -178,6 +181,7 @@ docker rm -f uni-api
|
|
178 |
docker run --user root -p 8001:8000 -dit --name uni-api \
|
179 |
-e CONFIG_URL=http://file_url/api.yaml \
|
180 |
-v ./api.yaml:/home/api.yaml \
|
|
|
181 |
yym68686/uni-api:latest
|
182 |
docker logs -f uni-api
|
183 |
```
|
|
|
133 |
|
134 |
```bash
|
135 |
docker run --user root -p 8001:8000 --name uni-api -dit \
|
136 |
+
-e CONFIG_URL=http://file_url/api.yaml \ # If the local configuration file is already mounted, you do not need to set CONFIG_URL
|
137 |
+
-v ./api.yaml:/home/api.yaml \ # If CONFIG_URL is already set, you do not need to mount the configuration file
|
138 |
+
-v ./stats.db:/home/stats.db \ # If you do not want to save statistical data, you do not need to mount the stats.db file
|
139 |
yym68686/uni-api:latest
|
140 |
```
|
141 |
|
|
|
147 |
container_name: uni-api
|
148 |
image: yym68686/uni-api:latest
|
149 |
environment:
|
150 |
+
- CONFIG_URL=http://file_url/api.yaml # If the local configuration file is already mounted, there is no need to set CONFIG_URL
|
151 |
ports:
|
152 |
- 8001:8000
|
153 |
volumes:
|
154 |
+
- ./api.yaml:/home/api.yaml # If CONFIG_URL is already set, there is no need to mount the configuration file
|
155 |
+
- ./stats.db:/home/stats.db # If you do not want to save statistical data, there is no need to mount the stats.db file
|
156 |
```
|
157 |
|
158 |
+
CONFIG_URL is used to automatically download remote configuration files. For example, if it is inconvenient to modify the configuration file on a certain platform, you can upload the configuration file to a hosting service and provide a direct link for uni-api to download. CONFIG_URL is this direct link. If you are using a locally mounted configuration file, you do not need to set CONFIG_URL. CONFIG_URL is used in situations where it is inconvenient to mount the configuration file.
|
159 |
|
160 |
Run Docker Compose container in the background
|
161 |
|
|
|
181 |
docker run --user root -p 8001:8000 -dit --name uni-api \
|
182 |
-e CONFIG_URL=http://file_url/api.yaml \
|
183 |
-v ./api.yaml:/home/api.yaml \
|
184 |
+
-v ./stats.db:/home/stats.db \
|
185 |
yym68686/uni-api:latest
|
186 |
docker logs -f uni-api
|
187 |
```
|
README_CN.md
CHANGED
@@ -133,7 +133,9 @@ Start the container
|
|
133 |
|
134 |
```bash
|
135 |
docker run --user root -p 8001:8000 --name uni-api -dit \
|
136 |
-
-
|
|
|
|
|
137 |
yym68686/uni-api:latest
|
138 |
```
|
139 |
|
@@ -145,14 +147,15 @@ services:
|
|
145 |
container_name: uni-api
|
146 |
image: yym68686/uni-api:latest
|
147 |
environment:
|
148 |
-
- CONFIG_URL=http://file_url/api.yaml
|
149 |
ports:
|
150 |
- 8001:8000
|
151 |
volumes:
|
152 |
-
- ./api.yaml:/home/api.yaml
|
|
|
153 |
```
|
154 |
|
155 |
-
CONFIG_URL 就是可以自动下载远程的配置文件。比如你在某个平台不方便修改配置文件,可以把配置文件传到某个托管服务,可以提供直链给 uni-api 下载,CONFIG_URL
|
156 |
|
157 |
Run Docker Compose container in the background
|
158 |
|
@@ -178,6 +181,7 @@ docker rm -f uni-api
|
|
178 |
docker run --user root -p 8001:8000 -dit --name uni-api \
|
179 |
-e CONFIG_URL=http://file_url/api.yaml \
|
180 |
-v ./api.yaml:/home/api.yaml \
|
|
|
181 |
yym68686/uni-api:latest
|
182 |
docker logs -f uni-api
|
183 |
```
|
|
|
133 |
|
134 |
```bash
|
135 |
docker run --user root -p 8001:8000 --name uni-api -dit \
|
136 |
+
-e CONFIG_URL=http://file_url/api.yaml \ # 如果已经挂载了本地配置文件,不需要设置 CONFIG_URL
|
137 |
+
-v ./api.yaml:/home/api.yaml \ # 如果已经设置 CONFIG_URL,不需要挂载配置文件
|
138 |
+
-v ./stats.db:/home/stats.db \ # 如果不想保存统计数据,不需要挂载 stats.db 文件
|
139 |
yym68686/uni-api:latest
|
140 |
```
|
141 |
|
|
|
147 |
container_name: uni-api
|
148 |
image: yym68686/uni-api:latest
|
149 |
environment:
|
150 |
+
- CONFIG_URL=http://file_url/api.yaml # 如果已经挂载了本地配置文件,不需要设置 CONFIG_URL
|
151 |
ports:
|
152 |
- 8001:8000
|
153 |
volumes:
|
154 |
+
- ./api.yaml:/home/api.yaml # 如果已经设置 CONFIG_URL,不需要挂载配置文件
|
155 |
+
- ./stats.db:/home/stats.db # 如果不想保存统计数据,不需要挂载 stats.db 文件
|
156 |
```
|
157 |
|
158 |
+
CONFIG_URL 就是可以自动下载远程的配置文件。比如你在某个平台不方便修改配置文件,可以把配置文件传到某个托管服务,可以提供直链给 uni-api 下载,CONFIG_URL 就是这个直链。如果使用本地挂载的配置文件,不需要设置 CONFIG_URL。CONFIG_URL 是在不方便挂载配置文件的情况下使用。
|
159 |
|
160 |
Run Docker Compose container in the background
|
161 |
|
|
|
181 |
docker run --user root -p 8001:8000 -dit --name uni-api \
|
182 |
-e CONFIG_URL=http://file_url/api.yaml \
|
183 |
-v ./api.yaml:/home/api.yaml \
|
184 |
+
-v ./stats.db:/home/stats.db \
|
185 |
yym68686/uni-api:latest
|
186 |
docker logs -f uni-api
|
187 |
```
|
docker-compose.yml
CHANGED
@@ -7,4 +7,5 @@ services:
|
|
7 |
ports:
|
8 |
- 8001:8000
|
9 |
volumes:
|
10 |
-
- ./api.yaml:/home/api.yaml
|
|
|
|
7 |
ports:
|
8 |
- 8001:8000
|
9 |
volumes:
|
10 |
+
- ./api.yaml:/home/api.yaml
|
11 |
+
- ./stats.db:/home/stats.db
|
main.py
CHANGED
@@ -22,15 +22,16 @@ from typing import List, Dict, Union
|
|
22 |
from urllib.parse import urlparse
|
23 |
|
24 |
import os
|
25 |
-
is_debug = os.getenv("DEBUG", False)
|
|
|
|
|
|
|
|
|
26 |
|
27 |
@asynccontextmanager
|
28 |
async def lifespan(app: FastAPI):
|
29 |
# 启动时的代码
|
30 |
-
|
31 |
-
# # 启动事件
|
32 |
-
# routes = [{"path": route.path, "name": route.name} for route in app.routes]
|
33 |
-
# logger.info(f"Registered routes: {routes}")
|
34 |
|
35 |
TIMEOUT = float(os.getenv("TIMEOUT", 100))
|
36 |
timeout = httpx.Timeout(connect=15.0, read=TIMEOUT, write=30.0, pool=30.0)
|
@@ -66,10 +67,7 @@ import asyncio
|
|
66 |
from time import time
|
67 |
from collections import defaultdict
|
68 |
from starlette.middleware.base import BaseHTTPMiddleware
|
69 |
-
from datetime import datetime
|
70 |
-
from datetime import timedelta
|
71 |
import json
|
72 |
-
import aiofiles
|
73 |
|
74 |
async def parse_request_body(request: Request):
|
75 |
if request.method == "POST" and "application/json" in request.headers.get("content-type", ""):
|
@@ -79,30 +77,53 @@ async def parse_request_body(request: Request):
|
|
79 |
return None
|
80 |
return None
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
class StatsMiddleware(BaseHTTPMiddleware):
|
83 |
-
def __init__(self, app
|
84 |
super().__init__(app)
|
85 |
-
self.
|
86 |
-
self.request_times = defaultdict(float)
|
87 |
-
self.ip_counts = defaultdict(lambda: defaultdict(int))
|
88 |
-
self.request_arrivals = defaultdict(list)
|
89 |
-
self.channel_success_counts = defaultdict(int)
|
90 |
-
self.model_counts = defaultdict(int)
|
91 |
-
self.channel_failure_counts = defaultdict(int)
|
92 |
-
self.lock = asyncio.Lock()
|
93 |
-
self.exclude_paths = set(exclude_paths or [])
|
94 |
-
self.save_interval = save_interval
|
95 |
-
self.filename = filename
|
96 |
-
self.last_save_time = time()
|
97 |
-
|
98 |
-
# 启动定期保存和清理任务
|
99 |
-
asyncio.create_task(self.periodic_save_and_cleanup())
|
100 |
|
101 |
async def dispatch(self, request: Request, call_next):
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
start_time = time()
|
104 |
|
105 |
-
# 使用依赖注入获取预解析的请求体
|
106 |
request.state.parsed_body = await parse_request_body(request)
|
107 |
|
108 |
model = "unknown"
|
@@ -121,86 +142,35 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
121 |
endpoint = f"{request.method} {request.url.path}"
|
122 |
client_ip = request.client.host
|
123 |
|
124 |
-
|
125 |
-
|
126 |
-
self.request_counts[endpoint] += 1
|
127 |
-
self.request_times[endpoint] += process_time
|
128 |
-
self.ip_counts[endpoint][client_ip] += 1
|
129 |
-
self.request_arrivals[endpoint].append(arrival_time)
|
130 |
-
if model != "unknown":
|
131 |
-
self.model_counts[model] += 1
|
132 |
|
133 |
return response
|
134 |
|
135 |
-
async def
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
filename = self.filename
|
160 |
-
async with aiofiles.open(filename, mode='w') as f:
|
161 |
-
await f.write(json.dumps(stats, indent=2))
|
162 |
-
|
163 |
-
self.last_save_time = current_time
|
164 |
-
|
165 |
-
def calculate_success_percentages(self):
|
166 |
-
percentages = {}
|
167 |
-
for channel, success_count in self.channel_success_counts.items():
|
168 |
-
total_count = success_count + self.channel_failure_counts[channel]
|
169 |
-
if total_count > 0:
|
170 |
-
percentages[channel] = success_count / total_count * 100
|
171 |
-
else:
|
172 |
-
percentages[channel] = 0
|
173 |
-
|
174 |
-
sorted_percentages = dict(sorted(percentages.items(), key=lambda item: item[1], reverse=True))
|
175 |
-
return sorted_percentages
|
176 |
-
|
177 |
-
def calculate_failure_percentages(self):
|
178 |
-
percentages = {}
|
179 |
-
for channel, failure_count in self.channel_failure_counts.items():
|
180 |
-
total_count = failure_count + self.channel_success_counts[channel]
|
181 |
-
if total_count > 0:
|
182 |
-
percentages[channel] = failure_count / total_count * 100
|
183 |
-
else:
|
184 |
-
percentages[channel] = 0
|
185 |
-
|
186 |
-
sorted_percentages = dict(sorted(percentages.items(), key=lambda item: item[1], reverse=True))
|
187 |
-
return sorted_percentages
|
188 |
-
|
189 |
-
async def cleanup_old_data(self):
|
190 |
-
cutoff_time = datetime.now() - timedelta(hours=24)
|
191 |
-
async with self.lock:
|
192 |
-
for endpoint in list(self.request_arrivals.keys()):
|
193 |
-
self.request_arrivals[endpoint] = [
|
194 |
-
t for t in self.request_arrivals[endpoint] if t > cutoff_time
|
195 |
-
]
|
196 |
-
if not self.request_arrivals[endpoint]:
|
197 |
-
del self.request_arrivals[endpoint]
|
198 |
-
self.request_counts.pop(endpoint, None)
|
199 |
-
self.request_times.pop(endpoint, None)
|
200 |
-
self.ip_counts.pop(endpoint, None)
|
201 |
-
|
202 |
-
async def cleanup(self):
|
203 |
-
await self.save_stats()
|
204 |
|
205 |
# 配置 CORS 中间件
|
206 |
app.add_middleware(
|
@@ -211,10 +181,10 @@ app.add_middleware(
|
|
211 |
allow_headers=["*"], # 允许所有头部字段
|
212 |
)
|
213 |
|
214 |
-
app.add_middleware(StatsMiddleware
|
215 |
|
216 |
# 在 process_request 函数中更新成功和失败计数
|
217 |
-
async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None):
|
218 |
url = provider['base_url']
|
219 |
parsed_url = urlparse(url)
|
220 |
# print("parsed_url", parsed_url)
|
@@ -269,25 +239,23 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
|
|
269 |
if request.stream:
|
270 |
model = provider['model'][request.model]
|
271 |
generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
|
272 |
-
wrapped_generator = await error_handling_wrapper(generator)
|
273 |
response = StreamingResponse(wrapped_generator, media_type="text/event-stream")
|
274 |
else:
|
275 |
generator = fetch_response(app.state.client, url, headers, payload)
|
276 |
-
wrapped_generator = await error_handling_wrapper(generator)
|
277 |
first_element = await anext(wrapped_generator)
|
278 |
first_element = first_element.lstrip("data: ")
|
279 |
first_element = json.loads(first_element)
|
280 |
response = JSONResponse(first_element)
|
281 |
|
282 |
-
#
|
283 |
-
|
284 |
-
app.middleware_stack.app.channel_success_counts[provider['provider']] += 1
|
285 |
|
286 |
return response
|
287 |
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError) as e:
|
288 |
-
#
|
289 |
-
|
290 |
-
app.middleware_stack.app.channel_failure_counts[provider['provider']] += 1
|
291 |
|
292 |
raise e
|
293 |
|
@@ -421,10 +389,10 @@ class ModelRequestHandler:
|
|
421 |
if safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY") == False:
|
422 |
auto_retry = False
|
423 |
|
424 |
-
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint)
|
425 |
|
426 |
# 在 try_all_providers 函数中处理失败的情况
|
427 |
-
async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None):
|
428 |
status_code = 500
|
429 |
error_message = None
|
430 |
num_providers = len(providers)
|
@@ -433,7 +401,7 @@ class ModelRequestHandler:
|
|
433 |
self.last_provider_index = (start_index + i) % num_providers
|
434 |
provider = providers[self.last_provider_index]
|
435 |
try:
|
436 |
-
response = await process_request(request, provider, endpoint)
|
437 |
return response
|
438 |
except HTTPException as e:
|
439 |
logger.error(f"Error with provider {provider['provider']}: {str(e)}")
|
@@ -510,6 +478,7 @@ async def get_user_rate_limit(api_index: str = None):
|
|
510 |
return rate_limit
|
511 |
|
512 |
security = HTTPBearer()
|
|
|
513 |
async def rate_limit_dependency(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
514 |
token = credentials.credentials if credentials else None
|
515 |
api_list = app.state.api_list
|
@@ -576,24 +545,96 @@ def generate_api_key():
|
|
576 |
return JSONResponse(content={"api_key": api_key})
|
577 |
|
578 |
# 在 /stats 路由中返回成功和失败百分比
|
|
|
|
|
|
|
|
|
|
|
|
|
579 |
@app.get("/stats", dependencies=[Depends(rate_limit_dependency)])
|
580 |
async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)):
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
597 |
|
598 |
# async def on_fetch(request, env):
|
599 |
# import asgi
|
|
|
22 |
from urllib.parse import urlparse
|
23 |
|
24 |
import os
|
25 |
+
is_debug = bool(os.getenv("DEBUG", False))
|
26 |
+
|
27 |
+
async def create_tables():
|
28 |
+
async with engine.begin() as conn:
|
29 |
+
await conn.run_sync(Base.metadata.create_all)
|
30 |
|
31 |
@asynccontextmanager
|
32 |
async def lifespan(app: FastAPI):
|
33 |
# 启动时的代码
|
34 |
+
await create_tables()
|
|
|
|
|
|
|
35 |
|
36 |
TIMEOUT = float(os.getenv("TIMEOUT", 100))
|
37 |
timeout = httpx.Timeout(connect=15.0, read=TIMEOUT, write=30.0, pool=30.0)
|
|
|
67 |
from time import time
|
68 |
from collections import defaultdict
|
69 |
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
|
70 |
import json
|
|
|
71 |
|
72 |
async def parse_request_body(request: Request):
|
73 |
if request.method == "POST" and "application/json" in request.headers.get("content-type", ""):
|
|
|
77 |
return None
|
78 |
return None
|
79 |
|
80 |
+
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
81 |
+
from sqlalchemy.orm import declarative_base, sessionmaker
|
82 |
+
from sqlalchemy import Column, Integer, String, Float, DateTime, select, Boolean
|
83 |
+
from sqlalchemy.sql import func
|
84 |
+
|
85 |
+
# 定义数据库模型
|
86 |
+
Base = declarative_base()
|
87 |
+
|
88 |
+
class RequestStat(Base):
|
89 |
+
__tablename__ = 'request_stats'
|
90 |
+
id = Column(Integer, primary_key=True)
|
91 |
+
endpoint = Column(String)
|
92 |
+
ip = Column(String)
|
93 |
+
token = Column(String)
|
94 |
+
total_time = Column(Float)
|
95 |
+
model = Column(String)
|
96 |
+
timestamp = Column(DateTime(timezone=True), server_default=func.now())
|
97 |
+
|
98 |
+
class ChannelStat(Base):
|
99 |
+
__tablename__ = 'channel_stats'
|
100 |
+
id = Column(Integer, primary_key=True)
|
101 |
+
provider = Column(String)
|
102 |
+
model = Column(String)
|
103 |
+
api_key = Column(String)
|
104 |
+
success = Column(Boolean)
|
105 |
+
first_response_time = Column(Float) # 新增: 记录首次响应时间
|
106 |
+
timestamp = Column(DateTime(timezone=True), server_default=func.now())
|
107 |
+
|
108 |
+
# 创建异步引擎和会话
|
109 |
+
engine = create_async_engine('sqlite+aiosqlite:///stats.db', echo=is_debug)
|
110 |
+
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
111 |
+
|
112 |
class StatsMiddleware(BaseHTTPMiddleware):
|
113 |
+
def __init__(self, app):
|
114 |
super().__init__(app)
|
115 |
+
self.db = async_session()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
async def dispatch(self, request: Request, call_next):
|
118 |
+
if request.headers.get("x-api-key"):
|
119 |
+
token = request.headers.get("x-api-key")
|
120 |
+
elif request.headers.get("Authorization"):
|
121 |
+
token = request.headers.get("Authorization").split(" ")[1]
|
122 |
+
else:
|
123 |
+
token = None
|
124 |
+
|
125 |
start_time = time()
|
126 |
|
|
|
127 |
request.state.parsed_body = await parse_request_body(request)
|
128 |
|
129 |
model = "unknown"
|
|
|
142 |
endpoint = f"{request.method} {request.url.path}"
|
143 |
client_ip = request.client.host
|
144 |
|
145 |
+
# 异步更新数据库
|
146 |
+
await self.update_stats(endpoint, process_time, client_ip, model, token)
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
return response
|
149 |
|
150 |
+
async def update_stats(self, endpoint, process_time, client_ip, model, token):
|
151 |
+
async with self.db as session:
|
152 |
+
# 为每个请求创建一条新的记录
|
153 |
+
new_request_stat = RequestStat(
|
154 |
+
endpoint=endpoint,
|
155 |
+
ip=client_ip,
|
156 |
+
token=token,
|
157 |
+
total_time=process_time,
|
158 |
+
model=model
|
159 |
+
)
|
160 |
+
session.add(new_request_stat)
|
161 |
+
await session.commit()
|
162 |
+
|
163 |
+
async def update_channel_stats(self, provider, model, api_key, success, first_response_time):
|
164 |
+
async with self.db as session:
|
165 |
+
channel_stat = ChannelStat(
|
166 |
+
provider=provider,
|
167 |
+
model=model,
|
168 |
+
api_key=api_key,
|
169 |
+
success=success,
|
170 |
+
first_response_time=first_response_time
|
171 |
+
)
|
172 |
+
session.add(channel_stat)
|
173 |
+
await session.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
# 配置 CORS 中间件
|
176 |
app.add_middleware(
|
|
|
181 |
allow_headers=["*"], # 允许所有头部字段
|
182 |
)
|
183 |
|
184 |
+
app.add_middleware(StatsMiddleware)
|
185 |
|
186 |
# 在 process_request 函数中更新成功和失败计数
|
187 |
+
async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None, token=None):
|
188 |
url = provider['base_url']
|
189 |
parsed_url = urlparse(url)
|
190 |
# print("parsed_url", parsed_url)
|
|
|
239 |
if request.stream:
|
240 |
model = provider['model'][request.model]
|
241 |
generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
|
242 |
+
wrapped_generator, first_response_time = await error_handling_wrapper(generator)
|
243 |
response = StreamingResponse(wrapped_generator, media_type="text/event-stream")
|
244 |
else:
|
245 |
generator = fetch_response(app.state.client, url, headers, payload)
|
246 |
+
wrapped_generator, first_response_time = await error_handling_wrapper(generator)
|
247 |
first_element = await anext(wrapped_generator)
|
248 |
first_element = first_element.lstrip("data: ")
|
249 |
first_element = json.loads(first_element)
|
250 |
response = JSONResponse(first_element)
|
251 |
|
252 |
+
# 更新成功计数和首次响应时间
|
253 |
+
await app.middleware_stack.app.update_channel_stats(provider['provider'], request.model, token, success=True, first_response_time=first_response_time)
|
|
|
254 |
|
255 |
return response
|
256 |
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError) as e:
|
257 |
+
# 更新失败计数,首次响应时间为-1表示失败
|
258 |
+
await app.middleware_stack.app.update_channel_stats(provider['provider'], request.model, token, success=False, first_response_time=-1)
|
|
|
259 |
|
260 |
raise e
|
261 |
|
|
|
389 |
if safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY") == False:
|
390 |
auto_retry = False
|
391 |
|
392 |
+
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint, token)
|
393 |
|
394 |
# 在 try_all_providers 函数中处理失败的情况
|
395 |
+
async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None, token: str = None):
|
396 |
status_code = 500
|
397 |
error_message = None
|
398 |
num_providers = len(providers)
|
|
|
401 |
self.last_provider_index = (start_index + i) % num_providers
|
402 |
provider = providers[self.last_provider_index]
|
403 |
try:
|
404 |
+
response = await process_request(request, provider, endpoint, token)
|
405 |
return response
|
406 |
except HTTPException as e:
|
407 |
logger.error(f"Error with provider {provider['provider']}: {str(e)}")
|
|
|
478 |
return rate_limit
|
479 |
|
480 |
security = HTTPBearer()
|
481 |
+
|
482 |
async def rate_limit_dependency(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
483 |
token = credentials.credentials if credentials else None
|
484 |
api_list = app.state.api_list
|
|
|
545 |
return JSONResponse(content={"api_key": api_key})
|
546 |
|
547 |
# 在 /stats 路由中返回成功和失败百分比
|
548 |
+
from collections import defaultdict
|
549 |
+
from sqlalchemy import func
|
550 |
+
|
551 |
+
from collections import defaultdict
|
552 |
+
from sqlalchemy import func, desc, case
|
553 |
+
|
554 |
@app.get("/stats", dependencies=[Depends(rate_limit_dependency)])
|
555 |
async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)):
|
556 |
+
async with async_session() as session:
|
557 |
+
# 1. 每个渠道下面每个模型的成功率
|
558 |
+
channel_model_stats = await session.execute(
|
559 |
+
select(
|
560 |
+
ChannelStat.provider,
|
561 |
+
ChannelStat.model,
|
562 |
+
func.count().label('total'),
|
563 |
+
func.sum(case((ChannelStat.success == True, 1), else_=0)).label('success_count')
|
564 |
+
).group_by(ChannelStat.provider, ChannelStat.model)
|
565 |
+
)
|
566 |
+
channel_model_stats = channel_model_stats.fetchall()
|
567 |
+
|
568 |
+
# 2. 每个渠道总的成功率
|
569 |
+
channel_stats = await session.execute(
|
570 |
+
select(
|
571 |
+
ChannelStat.provider,
|
572 |
+
func.count().label('total'),
|
573 |
+
func.sum(case((ChannelStat.success == True, 1), else_=0)).label('success_count')
|
574 |
+
).group_by(ChannelStat.provider)
|
575 |
+
)
|
576 |
+
channel_stats = channel_stats.fetchall()
|
577 |
+
|
578 |
+
# 3. 每个模型在所有渠道总的请求次数
|
579 |
+
model_stats = await session.execute(
|
580 |
+
select(ChannelStat.model, func.count().label('count'))
|
581 |
+
.group_by(ChannelStat.model)
|
582 |
+
.order_by(desc('count'))
|
583 |
+
)
|
584 |
+
model_stats = model_stats.fetchall()
|
585 |
+
|
586 |
+
# 4. 每个端点的请求次数
|
587 |
+
endpoint_stats = await session.execute(
|
588 |
+
select(RequestStat.endpoint, func.count().label('count'))
|
589 |
+
.group_by(RequestStat.endpoint)
|
590 |
+
.order_by(desc('count'))
|
591 |
+
)
|
592 |
+
endpoint_stats = endpoint_stats.fetchall()
|
593 |
+
|
594 |
+
# 5. 每个ip请求的次数
|
595 |
+
ip_stats = await session.execute(
|
596 |
+
select(RequestStat.ip, func.count().label('count'))
|
597 |
+
.group_by(RequestStat.ip)
|
598 |
+
.order_by(desc('count'))
|
599 |
+
)
|
600 |
+
ip_stats = ip_stats.fetchall()
|
601 |
+
|
602 |
+
# 处理统计数据并返回
|
603 |
+
stats = {
|
604 |
+
"channel_model_success_rates": [
|
605 |
+
{
|
606 |
+
"provider": stat.provider,
|
607 |
+
"model": stat.model,
|
608 |
+
"success_rate": stat.success_count / stat.total if stat.total > 0 else 0
|
609 |
+
} for stat in sorted(channel_model_stats, key=lambda x: x.success_count / x.total if x.total > 0 else 0, reverse=True)
|
610 |
+
],
|
611 |
+
"channel_success_rates": [
|
612 |
+
{
|
613 |
+
"provider": stat.provider,
|
614 |
+
"success_rate": stat.success_count / stat.total if stat.total > 0 else 0
|
615 |
+
} for stat in sorted(channel_stats, key=lambda x: x.success_count / x.total if x.total > 0 else 0, reverse=True)
|
616 |
+
],
|
617 |
+
"model_request_counts": [
|
618 |
+
{
|
619 |
+
"model": stat.model,
|
620 |
+
"count": stat.count
|
621 |
+
} for stat in model_stats
|
622 |
+
],
|
623 |
+
"endpoint_request_counts": [
|
624 |
+
{
|
625 |
+
"endpoint": stat.endpoint,
|
626 |
+
"count": stat.count
|
627 |
+
} for stat in endpoint_stats
|
628 |
+
],
|
629 |
+
"ip_request_counts": [
|
630 |
+
{
|
631 |
+
"ip": stat.ip,
|
632 |
+
"count": stat.count
|
633 |
+
} for stat in ip_stats
|
634 |
+
]
|
635 |
+
}
|
636 |
+
|
637 |
+
return JSONResponse(content=stats)
|
638 |
|
639 |
# async def on_fetch(request, env):
|
640 |
# import asgi
|
requirements.txt
CHANGED
@@ -3,6 +3,9 @@ pytest
|
|
3 |
uvicorn
|
4 |
fastapi
|
5 |
aiofiles
|
|
|
|
|
|
|
6 |
watchfiles
|
7 |
httpx[http2]
|
8 |
cryptography
|
|
|
3 |
uvicorn
|
4 |
fastapi
|
5 |
aiofiles
|
6 |
+
greenlet
|
7 |
+
aiosqlite
|
8 |
+
sqlalchemy
|
9 |
watchfiles
|
10 |
httpx[http2]
|
11 |
cryptography
|
test/provider_test.py
CHANGED
@@ -70,7 +70,8 @@ def test_request_model(test_client, api_key, get_model):
|
|
70 |
}
|
71 |
}
|
72 |
}
|
73 |
-
]
|
|
|
74 |
}
|
75 |
|
76 |
headers = {
|
|
|
70 |
}
|
71 |
}
|
72 |
}
|
73 |
+
],
|
74 |
+
"tool_choice": "auto"
|
75 |
}
|
76 |
|
77 |
headers = {
|
utils.py
CHANGED
@@ -116,9 +116,12 @@ def ensure_string(item):
|
|
116 |
return str(item)
|
117 |
|
118 |
import asyncio
|
|
|
119 |
async def error_handling_wrapper(generator):
|
|
|
120 |
try:
|
121 |
first_item = await generator.__anext__()
|
|
|
122 |
first_item_str = first_item
|
123 |
# logger.info("first_item_str: %s", first_item_str)
|
124 |
if isinstance(first_item_str, (bytes, bytearray)):
|
@@ -153,7 +156,7 @@ async def error_handling_wrapper(generator):
|
|
153 |
logger.error(f"Network error in new_generator: {e}")
|
154 |
raise
|
155 |
|
156 |
-
return new_generator()
|
157 |
|
158 |
except StopAsyncIteration:
|
159 |
raise HTTPException(status_code=400, detail="data: {'error': 'No data returned'}")
|
|
|
116 |
return str(item)
|
117 |
|
118 |
import asyncio
|
119 |
+
import time as time_module
|
120 |
async def error_handling_wrapper(generator):
|
121 |
+
start_time = time_module.time()
|
122 |
try:
|
123 |
first_item = await generator.__anext__()
|
124 |
+
first_response_time = time_module.time() - start_time
|
125 |
first_item_str = first_item
|
126 |
# logger.info("first_item_str: %s", first_item_str)
|
127 |
if isinstance(first_item_str, (bytes, bytearray)):
|
|
|
156 |
logger.error(f"Network error in new_generator: {e}")
|
157 |
raise
|
158 |
|
159 |
+
return new_generator(), first_response_time
|
160 |
|
161 |
except StopAsyncIteration:
|
162 |
raise HTTPException(status_code=400, detail="data: {'error': 'No data returned'}")
|