🐛 Bug: 1. Fix the bug where the API key is not found when rate limiting.
Browse files2. Fix the bug where the characters before the slash in the model name with a slash are parsed as the channel name.
main.py
CHANGED
@@ -275,20 +275,26 @@ class ModelRequestHandler:
|
|
275 |
for model in config['api_keys'][api_index]['model']:
|
276 |
if "/" in model:
|
277 |
provider_name = model.split("/")[0]
|
278 |
-
|
279 |
models_list = []
|
280 |
for provider in config['providers']:
|
281 |
if provider['provider'] == provider_name:
|
282 |
models_list.extend(list(provider['model'].keys()))
|
283 |
# print("models_list", models_list)
|
284 |
# print("model_name", model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
# print("model", model)
|
286 |
-
if (
|
287 |
provider_rules.append(provider_name)
|
288 |
else:
|
289 |
for provider in config['providers']:
|
290 |
if model in provider['model'].keys():
|
291 |
-
provider_rules.append(provider['provider'] + "/" +
|
292 |
|
293 |
provider_list = []
|
294 |
# print("provider_rules", provider_rules)
|
@@ -297,7 +303,7 @@ class ModelRequestHandler:
|
|
297 |
# print("provider", provider, provider['provider'] == item, item)
|
298 |
if "/" in item:
|
299 |
if provider['provider'] == item.split("/")[0]:
|
300 |
-
if model_name in provider['model'].keys() and item.split("/")[1] == model_name:
|
301 |
provider_list.append(provider)
|
302 |
elif provider['provider'] == item:
|
303 |
if model_name in provider['model'].keys():
|
@@ -422,15 +428,13 @@ class InMemoryRateLimiter:
|
|
422 |
|
423 |
rate_limiter = InMemoryRateLimiter()
|
424 |
|
425 |
-
async def get_user_rate_limit(
|
426 |
# 这里应该实现根据 token 获取用户速率限制的逻辑
|
427 |
# 示例: 返回 (次数, 秒数)
|
428 |
config = app.state.config
|
429 |
-
api_list = app.state.api_list
|
430 |
-
api_index = api_list.index(token)
|
431 |
raw_rate_limit = safe_get(config, 'api_keys', api_index, "preferences", "RATE_LIMIT")
|
432 |
|
433 |
-
if not
|
434 |
return (60, 60)
|
435 |
|
436 |
rate_limit = parse_rate_limit(raw_rate_limit)
|
@@ -439,8 +443,14 @@ async def get_user_rate_limit(token: str = None):
|
|
439 |
security = HTTPBearer()
|
440 |
async def rate_limit_dependency(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
441 |
token = credentials.credentials if credentials else None
|
442 |
-
|
443 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
|
445 |
# 使用 IP 地址和 token(如果有)作为限制键
|
446 |
client_ip = request.client.host
|
|
|
275 |
for model in config['api_keys'][api_index]['model']:
|
276 |
if "/" in model:
|
277 |
provider_name = model.split("/")[0]
|
278 |
+
model_name_split = "/".join(model.split("/")[1:])
|
279 |
models_list = []
|
280 |
for provider in config['providers']:
|
281 |
if provider['provider'] == provider_name:
|
282 |
models_list.extend(list(provider['model'].keys()))
|
283 |
# print("models_list", models_list)
|
284 |
# print("model_name", model_name)
|
285 |
+
|
286 |
+
# 处理带斜杠的模型名
|
287 |
+
for provider in config['providers']:
|
288 |
+
if model in provider['model'].keys():
|
289 |
+
provider_rules.append(provider['provider'] + "/" + model)
|
290 |
+
|
291 |
# print("model", model)
|
292 |
+
if (model_name_split and model_name in models_list) or (model_name_split == "*" and model_name in models_list):
|
293 |
provider_rules.append(provider_name)
|
294 |
else:
|
295 |
for provider in config['providers']:
|
296 |
if model in provider['model'].keys():
|
297 |
+
provider_rules.append(provider['provider'] + "/" + model_name_split)
|
298 |
|
299 |
provider_list = []
|
300 |
# print("provider_rules", provider_rules)
|
|
|
303 |
# print("provider", provider, provider['provider'] == item, item)
|
304 |
if "/" in item:
|
305 |
if provider['provider'] == item.split("/")[0]:
|
306 |
+
if model_name in provider['model'].keys() and "/".join(item.split("/")[1:]) == model_name:
|
307 |
provider_list.append(provider)
|
308 |
elif provider['provider'] == item:
|
309 |
if model_name in provider['model'].keys():
|
|
|
428 |
|
429 |
rate_limiter = InMemoryRateLimiter()
|
430 |
|
431 |
+
async def get_user_rate_limit(api_index: str = None):
|
432 |
# 这里应该实现根据 token 获取用户速率限制的逻辑
|
433 |
# 示例: 返回 (次数, 秒数)
|
434 |
config = app.state.config
|
|
|
|
|
435 |
raw_rate_limit = safe_get(config, 'api_keys', api_index, "preferences", "RATE_LIMIT")
|
436 |
|
437 |
+
if not api_index or not raw_rate_limit:
|
438 |
return (60, 60)
|
439 |
|
440 |
rate_limit = parse_rate_limit(raw_rate_limit)
|
|
|
443 |
security = HTTPBearer()
|
444 |
async def rate_limit_dependency(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
445 |
token = credentials.credentials if credentials else None
|
446 |
+
api_list = app.state.api_list
|
447 |
+
try:
|
448 |
+
api_index = api_list.index(token)
|
449 |
+
except ValueError:
|
450 |
+
print("error: Invalid or missing API Key:", token)
|
451 |
+
api_index = None
|
452 |
+
token = None
|
453 |
+
limit, period = await get_user_rate_limit(api_index)
|
454 |
|
455 |
# 使用 IP 地址和 token(如果有)作为限制键
|
456 |
client_ip = request.client.host
|