yym68686 commited on
Commit
b812da1
·
1 Parent(s): 7477ff7

✨ Feature: Add feature: Add support for rate limiting.

Browse files
Files changed (3) hide show
  1. README.md +7 -1
  2. main.py +75 -7
  3. test/test_rate_limit.py +41 -0
README.md CHANGED
@@ -21,9 +21,14 @@
21
  - 同时支持 Anthropic、Gemini、Vertex API。Vertex 同时支持 Claude 和 Gemini API。
22
  - 支持 OpenAI、 Anthropic、Gemini、Vertex 原生 tool use 函数调用。
23
  - 支持 OpenAI、Anthropic、Gemini、Vertex 原生识图 API。
24
- - 支持四种负载均衡。1. 支持渠道级加权负载均衡,可以根据不同的渠道权重分配请求。默认不开启,需要配置渠道权重。2. 支持 Vertex 区域级负载均衡,支持 Vertex 高并发,最高可将 Gemini,Claude 并发提高 (API数量 * 区域数量) 倍。自动开启不需要额外配置。3. 除了 Vertex 区域级负载均衡,所有 API 均支持渠道级顺序负载均衡,提高沉浸式翻译体验。自动开启不需要额外配置。4. 支持单个渠道多个 API Key 自动开启 API key 级别的轮训负载均衡。
 
 
 
 
25
  - 支持自动重试,当一个 API 渠道响应失败时,自动重试下一个 API 渠道。
26
  - 支持细粒度的权限控制。支持使用通配符设置 API key 可用渠道的特定模型。
 
27
 
28
  ## Configuration
29
 
@@ -93,6 +98,7 @@ api_keys:
93
  preferences:
94
  USE_ROUND_ROBIN: true # 是否使用轮询负载均衡,true 为使用,false 为不使用,默认为 true。开启轮训后每次请求模型按照 model 配置的顺序依次请求。与 providers 里面原始的渠道顺序无关。因此你可以设置每个 API key 请求顺序不一样。
95
  AUTO_RETRY: true # 是否自动重试,自动重试下一个提供商,true 为自动重试,false 为不自动重试,默认为 true
 
96
 
97
  # 渠道级加权负载均衡配置示例
98
  - api: sk-KjjI60Yf0JFWtxxxxxxxxxxxxxxwmRWpWpQRo
 
21
  - 同时支持 Anthropic、Gemini、Vertex API。Vertex 同时支持 Claude 和 Gemini API。
22
  - 支持 OpenAI、 Anthropic、Gemini、Vertex 原生 tool use 函数调用。
23
  - 支持 OpenAI、Anthropic、Gemini、Vertex 原生识图 API。
24
+ - 支持四种负载均衡。
25
+ 1. 支持渠道级加权负载均衡,可以根据不同的渠道权重分配请求。默认不开启,需要配置渠道权重。
26
+ 2. 支持 Vertex 区域级负载均衡,支持 Vertex 高并发,最高可将 Gemini,Claude 并发提高 (API数量 * 区域数量) 倍。自动开启不需要额外配置。
27
+ 3. 除了 Vertex 区域级负载均衡,所有 API 均支持渠道级顺序负载均衡,提高沉浸式翻译体验。自动开启不需要额外配置。
28
+ 4. 支持单个渠道多个 API Key 自动开启 API key 级别的轮训负载均衡。
29
  - 支持自动重试,当一个 API 渠道响应失败时,自动重试下一个 API 渠道。
30
  - 支持细粒度的权限控制。支持使用通配符设置 API key 可用渠道的特定模型。
31
+ - 支持限流,可以设置每分钟最多请求次数,可以设置为整数,如 2/min,2 次每分钟、5/hour,5 次每小时、10/day,10 次每天,10/month,10 次每月,10/year,10 次每年。默认60/min。
32
 
33
  ## Configuration
34
 
 
98
  preferences:
99
  USE_ROUND_ROBIN: true # 是否使用轮询负载均衡,true 为使用,false 为不使用,默认为 true。开启轮训后每次请求模型按照 model 配置的顺序依次请求。与 providers 里面原始的渠道顺序无关。因此你可以设置每个 API key 请求顺序不一样。
100
  AUTO_RETRY: true # 是否自动重试,自动重试下一个提供商,true 为自动重试,false 为不自动重试,默认为 true
101
+ RATE_LIMIT: 2/min # 支持限流,每分钟最多请求次数,可以设置为整数,如 2/min,2 次每分钟、5/hour,5 次每小时、10/day,10 次每天,10/month,10 次每月,10/year,10 次每年。默认60/min,选填
102
 
103
  # 渠道级加权负载均衡配置示例
104
  - api: sk-KjjI60Yf0JFWtxxxxxxxxxxxxxxwmRWpWpQRo
main.py CHANGED
@@ -1,7 +1,9 @@
1
  from log_config import logger
2
 
 
3
  import httpx
4
  import secrets
 
5
  from contextlib import asynccontextmanager
6
 
7
  from fastapi.middleware.cors import CORSMiddleware
@@ -14,6 +16,7 @@ from request import get_payload
14
  from response import fetch_response, fetch_response_stream
15
  from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
16
 
 
17
  from typing import List, Dict, Union
18
  from urllib.parse import urlparse
19
 
@@ -374,8 +377,73 @@ class ModelRequestHandler:
374
 
375
  model_handler = ModelRequestHandler()
376
 
377
- # 安全性依赖
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  security = HTTPBearer()
 
 
 
 
 
 
 
 
 
 
 
379
 
380
  def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
381
  api_list = app.state.api_list
@@ -395,15 +463,15 @@ def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(sec
395
  raise HTTPException(status_code=403, detail="Permission denied")
396
  return token
397
 
398
- @app.post("/v1/chat/completions")
399
  async def request_model(request: Union[RequestModel, ImageGenerationRequest], token: str = Depends(verify_api_key)):
400
  return await model_handler.request_model(request, token)
401
 
402
- @app.options("/v1/chat/completions")
403
  async def options_handler():
404
  return JSONResponse(status_code=200, content={"detail": "OPTIONS allowed"})
405
 
406
- @app.get("/v1/models")
407
  async def list_models(token: str = Depends(verify_api_key)):
408
  models = post_all_models(token, app.state.config, app.state.api_list)
409
  return JSONResponse(content={
@@ -411,20 +479,20 @@ async def list_models(token: str = Depends(verify_api_key)):
411
  "data": models
412
  })
413
 
414
- @app.post("/v1/images/generations")
415
  async def images_generations(
416
  request: ImageGenerationRequest,
417
  token: str = Depends(verify_api_key)
418
  ):
419
  return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
420
 
421
- @app.get("/generate-api-key")
422
  def generate_api_key():
423
  api_key = "sk-" + secrets.token_urlsafe(36)
424
  return JSONResponse(content={"api_key": api_key})
425
 
426
  # 在 /stats 路由中返回成功和失败百分比
427
- @app.get("/stats")
428
  async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)):
429
  middleware = app.middleware_stack.app
430
  if isinstance(middleware, StatsMiddleware):
 
1
  from log_config import logger
2
 
3
+ import re
4
  import httpx
5
  import secrets
6
+ import time as time_module
7
  from contextlib import asynccontextmanager
8
 
9
  from fastapi.middleware.cors import CORSMiddleware
 
16
  from response import fetch_response, fetch_response_stream
17
  from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
18
 
19
+ from collections import defaultdict
20
  from typing import List, Dict, Union
21
  from urllib.parse import urlparse
22
 
 
377
 
378
  model_handler = ModelRequestHandler()
379
 
380
+ def parse_rate_limit(limit_string):
381
+ # 定义时间单位到秒的映射
382
+ time_units = {
383
+ 's': 1, 'sec': 1, 'second': 1,
384
+ 'm': 60, 'min': 60, 'minute': 60,
385
+ 'h': 3600, 'hr': 3600, 'hour': 3600,
386
+ 'd': 86400, 'day': 86400,
387
+ 'mo': 2592000, 'month': 2592000,
388
+ 'y': 31536000, 'year': 31536000
389
+ }
390
+
391
+ # 使用正则表达式匹配数字和单位
392
+ match = re.match(r'^(\d+)/(\w+)$', limit_string)
393
+ if not match:
394
+ raise ValueError(f"Invalid rate limit format: {limit_string}")
395
+
396
+ count, unit = match.groups()
397
+ count = int(count)
398
+
399
+ # 转换单位到秒
400
+ if unit not in time_units:
401
+ raise ValueError(f"Unknown time unit: {unit}")
402
+
403
+ seconds = time_units[unit]
404
+
405
+ return (count, seconds)
406
+
407
+ class InMemoryRateLimiter:
408
+ def __init__(self):
409
+ self.requests = defaultdict(list)
410
+
411
+ async def is_rate_limited(self, key: str, limit: int, period: int) -> bool:
412
+ now = time_module.time()
413
+ self.requests[key] = [req for req in self.requests[key] if req > now - period]
414
+ if len(self.requests[key]) >= limit:
415
+ return True
416
+ self.requests[key].append(now)
417
+ return False
418
+
419
+ rate_limiter = InMemoryRateLimiter()
420
+
421
+ async def get_user_rate_limit(token: str = None):
422
+ # 这里应该实现根据 token 获取用户速率限制的逻辑
423
+ # 示例: 返回 (次数, 秒数)
424
+ config = app.state.config
425
+ api_list = app.state.api_list
426
+ api_index = api_list.index(token)
427
+ raw_rate_limit = safe_get(config, 'api_keys', api_index, "preferences", "RATE_LIMIT")
428
+
429
+ if not token or not raw_rate_limit:
430
+ return (60, 60)
431
+
432
+ rate_limit = parse_rate_limit(raw_rate_limit)
433
+ return rate_limit
434
+
435
  security = HTTPBearer()
436
+ async def rate_limit_dependency(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
437
+ token = credentials.credentials if credentials else None
438
+ # print("token", token)
439
+ limit, period = await get_user_rate_limit(token)
440
+
441
+ # 使用 IP 地址和 token(如果有)作为限制键
442
+ client_ip = request.client.host
443
+ rate_limit_key = f"{client_ip}:{token}" if token else client_ip
444
+
445
+ if await rate_limiter.is_rate_limited(rate_limit_key, limit, period):
446
+ raise HTTPException(status_code=429, detail="Too many requests")
447
 
448
  def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
449
  api_list = app.state.api_list
 
463
  raise HTTPException(status_code=403, detail="Permission denied")
464
  return token
465
 
466
+ @app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
467
  async def request_model(request: Union[RequestModel, ImageGenerationRequest], token: str = Depends(verify_api_key)):
468
  return await model_handler.request_model(request, token)
469
 
470
+ @app.options("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
471
  async def options_handler():
472
  return JSONResponse(status_code=200, content={"detail": "OPTIONS allowed"})
473
 
474
+ @app.get("/v1/models", dependencies=[Depends(rate_limit_dependency)])
475
  async def list_models(token: str = Depends(verify_api_key)):
476
  models = post_all_models(token, app.state.config, app.state.api_list)
477
  return JSONResponse(content={
 
479
  "data": models
480
  })
481
 
482
+ @app.post("/v1/images/generations", dependencies=[Depends(rate_limit_dependency)])
483
  async def images_generations(
484
  request: ImageGenerationRequest,
485
  token: str = Depends(verify_api_key)
486
  ):
487
  return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
488
 
489
+ @app.get("/generate-api-key", dependencies=[Depends(rate_limit_dependency)])
490
  def generate_api_key():
491
  api_key = "sk-" + secrets.token_urlsafe(36)
492
  return JSONResponse(content={"api_key": api_key})
493
 
494
  # 在 /stats 路由中返回成功和失败百分比
495
+ @app.get("/stats", dependencies=[Depends(rate_limit_dependency)])
496
  async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)):
497
  middleware = app.middleware_stack.app
498
  if isinstance(middleware, StatsMiddleware):
test/test_rate_limit.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ def parse_rate_limit(limit_string):
4
+ # 定义时间单位到秒的映射
5
+ time_units = {
6
+ 's': 1, 'sec': 1, 'second': 1,
7
+ 'm': 60, 'min': 60, 'minute': 60,
8
+ 'h': 3600, 'hr': 3600, 'hour': 3600,
9
+ 'd': 86400, 'day': 86400,
10
+ 'mo': 2592000, 'month': 2592000,
11
+ 'y': 31536000, 'year': 31536000
12
+ }
13
+
14
+ # 使用正则表达式匹配数字和单位
15
+ match = re.match(r'^(\d+)/(\w+)$', limit_string)
16
+ if not match:
17
+ raise ValueError(f"Invalid rate limit format: {limit_string}")
18
+
19
+ count, unit = match.groups()
20
+ count = int(count)
21
+
22
+ # 转换单位到秒
23
+ if unit not in time_units:
24
+ raise ValueError(f"Unknown time unit: {unit}")
25
+
26
+ seconds = time_units[unit]
27
+
28
+ return (count, seconds)
29
+
30
+ # 测试函数
31
+ test_cases = [
32
+ "2/min", "5/hour", "10/day", "1/second", "3/mo", "1/year",
33
+ "20/s", "15/m", "8/h", "100/d", "50/mo", "2/y"
34
+ ]
35
+
36
+ for case in test_cases:
37
+ try:
38
+ result = parse_rate_limit(case)
39
+ print(f"{case} => {result}")
40
+ except ValueError as e:
41
+ print(f"Error parsing {case}: {str(e)}")