🐛 Bug: Fix the bug where vercel cannot set app.state.config.
Browse files- README.md +2 -0
- README_CN.md +1 -1
- main.py +58 -67
README.md
CHANGED
@@ -192,6 +192,8 @@ There are other statistical data that you can query yourself by writing SQL in t
|
|
192 |
|
193 |
[![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https%3A%2F%2Fgithub.com%2Fyym68686%2Funi-api%2Ftree%2Fmain&env=CONFIG_URL,DISABLE_DATABASE&project-name=uni-api-vercel&repository-name=uni-api-vercel)
|
194 |
|
|
|
|
|
195 |
## Docker local deployment
|
196 |
|
197 |
Start the container
|
|
|
192 |
|
193 |
[![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https%3A%2F%2Fgithub.com%2Fyym68686%2Funi-api%2Ftree%2Fmain&env=CONFIG_URL,DISABLE_DATABASE&project-name=uni-api-vercel&repository-name=uni-api-vercel)
|
194 |
|
195 |
+
After clicking the one-click deployment button, set the environment variable `CONFIG_URL` to the direct link of the configuration file, and set `DISABLE_DATABASE` to true, then click Create to create the project.
|
196 |
+
|
197 |
## Docker local deployment
|
198 |
|
199 |
Start the container
|
README_CN.md
CHANGED
@@ -192,7 +192,7 @@ yym68686/uni-api:latest
|
|
192 |
|
193 |
[![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https%3A%2F%2Fgithub.com%2Fyym68686%2Funi-api%2Ftree%2Fmain&env=CONFIG_URL,DISABLE_DATABASE&project-name=uni-api-vercel&repository-name=uni-api-vercel)
|
194 |
|
195 |
-
点击上面的一键部署按钮后,设置环境变量 `CONFIG_URL`
|
196 |
|
197 |
## Docker 本地部署
|
198 |
|
|
|
192 |
|
193 |
[![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https%3A%2F%2Fgithub.com%2Fyym68686%2Funi-api%2Ftree%2Fmain&env=CONFIG_URL,DISABLE_DATABASE&project-name=uni-api-vercel&repository-name=uni-api-vercel)
|
194 |
|
195 |
+
点击上面的一键部署按钮后,设置环境变量 `CONFIG_URL` 为配置文件的直链, `DISABLE_DATABASE` 为 true,然后点击 Create 创建项目。
|
196 |
|
197 |
## Docker 本地部署
|
198 |
|
main.py
CHANGED
@@ -106,17 +106,6 @@ async def lifespan(app: FastAPI):
|
|
106 |
verify=True, # 保持 SSL 验证(如需禁用,设为 False,但不建议)
|
107 |
follow_redirects=True, # 自动跟随重定向
|
108 |
)
|
109 |
-
# app.state.client = httpx.AsyncClient(timeout=timeout)
|
110 |
-
app.state.config, app.state.api_keys_db, app.state.api_list = await load_config(app)
|
111 |
-
|
112 |
-
for item in app.state.api_keys_db:
|
113 |
-
if item.get("role") == "admin":
|
114 |
-
app.state.admin_api_key = item.get("api")
|
115 |
-
if not hasattr(app.state, "admin_api_key"):
|
116 |
-
if len(app.state.api_keys_db) >= 1:
|
117 |
-
app.state.admin_api_key = app.state.api_keys_db[0].get("api")
|
118 |
-
else:
|
119 |
-
raise Exception("No admin API key found")
|
120 |
|
121 |
yield
|
122 |
# 关闭时的代码
|
@@ -224,6 +213,41 @@ def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> Decimal
|
|
224 |
# 返回精确到15位小数的结果
|
225 |
return total_cost.quantize(Decimal('0.000000000000001'))
|
226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
class LoggingStreamingResponse(Response):
|
228 |
def __init__(self, content, status_code=200, headers=None, media_type=None, current_info=None):
|
229 |
super().__init__(content=None, status_code=status_code, headers=headers, media_type=media_type)
|
@@ -263,31 +287,14 @@ class LoggingStreamingResponse(Response):
|
|
263 |
|
264 |
process_time = time() - self.current_info["start_time"]
|
265 |
self.current_info["process_time"] = process_time
|
266 |
-
await self.
|
267 |
-
|
268 |
-
async def update_stats(self):
|
269 |
-
# 这里添加更新数据库的逻辑
|
270 |
-
# print("current_info2")
|
271 |
-
if DISABLE_DATABASE:
|
272 |
-
return
|
273 |
-
async with async_session() as session:
|
274 |
-
async with session.begin():
|
275 |
-
try:
|
276 |
-
columns = [column.key for column in RequestStat.__table__.columns]
|
277 |
-
filtered_info = {k: v for k, v in self.current_info.items() if k in columns}
|
278 |
-
new_request_stat = RequestStat(**filtered_info)
|
279 |
-
session.add(new_request_stat)
|
280 |
-
await session.commit()
|
281 |
-
except Exception as e:
|
282 |
-
await session.rollback()
|
283 |
-
logger.error(f"Error updating stats: {str(e)}")
|
284 |
|
285 |
async def _logging_iterator(self):
|
286 |
try:
|
287 |
async for chunk in self.body_iterator:
|
288 |
if isinstance(chunk, str):
|
289 |
chunk = chunk.encode('utf-8')
|
290 |
-
line = chunk.decode()
|
291 |
if is_debug:
|
292 |
logger.info(f"{line}")
|
293 |
if line.startswith("data:"):
|
@@ -435,41 +442,6 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
435 |
# print("current_request_info", current_request_info)
|
436 |
request_info.reset(current_request_info)
|
437 |
|
438 |
-
async def update_stats(self, current_info):
|
439 |
-
if DISABLE_DATABASE:
|
440 |
-
return
|
441 |
-
# 这里添加更新数据库的逻辑
|
442 |
-
async with async_session() as session:
|
443 |
-
async with session.begin():
|
444 |
-
try:
|
445 |
-
columns = [column.key for column in RequestStat.__table__.columns]
|
446 |
-
filtered_info = {k: v for k, v in current_info.items() if k in columns}
|
447 |
-
new_request_stat = RequestStat(**filtered_info)
|
448 |
-
session.add(new_request_stat)
|
449 |
-
await session.commit()
|
450 |
-
except Exception as e:
|
451 |
-
await session.rollback()
|
452 |
-
logger.error(f"Error updating stats: {str(e)}")
|
453 |
-
|
454 |
-
async def update_channel_stats(self, request_id, provider, model, api_key, success):
|
455 |
-
if DISABLE_DATABASE:
|
456 |
-
return
|
457 |
-
async with async_session() as session:
|
458 |
-
async with session.begin():
|
459 |
-
try:
|
460 |
-
channel_stat = ChannelStat(
|
461 |
-
request_id=request_id,
|
462 |
-
provider=provider,
|
463 |
-
model=model,
|
464 |
-
api_key=api_key,
|
465 |
-
success=success,
|
466 |
-
)
|
467 |
-
session.add(channel_stat)
|
468 |
-
await session.commit()
|
469 |
-
except Exception as e:
|
470 |
-
await session.rollback()
|
471 |
-
logger.error(f"Error updating channel stats: {str(e)}")
|
472 |
-
|
473 |
async def moderate_content(self, content, token):
|
474 |
moderation_request = ModerationRequest(input=content)
|
475 |
|
@@ -500,6 +472,23 @@ app.add_middleware(
|
|
500 |
|
501 |
app.add_middleware(StatsMiddleware)
|
502 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
# 在 process_request 函数中更新成功和失败计数
|
504 |
async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], provider: Dict, endpoint=None, token=None):
|
505 |
url = provider['base_url']
|
@@ -581,14 +570,16 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
581 |
# response = JSONResponse(first_element)
|
582 |
|
583 |
# 更新成功计数和首次响应时间
|
584 |
-
await
|
|
|
585 |
current_info["first_response_time"] = first_response_time
|
586 |
current_info["success"] = True
|
587 |
current_info["provider"] = provider['provider']
|
588 |
|
589 |
return response
|
590 |
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError) as e:
|
591 |
-
await
|
|
|
592 |
|
593 |
raise e
|
594 |
|
|
|
106 |
verify=True, # 保持 SSL 验证(如需禁用,设为 False,但不建议)
|
107 |
follow_redirects=True, # 自动跟随重定向
|
108 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
yield
|
111 |
# 关闭时的代码
|
|
|
213 |
# 返回精确到15位小数的结果
|
214 |
return total_cost.quantize(Decimal('0.000000000000001'))
|
215 |
|
216 |
+
async def update_stats(current_info):
|
217 |
+
if DISABLE_DATABASE:
|
218 |
+
return
|
219 |
+
# 这里添加更新数据库的逻辑
|
220 |
+
async with async_session() as session:
|
221 |
+
async with session.begin():
|
222 |
+
try:
|
223 |
+
columns = [column.key for column in RequestStat.__table__.columns]
|
224 |
+
filtered_info = {k: v for k, v in current_info.items() if k in columns}
|
225 |
+
new_request_stat = RequestStat(**filtered_info)
|
226 |
+
session.add(new_request_stat)
|
227 |
+
await session.commit()
|
228 |
+
except Exception as e:
|
229 |
+
await session.rollback()
|
230 |
+
logger.error(f"Error updating stats: {str(e)}")
|
231 |
+
|
232 |
+
async def update_channel_stats(request_id, provider, model, api_key, success):
|
233 |
+
if DISABLE_DATABASE:
|
234 |
+
return
|
235 |
+
async with async_session() as session:
|
236 |
+
async with session.begin():
|
237 |
+
try:
|
238 |
+
channel_stat = ChannelStat(
|
239 |
+
request_id=request_id,
|
240 |
+
provider=provider,
|
241 |
+
model=model,
|
242 |
+
api_key=api_key,
|
243 |
+
success=success,
|
244 |
+
)
|
245 |
+
session.add(channel_stat)
|
246 |
+
await session.commit()
|
247 |
+
except Exception as e:
|
248 |
+
await session.rollback()
|
249 |
+
logger.error(f"Error updating channel stats: {str(e)}")
|
250 |
+
|
251 |
class LoggingStreamingResponse(Response):
|
252 |
def __init__(self, content, status_code=200, headers=None, media_type=None, current_info=None):
|
253 |
super().__init__(content=None, status_code=status_code, headers=headers, media_type=media_type)
|
|
|
287 |
|
288 |
process_time = time() - self.current_info["start_time"]
|
289 |
self.current_info["process_time"] = process_time
|
290 |
+
await update_stats(self.current_info)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
|
292 |
async def _logging_iterator(self):
|
293 |
try:
|
294 |
async for chunk in self.body_iterator:
|
295 |
if isinstance(chunk, str):
|
296 |
chunk = chunk.encode('utf-8')
|
297 |
+
line = chunk.decode('utf-8')
|
298 |
if is_debug:
|
299 |
logger.info(f"{line}")
|
300 |
if line.startswith("data:"):
|
|
|
442 |
# print("current_request_info", current_request_info)
|
443 |
request_info.reset(current_request_info)
|
444 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
async def moderate_content(self, content, token):
|
446 |
moderation_request = ModerationRequest(input=content)
|
447 |
|
|
|
472 |
|
473 |
app.add_middleware(StatsMiddleware)
|
474 |
|
475 |
+
@app.middleware("http")
|
476 |
+
async def ensure_config(request: Request, call_next):
|
477 |
+
if not hasattr(app.state, 'config'):
|
478 |
+
logger.warning("Config not found, attempting to reload")
|
479 |
+
app.state.config, app.state.api_keys_db, app.state.api_list = await load_config(app)
|
480 |
+
|
481 |
+
for item in app.state.api_keys_db:
|
482 |
+
if item.get("role") == "admin":
|
483 |
+
app.state.admin_api_key = item.get("api")
|
484 |
+
if not hasattr(app.state, "admin_api_key"):
|
485 |
+
if len(app.state.api_keys_db) >= 1:
|
486 |
+
app.state.admin_api_key = app.state.api_keys_db[0].get("api")
|
487 |
+
else:
|
488 |
+
raise Exception("No admin API key found")
|
489 |
+
|
490 |
+
return await call_next(request)
|
491 |
+
|
492 |
# 在 process_request 函数中更新成功和失败计数
|
493 |
async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], provider: Dict, endpoint=None, token=None):
|
494 |
url = provider['base_url']
|
|
|
570 |
# response = JSONResponse(first_element)
|
571 |
|
572 |
# 更新成功计数和首次响应时间
|
573 |
+
await update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=True)
|
574 |
+
# await app.middleware_stack.app.update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=True)
|
575 |
current_info["first_response_time"] = first_response_time
|
576 |
current_info["success"] = True
|
577 |
current_info["provider"] = provider['provider']
|
578 |
|
579 |
return response
|
580 |
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError) as e:
|
581 |
+
await update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=False)
|
582 |
+
# await app.middleware_stack.app.update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=False)
|
583 |
|
584 |
raise e
|
585 |
|