yym68686 commited on
Commit
cfd3f47
·
1 Parent(s): 95ca783

🐛 Bug: 1. Fix the bug where the API key is not found when rate limiting.

Browse files

2. Fix the bug where the characters before the slash in the model name with a slash are parsed as the channel name.

Files changed (1) hide show
  1. main.py +20 -10
main.py CHANGED
@@ -275,20 +275,26 @@ class ModelRequestHandler:
275
  for model in config['api_keys'][api_index]['model']:
276
  if "/" in model:
277
  provider_name = model.split("/")[0]
278
- model = model.split("/")[1]
279
  models_list = []
280
  for provider in config['providers']:
281
  if provider['provider'] == provider_name:
282
  models_list.extend(list(provider['model'].keys()))
283
  # print("models_list", models_list)
284
  # print("model_name", model_name)
 
 
 
 
 
 
285
  # print("model", model)
286
- if (model and model_name in models_list) or (model == "*" and model_name in models_list):
287
  provider_rules.append(provider_name)
288
  else:
289
  for provider in config['providers']:
290
  if model in provider['model'].keys():
291
- provider_rules.append(provider['provider'] + "/" + model)
292
 
293
  provider_list = []
294
  # print("provider_rules", provider_rules)
@@ -297,7 +303,7 @@ class ModelRequestHandler:
297
  # print("provider", provider, provider['provider'] == item, item)
298
  if "/" in item:
299
  if provider['provider'] == item.split("/")[0]:
300
- if model_name in provider['model'].keys() and item.split("/")[1] == model_name:
301
  provider_list.append(provider)
302
  elif provider['provider'] == item:
303
  if model_name in provider['model'].keys():
@@ -422,15 +428,13 @@ class InMemoryRateLimiter:
422
 
423
  rate_limiter = InMemoryRateLimiter()
424
 
425
- async def get_user_rate_limit(token: str = None):
426
  # 这里应该实现根据 token 获取用户速率限制的逻辑
427
  # 示例: 返回 (次数, 秒数)
428
  config = app.state.config
429
- api_list = app.state.api_list
430
- api_index = api_list.index(token)
431
  raw_rate_limit = safe_get(config, 'api_keys', api_index, "preferences", "RATE_LIMIT")
432
 
433
- if not token or not raw_rate_limit:
434
  return (60, 60)
435
 
436
  rate_limit = parse_rate_limit(raw_rate_limit)
@@ -439,8 +443,14 @@ async def get_user_rate_limit(token: str = None):
439
  security = HTTPBearer()
440
  async def rate_limit_dependency(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
441
  token = credentials.credentials if credentials else None
442
- # print("token", token)
443
- limit, period = await get_user_rate_limit(token)
 
 
 
 
 
 
444
 
445
  # 使用 IP 地址和 token(如果有)作为限制键
446
  client_ip = request.client.host
 
275
  for model in config['api_keys'][api_index]['model']:
276
  if "/" in model:
277
  provider_name = model.split("/")[0]
278
+ model_name_split = "/".join(model.split("/")[1:])
279
  models_list = []
280
  for provider in config['providers']:
281
  if provider['provider'] == provider_name:
282
  models_list.extend(list(provider['model'].keys()))
283
  # print("models_list", models_list)
284
  # print("model_name", model_name)
285
+
286
+ # 处理带斜杠的模型名
287
+ for provider in config['providers']:
288
+ if model in provider['model'].keys():
289
+ provider_rules.append(provider['provider'] + "/" + model)
290
+
291
  # print("model", model)
292
+ if (model_name_split and model_name in models_list) or (model_name_split == "*" and model_name in models_list):
293
  provider_rules.append(provider_name)
294
  else:
295
  for provider in config['providers']:
296
  if model in provider['model'].keys():
297
+ provider_rules.append(provider['provider'] + "/" + model_name_split)
298
 
299
  provider_list = []
300
  # print("provider_rules", provider_rules)
 
303
  # print("provider", provider, provider['provider'] == item, item)
304
  if "/" in item:
305
  if provider['provider'] == item.split("/")[0]:
306
+ if model_name in provider['model'].keys() and "/".join(item.split("/")[1:]) == model_name:
307
  provider_list.append(provider)
308
  elif provider['provider'] == item:
309
  if model_name in provider['model'].keys():
 
428
 
429
  rate_limiter = InMemoryRateLimiter()
430
 
431
+ async def get_user_rate_limit(api_index: str = None):
432
  # 这里应该实现根据 token 获取用户速率限制的逻辑
433
  # 示例: 返回 (次数, 秒数)
434
  config = app.state.config
 
 
435
  raw_rate_limit = safe_get(config, 'api_keys', api_index, "preferences", "RATE_LIMIT")
436
 
437
+ if not api_index or not raw_rate_limit:
438
  return (60, 60)
439
 
440
  rate_limit = parse_rate_limit(raw_rate_limit)
 
443
  security = HTTPBearer()
444
  async def rate_limit_dependency(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
445
  token = credentials.credentials if credentials else None
446
+ api_list = app.state.api_list
447
+ try:
448
+ api_index = api_list.index(token)
449
+ except ValueError:
450
+ print("error: Invalid or missing API Key:", token)
451
+ api_index = None
452
+ token = None
453
+ limit, period = await get_user_rate_limit(api_index)
454
 
455
  # 使用 IP 地址和 token(如果有)作为限制键
456
  client_ip = request.client.host