✨ Feature: Add feature: Add support for rate limiting.
Browse files- README.md +7 -1
- main.py +75 -7
- test/test_rate_limit.py +41 -0
README.md
CHANGED
@@ -21,9 +21,14 @@
|
|
21 |
- 同时支持 Anthropic、Gemini、Vertex API。Vertex 同时支持 Claude 和 Gemini API。
|
22 |
- 支持 OpenAI、 Anthropic、Gemini、Vertex 原生 tool use 函数调用。
|
23 |
- 支持 OpenAI、Anthropic、Gemini、Vertex 原生识图 API。
|
24 |
-
- 支持四种负载均衡。
|
|
|
|
|
|
|
|
|
25 |
- 支持自动重试,当一个 API 渠道响应失败时,自动重试下一个 API 渠道。
|
26 |
- 支持细粒度的权限控制。支持使用通配符设置 API key 可用渠道的特定模型。
|
|
|
27 |
|
28 |
## Configuration
|
29 |
|
@@ -93,6 +98,7 @@ api_keys:
|
|
93 |
preferences:
|
94 |
USE_ROUND_ROBIN: true # 是否使用轮询负载均衡,true 为使用,false 为不使用,默认为 true。开启轮训后每次请求模型按照 model 配置的顺序依次请求。与 providers 里面原始的渠道顺序无关。因此你可以设置每个 API key 请求顺序不一样。
|
95 |
AUTO_RETRY: true # 是否自动重试,自动重试下一个提供商,true 为自动重试,false 为不自动重试,默认为 true
|
|
|
96 |
|
97 |
# 渠道级加权负载均衡配置示例
|
98 |
- api: sk-KjjI60Yf0JFWtxxxxxxxxxxxxxxwmRWpWpQRo
|
|
|
21 |
- 同时支持 Anthropic、Gemini、Vertex API。Vertex 同时支持 Claude 和 Gemini API。
|
22 |
- 支持 OpenAI、 Anthropic、Gemini、Vertex 原生 tool use 函数调用。
|
23 |
- 支持 OpenAI、Anthropic、Gemini、Vertex 原生识图 API。
|
24 |
+
- 支持四种负载均衡。
|
25 |
+
1. 支持渠道级加权负载均衡,可以根据不同的渠道权重分配请求。默认不开启,需要配置渠道权重。
|
26 |
+
2. 支持 Vertex 区域级负载均衡,支持 Vertex 高并发,最高可将 Gemini,Claude 并发提高 (API数量 * 区域数量) 倍。自动开启不需要额外配置。
|
27 |
+
3. 除了 Vertex 区域级负载均衡,所有 API 均支持渠道级顺序负载均衡,提高沉浸式翻译体验。自动开启不需要额外配置。
|
28 |
+
4. 支持单个渠道多个 API Key 自动开启 API key 级别的轮训负载均衡。
|
29 |
- 支持自动重试,当一个 API 渠道响应失败时,自动重试下一个 API 渠道。
|
30 |
- 支持细粒度的权限控制。支持使用通配符设置 API key 可用渠道的特定模型。
|
31 |
+
- 支持限流,可以设置每分钟最多请求次数,可以设置为整数,如 2/min,2 次每分钟、5/hour,5 次每小时、10/day,10 次每天,10/month,10 次每月,10/year,10 次每年。默认60/min。
|
32 |
|
33 |
## Configuration
|
34 |
|
|
|
98 |
preferences:
|
99 |
USE_ROUND_ROBIN: true # 是否使用轮询负载均衡,true 为使用,false 为不使用,默认为 true。开启轮训后每次请求模型按照 model 配置的顺序依次请求。与 providers 里面原始的渠道顺序无关。因此你可以设置每个 API key 请求顺序不一样。
|
100 |
AUTO_RETRY: true # 是否自动重试,自动重试下一个提供商,true 为自动重试,false 为不自动重试,默认为 true
|
101 |
+
RATE_LIMIT: 2/min # 支持限流,每分钟最多请求次数,可以设置为整数,如 2/min,2 次每分钟、5/hour,5 次每小时、10/day,10 次每天,10/month,10 次每月,10/year,10 次每年。默认60/min,选填
|
102 |
|
103 |
# 渠道级加权负载均衡配置示例
|
104 |
- api: sk-KjjI60Yf0JFWtxxxxxxxxxxxxxxwmRWpWpQRo
|
main.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
from log_config import logger
|
2 |
|
|
|
3 |
import httpx
|
4 |
import secrets
|
|
|
5 |
from contextlib import asynccontextmanager
|
6 |
|
7 |
from fastapi.middleware.cors import CORSMiddleware
|
@@ -14,6 +16,7 @@ from request import get_payload
|
|
14 |
from response import fetch_response, fetch_response_stream
|
15 |
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
|
16 |
|
|
|
17 |
from typing import List, Dict, Union
|
18 |
from urllib.parse import urlparse
|
19 |
|
@@ -374,8 +377,73 @@ class ModelRequestHandler:
|
|
374 |
|
375 |
model_handler = ModelRequestHandler()
|
376 |
|
377 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
378 |
security = HTTPBearer()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
|
380 |
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
381 |
api_list = app.state.api_list
|
@@ -395,15 +463,15 @@ def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(sec
|
|
395 |
raise HTTPException(status_code=403, detail="Permission denied")
|
396 |
return token
|
397 |
|
398 |
-
@app.post("/v1/chat/completions")
|
399 |
async def request_model(request: Union[RequestModel, ImageGenerationRequest], token: str = Depends(verify_api_key)):
|
400 |
return await model_handler.request_model(request, token)
|
401 |
|
402 |
-
@app.options("/v1/chat/completions")
|
403 |
async def options_handler():
|
404 |
return JSONResponse(status_code=200, content={"detail": "OPTIONS allowed"})
|
405 |
|
406 |
-
@app.get("/v1/models")
|
407 |
async def list_models(token: str = Depends(verify_api_key)):
|
408 |
models = post_all_models(token, app.state.config, app.state.api_list)
|
409 |
return JSONResponse(content={
|
@@ -411,20 +479,20 @@ async def list_models(token: str = Depends(verify_api_key)):
|
|
411 |
"data": models
|
412 |
})
|
413 |
|
414 |
-
@app.post("/v1/images/generations")
|
415 |
async def images_generations(
|
416 |
request: ImageGenerationRequest,
|
417 |
token: str = Depends(verify_api_key)
|
418 |
):
|
419 |
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
|
420 |
|
421 |
-
@app.get("/generate-api-key")
|
422 |
def generate_api_key():
|
423 |
api_key = "sk-" + secrets.token_urlsafe(36)
|
424 |
return JSONResponse(content={"api_key": api_key})
|
425 |
|
426 |
# 在 /stats 路由中返回成功和失败百分比
|
427 |
-
@app.get("/stats")
|
428 |
async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)):
|
429 |
middleware = app.middleware_stack.app
|
430 |
if isinstance(middleware, StatsMiddleware):
|
|
|
1 |
from log_config import logger
|
2 |
|
3 |
+
import re
|
4 |
import httpx
|
5 |
import secrets
|
6 |
+
import time as time_module
|
7 |
from contextlib import asynccontextmanager
|
8 |
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
16 |
from response import fetch_response, fetch_response_stream
|
17 |
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
|
18 |
|
19 |
+
from collections import defaultdict
|
20 |
from typing import List, Dict, Union
|
21 |
from urllib.parse import urlparse
|
22 |
|
|
|
377 |
|
378 |
model_handler = ModelRequestHandler()
|
379 |
|
380 |
+
def parse_rate_limit(limit_string):
|
381 |
+
# 定义时间单位到秒的映射
|
382 |
+
time_units = {
|
383 |
+
's': 1, 'sec': 1, 'second': 1,
|
384 |
+
'm': 60, 'min': 60, 'minute': 60,
|
385 |
+
'h': 3600, 'hr': 3600, 'hour': 3600,
|
386 |
+
'd': 86400, 'day': 86400,
|
387 |
+
'mo': 2592000, 'month': 2592000,
|
388 |
+
'y': 31536000, 'year': 31536000
|
389 |
+
}
|
390 |
+
|
391 |
+
# 使用正则表达式匹配数字和单位
|
392 |
+
match = re.match(r'^(\d+)/(\w+)$', limit_string)
|
393 |
+
if not match:
|
394 |
+
raise ValueError(f"Invalid rate limit format: {limit_string}")
|
395 |
+
|
396 |
+
count, unit = match.groups()
|
397 |
+
count = int(count)
|
398 |
+
|
399 |
+
# 转换单位到秒
|
400 |
+
if unit not in time_units:
|
401 |
+
raise ValueError(f"Unknown time unit: {unit}")
|
402 |
+
|
403 |
+
seconds = time_units[unit]
|
404 |
+
|
405 |
+
return (count, seconds)
|
406 |
+
|
407 |
+
class InMemoryRateLimiter:
|
408 |
+
def __init__(self):
|
409 |
+
self.requests = defaultdict(list)
|
410 |
+
|
411 |
+
async def is_rate_limited(self, key: str, limit: int, period: int) -> bool:
|
412 |
+
now = time_module.time()
|
413 |
+
self.requests[key] = [req for req in self.requests[key] if req > now - period]
|
414 |
+
if len(self.requests[key]) >= limit:
|
415 |
+
return True
|
416 |
+
self.requests[key].append(now)
|
417 |
+
return False
|
418 |
+
|
419 |
+
rate_limiter = InMemoryRateLimiter()
|
420 |
+
|
421 |
+
async def get_user_rate_limit(token: str = None):
|
422 |
+
# 这里应该实现根据 token 获取用户速率限制的逻辑
|
423 |
+
# 示例: 返回 (次数, 秒数)
|
424 |
+
config = app.state.config
|
425 |
+
api_list = app.state.api_list
|
426 |
+
api_index = api_list.index(token)
|
427 |
+
raw_rate_limit = safe_get(config, 'api_keys', api_index, "preferences", "RATE_LIMIT")
|
428 |
+
|
429 |
+
if not token or not raw_rate_limit:
|
430 |
+
return (60, 60)
|
431 |
+
|
432 |
+
rate_limit = parse_rate_limit(raw_rate_limit)
|
433 |
+
return rate_limit
|
434 |
+
|
435 |
security = HTTPBearer()
|
436 |
+
async def rate_limit_dependency(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
437 |
+
token = credentials.credentials if credentials else None
|
438 |
+
# print("token", token)
|
439 |
+
limit, period = await get_user_rate_limit(token)
|
440 |
+
|
441 |
+
# 使用 IP 地址和 token(如果有)作为限制键
|
442 |
+
client_ip = request.client.host
|
443 |
+
rate_limit_key = f"{client_ip}:{token}" if token else client_ip
|
444 |
+
|
445 |
+
if await rate_limiter.is_rate_limited(rate_limit_key, limit, period):
|
446 |
+
raise HTTPException(status_code=429, detail="Too many requests")
|
447 |
|
448 |
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
449 |
api_list = app.state.api_list
|
|
|
463 |
raise HTTPException(status_code=403, detail="Permission denied")
|
464 |
return token
|
465 |
|
466 |
+
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
|
467 |
async def request_model(request: Union[RequestModel, ImageGenerationRequest], token: str = Depends(verify_api_key)):
|
468 |
return await model_handler.request_model(request, token)
|
469 |
|
470 |
+
@app.options("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
|
471 |
async def options_handler():
|
472 |
return JSONResponse(status_code=200, content={"detail": "OPTIONS allowed"})
|
473 |
|
474 |
+
@app.get("/v1/models", dependencies=[Depends(rate_limit_dependency)])
|
475 |
async def list_models(token: str = Depends(verify_api_key)):
|
476 |
models = post_all_models(token, app.state.config, app.state.api_list)
|
477 |
return JSONResponse(content={
|
|
|
479 |
"data": models
|
480 |
})
|
481 |
|
482 |
+
@app.post("/v1/images/generations", dependencies=[Depends(rate_limit_dependency)])
|
483 |
async def images_generations(
|
484 |
request: ImageGenerationRequest,
|
485 |
token: str = Depends(verify_api_key)
|
486 |
):
|
487 |
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
|
488 |
|
489 |
+
@app.get("/generate-api-key", dependencies=[Depends(rate_limit_dependency)])
|
490 |
def generate_api_key():
|
491 |
api_key = "sk-" + secrets.token_urlsafe(36)
|
492 |
return JSONResponse(content={"api_key": api_key})
|
493 |
|
494 |
# 在 /stats 路由中返回成功和失败百分比
|
495 |
+
@app.get("/stats", dependencies=[Depends(rate_limit_dependency)])
|
496 |
async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)):
|
497 |
middleware = app.middleware_stack.app
|
498 |
if isinstance(middleware, StatsMiddleware):
|
test/test_rate_limit.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
def parse_rate_limit(limit_string):
|
4 |
+
# 定义时间单位到秒的映射
|
5 |
+
time_units = {
|
6 |
+
's': 1, 'sec': 1, 'second': 1,
|
7 |
+
'm': 60, 'min': 60, 'minute': 60,
|
8 |
+
'h': 3600, 'hr': 3600, 'hour': 3600,
|
9 |
+
'd': 86400, 'day': 86400,
|
10 |
+
'mo': 2592000, 'month': 2592000,
|
11 |
+
'y': 31536000, 'year': 31536000
|
12 |
+
}
|
13 |
+
|
14 |
+
# 使用正则表达式匹配数字和单位
|
15 |
+
match = re.match(r'^(\d+)/(\w+)$', limit_string)
|
16 |
+
if not match:
|
17 |
+
raise ValueError(f"Invalid rate limit format: {limit_string}")
|
18 |
+
|
19 |
+
count, unit = match.groups()
|
20 |
+
count = int(count)
|
21 |
+
|
22 |
+
# 转换单位到秒
|
23 |
+
if unit not in time_units:
|
24 |
+
raise ValueError(f"Unknown time unit: {unit}")
|
25 |
+
|
26 |
+
seconds = time_units[unit]
|
27 |
+
|
28 |
+
return (count, seconds)
|
29 |
+
|
30 |
+
# 测试函数
|
31 |
+
test_cases = [
|
32 |
+
"2/min", "5/hour", "10/day", "1/second", "3/mo", "1/year",
|
33 |
+
"20/s", "15/m", "8/h", "100/d", "50/mo", "2/y"
|
34 |
+
]
|
35 |
+
|
36 |
+
for case in test_cases:
|
37 |
+
try:
|
38 |
+
result = parse_rate_limit(case)
|
39 |
+
print(f"{case} => {result}")
|
40 |
+
except ValueError as e:
|
41 |
+
print(f"Error parsing {case}: {str(e)}")
|