yym68686 commited on
Commit
2c0a348
·
1 Parent(s): 6719e81

🐛 Bug: Fix the bug where vercel cannot set app.state.config.

Browse files
Files changed (3) hide show
  1. README.md +2 -0
  2. README_CN.md +1 -1
  3. main.py +58 -67
README.md CHANGED
@@ -192,6 +192,8 @@ There are other statistical data that you can query yourself by writing SQL in t
192
 
193
  [![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https%3A%2F%2Fgithub.com%2Fyym68686%2Funi-api%2Ftree%2Fmain&env=CONFIG_URL,DISABLE_DATABASE&project-name=uni-api-vercel&repository-name=uni-api-vercel)
194
 
 
 
195
  ## Docker local deployment
196
 
197
  Start the container
 
192
 
193
  [![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https%3A%2F%2Fgithub.com%2Fyym68686%2Funi-api%2Ftree%2Fmain&env=CONFIG_URL,DISABLE_DATABASE&project-name=uni-api-vercel&repository-name=uni-api-vercel)
194
 
195
+ After clicking the one-click deployment button, set the environment variable `CONFIG_URL` to the direct link of the configuration file, and set `DISABLE_DATABASE` to true, then click Create to create the project.
196
+
197
  ## Docker local deployment
198
 
199
  Start the container
README_CN.md CHANGED
@@ -192,7 +192,7 @@ yym68686/uni-api:latest
192
 
193
  [![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https%3A%2F%2Fgithub.com%2Fyym68686%2Funi-api%2Ftree%2Fmain&env=CONFIG_URL,DISABLE_DATABASE&project-name=uni-api-vercel&repository-name=uni-api-vercel)
194
 
195
- 点击上面的一键部署按钮后,设置环境变量 `CONFIG_URL` 为配置文件的直链,然后点击 Create 创建项目。
196
 
197
  ## Docker 本地部署
198
 
 
192
 
193
  [![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https%3A%2F%2Fgithub.com%2Fyym68686%2Funi-api%2Ftree%2Fmain&env=CONFIG_URL,DISABLE_DATABASE&project-name=uni-api-vercel&repository-name=uni-api-vercel)
194
 
195
+ 点击上面的一键部署按钮后,设置环境变量 `CONFIG_URL` 为配置文件的直链, `DISABLE_DATABASE` 为 true,然后点击 Create 创建项目。
196
 
197
  ## Docker 本地部署
198
 
main.py CHANGED
@@ -106,17 +106,6 @@ async def lifespan(app: FastAPI):
106
  verify=True, # 保持 SSL 验证(如需禁用,设为 False,但不建议)
107
  follow_redirects=True, # 自动跟随重定向
108
  )
109
- # app.state.client = httpx.AsyncClient(timeout=timeout)
110
- app.state.config, app.state.api_keys_db, app.state.api_list = await load_config(app)
111
-
112
- for item in app.state.api_keys_db:
113
- if item.get("role") == "admin":
114
- app.state.admin_api_key = item.get("api")
115
- if not hasattr(app.state, "admin_api_key"):
116
- if len(app.state.api_keys_db) >= 1:
117
- app.state.admin_api_key = app.state.api_keys_db[0].get("api")
118
- else:
119
- raise Exception("No admin API key found")
120
 
121
  yield
122
  # 关闭时的代码
@@ -224,6 +213,41 @@ def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> Decimal
224
  # 返回精确到15位小数的结果
225
  return total_cost.quantize(Decimal('0.000000000000001'))
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  class LoggingStreamingResponse(Response):
228
  def __init__(self, content, status_code=200, headers=None, media_type=None, current_info=None):
229
  super().__init__(content=None, status_code=status_code, headers=headers, media_type=media_type)
@@ -263,31 +287,14 @@ class LoggingStreamingResponse(Response):
263
 
264
  process_time = time() - self.current_info["start_time"]
265
  self.current_info["process_time"] = process_time
266
- await self.update_stats()
267
-
268
- async def update_stats(self):
269
- # 这里添加更新数据库的逻辑
270
- # print("current_info2")
271
- if DISABLE_DATABASE:
272
- return
273
- async with async_session() as session:
274
- async with session.begin():
275
- try:
276
- columns = [column.key for column in RequestStat.__table__.columns]
277
- filtered_info = {k: v for k, v in self.current_info.items() if k in columns}
278
- new_request_stat = RequestStat(**filtered_info)
279
- session.add(new_request_stat)
280
- await session.commit()
281
- except Exception as e:
282
- await session.rollback()
283
- logger.error(f"Error updating stats: {str(e)}")
284
 
285
  async def _logging_iterator(self):
286
  try:
287
  async for chunk in self.body_iterator:
288
  if isinstance(chunk, str):
289
  chunk = chunk.encode('utf-8')
290
- line = chunk.decode()
291
  if is_debug:
292
  logger.info(f"{line}")
293
  if line.startswith("data:"):
@@ -435,41 +442,6 @@ class StatsMiddleware(BaseHTTPMiddleware):
435
  # print("current_request_info", current_request_info)
436
  request_info.reset(current_request_info)
437
 
438
- async def update_stats(self, current_info):
439
- if DISABLE_DATABASE:
440
- return
441
- # 这里添加更新数据库的逻辑
442
- async with async_session() as session:
443
- async with session.begin():
444
- try:
445
- columns = [column.key for column in RequestStat.__table__.columns]
446
- filtered_info = {k: v for k, v in current_info.items() if k in columns}
447
- new_request_stat = RequestStat(**filtered_info)
448
- session.add(new_request_stat)
449
- await session.commit()
450
- except Exception as e:
451
- await session.rollback()
452
- logger.error(f"Error updating stats: {str(e)}")
453
-
454
- async def update_channel_stats(self, request_id, provider, model, api_key, success):
455
- if DISABLE_DATABASE:
456
- return
457
- async with async_session() as session:
458
- async with session.begin():
459
- try:
460
- channel_stat = ChannelStat(
461
- request_id=request_id,
462
- provider=provider,
463
- model=model,
464
- api_key=api_key,
465
- success=success,
466
- )
467
- session.add(channel_stat)
468
- await session.commit()
469
- except Exception as e:
470
- await session.rollback()
471
- logger.error(f"Error updating channel stats: {str(e)}")
472
-
473
  async def moderate_content(self, content, token):
474
  moderation_request = ModerationRequest(input=content)
475
 
@@ -500,6 +472,23 @@ app.add_middleware(
500
 
501
  app.add_middleware(StatsMiddleware)
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  # 在 process_request 函数中更新成功和失败计数
504
  async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], provider: Dict, endpoint=None, token=None):
505
  url = provider['base_url']
@@ -581,14 +570,16 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
581
  # response = JSONResponse(first_element)
582
 
583
  # 更新成功计数和首次响应时间
584
- await app.middleware_stack.app.update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=True)
 
585
  current_info["first_response_time"] = first_response_time
586
  current_info["success"] = True
587
  current_info["provider"] = provider['provider']
588
 
589
  return response
590
  except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError) as e:
591
- await app.middleware_stack.app.update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=False)
 
592
 
593
  raise e
594
 
 
106
  verify=True, # 保持 SSL 验证(如需禁用,设为 False,但不建议)
107
  follow_redirects=True, # 自动跟随重定向
108
  )
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  yield
111
  # 关闭时的代码
 
213
  # 返回精确到15位小数的结果
214
  return total_cost.quantize(Decimal('0.000000000000001'))
215
 
216
+ async def update_stats(current_info):
217
+ if DISABLE_DATABASE:
218
+ return
219
+ # 这里添加更新数据库的逻辑
220
+ async with async_session() as session:
221
+ async with session.begin():
222
+ try:
223
+ columns = [column.key for column in RequestStat.__table__.columns]
224
+ filtered_info = {k: v for k, v in current_info.items() if k in columns}
225
+ new_request_stat = RequestStat(**filtered_info)
226
+ session.add(new_request_stat)
227
+ await session.commit()
228
+ except Exception as e:
229
+ await session.rollback()
230
+ logger.error(f"Error updating stats: {str(e)}")
231
+
232
+ async def update_channel_stats(request_id, provider, model, api_key, success):
233
+ if DISABLE_DATABASE:
234
+ return
235
+ async with async_session() as session:
236
+ async with session.begin():
237
+ try:
238
+ channel_stat = ChannelStat(
239
+ request_id=request_id,
240
+ provider=provider,
241
+ model=model,
242
+ api_key=api_key,
243
+ success=success,
244
+ )
245
+ session.add(channel_stat)
246
+ await session.commit()
247
+ except Exception as e:
248
+ await session.rollback()
249
+ logger.error(f"Error updating channel stats: {str(e)}")
250
+
251
  class LoggingStreamingResponse(Response):
252
  def __init__(self, content, status_code=200, headers=None, media_type=None, current_info=None):
253
  super().__init__(content=None, status_code=status_code, headers=headers, media_type=media_type)
 
287
 
288
  process_time = time() - self.current_info["start_time"]
289
  self.current_info["process_time"] = process_time
290
+ await update_stats(self.current_info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
  async def _logging_iterator(self):
293
  try:
294
  async for chunk in self.body_iterator:
295
  if isinstance(chunk, str):
296
  chunk = chunk.encode('utf-8')
297
+ line = chunk.decode('utf-8')
298
  if is_debug:
299
  logger.info(f"{line}")
300
  if line.startswith("data:"):
 
442
  # print("current_request_info", current_request_info)
443
  request_info.reset(current_request_info)
444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  async def moderate_content(self, content, token):
446
  moderation_request = ModerationRequest(input=content)
447
 
 
472
 
473
  app.add_middleware(StatsMiddleware)
474
 
475
+ @app.middleware("http")
476
+ async def ensure_config(request: Request, call_next):
477
+ if not hasattr(app.state, 'config'):
478
+ logger.warning("Config not found, attempting to reload")
479
+ app.state.config, app.state.api_keys_db, app.state.api_list = await load_config(app)
480
+
481
+ for item in app.state.api_keys_db:
482
+ if item.get("role") == "admin":
483
+ app.state.admin_api_key = item.get("api")
484
+ if not hasattr(app.state, "admin_api_key"):
485
+ if len(app.state.api_keys_db) >= 1:
486
+ app.state.admin_api_key = app.state.api_keys_db[0].get("api")
487
+ else:
488
+ raise Exception("No admin API key found")
489
+
490
+ return await call_next(request)
491
+
492
  # 在 process_request 函数中更新成功和失败计数
493
  async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], provider: Dict, endpoint=None, token=None):
494
  url = provider['base_url']
 
570
  # response = JSONResponse(first_element)
571
 
572
  # 更新成功计数和首次响应时间
573
+ await update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=True)
574
+ # await app.middleware_stack.app.update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=True)
575
  current_info["first_response_time"] = first_response_time
576
  current_info["success"] = True
577
  current_info["provider"] = provider['provider']
578
 
579
  return response
580
  except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError) as e:
581
+ await update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=False)
582
+ # await app.middleware_stack.app.update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=False)
583
 
584
  raise e
585