yym68686 commited on
Commit
76819d6
·
1 Parent(s): 5f2aeb0

🐛 Bug: 1. Fix the bug where the closed client cannot be found when closing the request client.

Browse files

2. Fix the bug where, when there is only one provider but multiple API keys, an error prevents switching to the next API key.

✨ Feature: Add feature: Add support for API key request rate limiting, add support for automatically cooling down API key upon receiving a 429 status code.

📖 Docs: Update documentation

Files changed (5) hide show
  1. .github/workflows/main.yml +1 -0
  2. README.md +42 -1
  3. README_CN.md +42 -1
  4. main.py +28 -59
  5. utils.py +124 -5
.github/workflows/main.yml CHANGED
@@ -21,6 +21,7 @@ on:
21
  jobs:
22
  build-and-push:
23
  runs-on: ubuntu-latest
 
24
 
25
  steps:
26
  - name: Checkout repository
 
21
  jobs:
22
  build-and-push:
23
  runs-on: ubuntu-latest
24
+ if: ${{ secrets.DOCKER_HUB_USERNAME != '' && secrets.DOCKER_HUB_ACCESS_TOKEN != '' }}
25
 
26
  steps:
27
  - name: Checkout repository
README.md CHANGED
@@ -80,12 +80,18 @@ providers:
80
 
81
  - provider: gemini
82
  base_url: https://generativelanguage.googleapis.com/v1beta # base_url supports v1beta/v1, only for Gemini model use, required
83
- api: AIzaSyAN2k6IRdgw
 
 
 
84
  model:
85
  - gemini-1.5-pro
86
  - gemini-1.5-flash-exp-0827: gemini-1.5-flash # After renaming, the original model name gemini-1.5-flash-exp-0827 cannot be used, if you want to use the original name, you can add the original name in the model, just add the line below to use the original name
87
  - gemini-1.5-flash-exp-0827 # Add this line, both gemini-1.5-flash-exp-0827 and gemini-1.5-flash can be requested
88
  tools: true
 
 
 
89
 
90
  - provider: vertex
91
  project_id: gen-lang-client-xxxxxxxxxxxxxx # Description: Your Google Cloud project ID. Format: String, usually composed of lowercase letters, numbers, and hyphens. How to obtain: You can find your project ID in the project selector of the Google Cloud Console.
@@ -338,6 +344,41 @@ Thank you for your support!
338
 
339
  Setting ENABLE_MODERATION to false will fix this issue. When ENABLE_MODERATION is true, the API must be able to use the text-moderation-latest model, and if you have not provided text-moderation-latest in the provider model settings, an error will occur indicating that the model cannot be found.
340
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  ## ⭐ Star History
342
 
343
  <a href="https://github.com/yym68686/uni-api/stargazers">
 
80
 
81
  - provider: gemini
82
  base_url: https://generativelanguage.googleapis.com/v1beta # base_url supports v1beta/v1, only for Gemini model use, required
83
+ api: # Supports multiple API Keys, multiple keys automatically enable polling load balancing, at least one key, required
84
+ - AIzaSyAN2k6IRdgw123
85
+ - AIzaSyAN2k6IRdgw456
86
+ - AIzaSyAN2k6IRdgw789
87
  model:
88
  - gemini-1.5-pro
89
  - gemini-1.5-flash-exp-0827: gemini-1.5-flash # After renaming, the original model name gemini-1.5-flash-exp-0827 cannot be used, if you want to use the original name, you can add the original name in the model, just add the line below to use the original name
90
  - gemini-1.5-flash-exp-0827 # Add this line, both gemini-1.5-flash-exp-0827 and gemini-1.5-flash can be requested
91
  tools: true
92
+ preferences:
93
+ API_KEY_RATE_LIMIT: 15/min # Each API Key can request up to 15 times per minute, optional. The default is 999999/min.
94
+ API_KEY_COOLDOWN_PERIOD: 60 # Each API Key will be cooled down for 60 seconds after encountering a 429 error. Optional, the default is 60 seconds.
95
 
96
  - provider: vertex
97
  project_id: gen-lang-client-xxxxxxxxxxxxxx # Description: Your Google Cloud project ID. Format: String, usually composed of lowercase letters, numbers, and hyphens. How to obtain: You can find your project ID in the project selector of the Google Cloud Console.
 
344
 
345
  Setting ENABLE_MODERATION to false will fix this issue. When ENABLE_MODERATION is true, the API must be able to use the text-moderation-latest model, and if you have not provided text-moderation-latest in the provider model settings, an error will occur indicating that the model cannot be found.
346
 
347
+ - How to prioritize requests for a specific channel, how to set the priority of a channel?
348
+
349
+ Directly set the channel order in the api_keys. No other settings are required. Sample configuration file:
350
+
351
+ ```yaml
352
+ providers:
353
+ - provider: ai1
354
+ base_url: https://xxx/v1/chat/completions
355
+ api: sk-xxx
356
+
357
+ - provider: ai2
358
+ base_url: https://xxx/v1/chat/completions
359
+ api: sk-xxx
360
+
361
+ api_keys:
362
+ - api: sk-1234
363
+ model:
364
+ - ai2/*
365
+ - ai1/*
366
+ ```
367
+
368
+ In this way, request ai2 first, and if it fails, request ai1.
369
+
370
+ - What is the behavior behind various scheduling algorithms? For example, fixed_priority, weighted_round_robin, lottery, random, round_robin?
371
+
372
+ All scheduling algorithms need to be enabled by setting api_keys.(api).preferences.SCHEDULING_ALGORITHM in the configuration file to any of the values: fixed_priority, weighted_round_robin, lottery, random, round_robin.
373
+
374
+ 1. fixed_priority: Fixed priority scheduling. All requests are always executed by the channel of the model that first has a user request. In case of an error, it will switch to the next channel. This is the default scheduling algorithm.
375
+
376
+ 2. weighted_round_robin: Weighted round-robin load balancing, requests channels with the user's requested model according to the weight order set in the configuration file api_keys.(api).model.
377
+
378
+ 3. lottery: Draw round-robin load balancing, randomly request the channel of the model with user requests according to the weight set in the configuration file api_keys.(api).model.
379
+
380
+ 4. round_robin: Round-robin load balancing, requests the channel that owns the model requested by the user according to the configuration order in the configuration file api_keys.(api).model. You can check the previous question on how to set the priority of channels.
381
+
382
  ## ⭐ Star History
383
 
384
  <a href="https://github.com/yym68686/uni-api/stargazers">
README_CN.md CHANGED
@@ -80,12 +80,18 @@ providers:
80
 
81
  - provider: gemini
82
  base_url: https://generativelanguage.googleapis.com/v1beta # base_url 支持 v1beta/v1, 仅供 Gemini 模型使用,必填
83
- api: AIzaSyAN2k6IRdgw
 
 
 
84
  model:
85
  - gemini-1.5-pro
86
  - gemini-1.5-flash-exp-0827: gemini-1.5-flash # 重命名后,原来的模型名字 gemini-1.5-flash-exp-0827 无法使用,如果要使用原来的名字,可以在 model 中添加原来的名字,只要加上下面一行就可以使用原来的名字了
87
  - gemini-1.5-flash-exp-0827 # 加上这一行,gemini-1.5-flash-exp-0827 和 gemini-1.5-flash 都可以被请求
88
  tools: true
 
 
 
89
 
90
  - provider: vertex
91
  project_id: gen-lang-client-xxxxxxxxxxxxxx # 描述: 您的Google Cloud项目ID。格式: 字符串,通常由小写字母、数字和连字符组成。获取方式: 在Google Cloud Console的项目选择器中可以找到您的项目ID。
@@ -338,6 +344,41 @@ curl -X POST http://127.0.0.1:8000/v1/chat/completions \
338
 
339
  将 ENABLE_MODERATION 设置为 false 将修复这个问题。当 ENABLE_MODERATION 为 true 时,API 必须能够使用 text-moderation-latest 模型,如果你没有在提供商模型设置里面提供 text-moderation-latest,将会报错找不到模型。
340
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  ## ⭐ Star 历史
342
 
343
  <a href="https://github.com/yym68686/uni-api/stargazers">
 
80
 
81
  - provider: gemini
82
  base_url: https://generativelanguage.googleapis.com/v1beta # base_url 支持 v1beta/v1, 仅供 Gemini 模型使用,必填
83
+ api: # 支持多个 API Key,多个 key 自动开启轮训负载均衡,至少一个 key,必填
84
+ - AIzaSyAN2k6IRdgw123
85
+ - AIzaSyAN2k6IRdgw456
86
+ - AIzaSyAN2k6IRdgw789
87
  model:
88
  - gemini-1.5-pro
89
  - gemini-1.5-flash-exp-0827: gemini-1.5-flash # 重命名后,原来的模型名字 gemini-1.5-flash-exp-0827 无法使用,如果要使用原来的名字,可以在 model 中添加原来的名字,只要加上下面一行就可以使用原来的名字了
90
  - gemini-1.5-flash-exp-0827 # 加上这一行,gemini-1.5-flash-exp-0827 和 gemini-1.5-flash 都可以被请求
91
  tools: true
92
+ preferences:
93
+ API_KEY_RATE_LIMIT: 15/min # 每个 API Key 每分钟最多请求次数,选填。默认为 999999/min
94
+ API_KEY_COOLDOWN_PERIOD: 60 # 每个 API Key 遭遇 429 错误后的冷却时间,单位为秒,选填。默认为 60 秒
95
 
96
  - provider: vertex
97
  project_id: gen-lang-client-xxxxxxxxxxxxxx # 描述: 您的Google Cloud项目ID。格式: 字符串,通常由小写字母、数字和连字符组成。获取方式: 在Google Cloud Console的项目选择器中可以找到您的项目ID。
 
344
 
345
  将 ENABLE_MODERATION 设置为 false 将修复这个问题。当 ENABLE_MODERATION 为 true 时,API 必须能够使用 text-moderation-latest 模型,如果你没有在提供商模型设置里面提供 text-moderation-latest,将会报错找不到模型。
346
 
347
+ - 怎么优先请求某个渠道,怎么设置渠道的优先级?
348
+
349
+ 直接在api_keys里面通过设置渠道顺序即可。不需要做其他设置,示例配置文件:
350
+
351
+ ```yaml
352
+ providers:
353
+ - provider: ai1
354
+ base_url: https://xxx/v1/chat/completions
355
+ api: sk-xxx
356
+
357
+ - provider: ai2
358
+ base_url: https://xxx/v1/chat/completions
359
+ api: sk-xxx
360
+
361
+ api_keys:
362
+ - api: sk-1234
363
+ model:
364
+ - ai2/*
365
+ - ai1/*
366
+ ```
367
+
368
+ 这样设置则先请求 ai2,失败后请求 ai1。
369
+
370
+ - 各种调度算法背后的行为是怎样的?比如 fixed_priority,weighted_round_robin,lottery,random,round_robin?
371
+
372
+ 所有调度算法需要通过在配置文件的 api_keys.(api).preferences.SCHEDULING_ALGORITHM 设置为 fixed_priority,weighted_round_robin,lottery,random,round_robin 中的任意值来开启。
373
+
374
+ 1. fixed_priority:固定优先级调度。所有请求永远执行第一个拥有用户请求的模型的渠道。报错时,会切换下一个渠道。这是默认的调度算法。
375
+
376
+ 2. weighted_round_robin:加权轮训负载均衡,按照配置文件 api_keys.(api).model 设定的权重顺序请求拥有用户请求的模型的渠道。
377
+
378
+ 3. lottery:抽奖轮训负载均衡,按照配置文件 api_keys.(api).model 设置的权重随机请求拥有用户请求的模型的渠道。
379
+
380
+ 4. round_robin:轮训负载均衡,按照配置文件 api_keys.(api).model 的配置顺序请求拥有用户请求的模型的渠道。可以查看上一个问题,如何设置渠道的优先级。
381
+
382
  ## ⭐ Star 历史
383
 
384
  <a href="https://github.com/yym68686/uni-api/stargazers">
main.py CHANGED
@@ -1,6 +1,5 @@
1
  from log_config import logger
2
 
3
- import re
4
  import copy
5
  import httpx
6
  import secrets
@@ -19,7 +18,18 @@ from fastapi.exceptions import RequestValidationError
19
  from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest, EmbeddingRequest
20
  from request import get_payload
21
  from response import fetch_response, fetch_response_stream
22
- from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder, get_model_dict, save_api_yaml
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  from collections import defaultdict
25
  from typing import List, Dict, Union
@@ -542,6 +552,7 @@ class ClientManager:
542
  @asynccontextmanager
543
  async def get_client(self, timeout_value):
544
  # 直接获取或创建客户端,不使用锁
 
545
  if timeout_value not in self.clients:
546
  timeout = httpx.Timeout(
547
  connect=15.0,
@@ -558,8 +569,10 @@ class ClientManager:
558
  try:
559
  yield self.clients[timeout_value]
560
  except Exception as e:
561
- await self.clients[timeout_value].aclose()
562
- del self.clients[timeout_value]
 
 
563
  raise e
564
 
565
  async def close(self):
@@ -955,8 +968,13 @@ class ModelRequestHandler:
955
  auto_retry = safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY", default=True)
956
 
957
  index = 0
 
 
 
 
 
958
  while True:
959
- if index >= num_matching_providers:
960
  break
961
  current_index = (start_index + index) % num_matching_providers
962
  index += 1
@@ -995,6 +1013,10 @@ class ModelRequestHandler:
995
  num_matching_providers = len(matching_providers)
996
  index = 0
997
 
 
 
 
 
998
  logger.error(f"Error {status_code} with provider {channel_id}: {error_message}")
999
  if is_debug:
1000
  import traceback
@@ -1012,59 +1034,6 @@ class ModelRequestHandler:
1012
 
1013
  model_handler = ModelRequestHandler()
1014
 
1015
- def parse_rate_limit(limit_string):
1016
- # 定义时间单位到秒的映射
1017
- time_units = {
1018
- 's': 1, 'sec': 1, 'second': 1,
1019
- 'm': 60, 'min': 60, 'minute': 60,
1020
- 'h': 3600, 'hr': 3600, 'hour': 3600,
1021
- 'd': 86400, 'day': 86400,
1022
- 'mo': 2592000, 'month': 2592000,
1023
- 'y': 31536000, 'year': 31536000
1024
- }
1025
-
1026
- # 使用正则表达式匹配数字和单位
1027
- match = re.match(r'^(\d+)/(\w+)$', limit_string)
1028
- if not match:
1029
- raise ValueError(f"Invalid rate limit format: {limit_string}")
1030
-
1031
- count, unit = match.groups()
1032
- count = int(count)
1033
-
1034
- # 转换单位到秒
1035
- if unit not in time_units:
1036
- raise ValueError(f"Unknown time unit: {unit}")
1037
-
1038
- seconds = time_units[unit]
1039
-
1040
- return (count, seconds)
1041
-
1042
- class InMemoryRateLimiter:
1043
- def __init__(self):
1044
- self.requests = defaultdict(list)
1045
-
1046
- async def is_rate_limited(self, key: str, limit: int, period: int) -> bool:
1047
- now = time()
1048
- self.requests[key] = [req for req in self.requests[key] if req > now - period]
1049
- if len(self.requests[key]) >= limit:
1050
- return True
1051
- self.requests[key].append(now)
1052
- return False
1053
-
1054
- rate_limiter = InMemoryRateLimiter()
1055
-
1056
- async def get_user_rate_limit(api_index: str = None):
1057
- # 这里应该实现根据 token 获取用户速率限制的逻辑
1058
- # 示例: 返回 (次数, 秒数)
1059
- config = app.state.config
1060
- raw_rate_limit = safe_get(config, 'api_keys', api_index, "preferences", "RATE_LIMIT")
1061
-
1062
- if not api_index or not raw_rate_limit:
1063
- return (30, 60)
1064
-
1065
- rate_limit = parse_rate_limit(raw_rate_limit)
1066
- return rate_limit
1067
-
1068
  security = HTTPBearer()
1069
 
1070
  async def rate_limit_dependency(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
@@ -1076,7 +1045,7 @@ async def rate_limit_dependency(request: Request, credentials: HTTPAuthorization
1076
  print("error: Invalid or missing API Key:", token)
1077
  api_index = None
1078
  token = None
1079
- limit, period = await get_user_rate_limit(api_index)
1080
 
1081
  # 使用 IP 地址和 token(如果有)作为限制键
1082
  client_ip = request.client.host
 
1
  from log_config import logger
2
 
 
3
  import copy
4
  import httpx
5
  import secrets
 
18
  from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest, EmbeddingRequest
19
  from request import get_payload
20
  from response import fetch_response, fetch_response_stream
21
+ from utils import (
22
+ safe_get,
23
+ load_config,
24
+ save_api_yaml,
25
+ get_model_dict,
26
+ post_all_models,
27
+ get_user_rate_limit,
28
+ circular_list_encoder,
29
+ error_handling_wrapper,
30
+ rate_limiter,
31
+ provider_api_circular_list,
32
+ )
33
 
34
  from collections import defaultdict
35
  from typing import List, Dict, Union
 
552
  @asynccontextmanager
553
  async def get_client(self, timeout_value):
554
  # 直接获取或创建客户端,不使用锁
555
+ timeout_value = int(timeout_value)
556
  if timeout_value not in self.clients:
557
  timeout = httpx.Timeout(
558
  connect=15.0,
 
569
  try:
570
  yield self.clients[timeout_value]
571
  except Exception as e:
572
+ if timeout_value in self.clients:
573
+ tmp_client = self.clients[timeout_value]
574
+ del self.clients[timeout_value] # 先删除引用
575
+ await tmp_client.aclose() # 然后关闭客户端
576
  raise e
577
 
578
  async def close(self):
 
968
  auto_retry = safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY", default=True)
969
 
970
  index = 0
971
+ if num_matching_providers == 1 and (count := provider_api_circular_list[matching_providers[0]['provider']].get_items_count()) > 1:
972
+ retry_count = count
973
+ else:
974
+ retry_count = 0
975
+
976
  while True:
977
+ if index >= num_matching_providers + retry_count:
978
  break
979
  current_index = (start_index + index) % num_matching_providers
980
  index += 1
 
1013
  num_matching_providers = len(matching_providers)
1014
  index = 0
1015
 
1016
+ if status_code == 429:
1017
+ current_api = await provider_api_circular_list[channel_id].after_next_current()
1018
+ await provider_api_circular_list[channel_id].set_cooling(current_api, cooldown_period=safe_get(provider, "preferences", "API_KEY_COOLDOWN_PERIOD", default=60))
1019
+
1020
  logger.error(f"Error {status_code} with provider {channel_id}: {error_message}")
1021
  if is_debug:
1022
  import traceback
 
1034
 
1035
  model_handler = ModelRequestHandler()
1036
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1037
  security = HTTPBearer()
1038
 
1039
  async def rate_limit_dependency(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
 
1045
  print("error: Invalid or missing API Key:", token)
1046
  api_index = None
1047
  token = None
1048
+ limit, period = await get_user_rate_limit(app, api_index)
1049
 
1050
  # 使用 IP 地址和 token(如果有)作为限制键
1051
  client_ip = request.client.host
utils.py CHANGED
@@ -3,22 +3,135 @@ from fastapi import HTTPException
3
  import httpx
4
 
5
  from log_config import logger
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from collections import defaultdict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  import asyncio
9
 
10
  class ThreadSafeCircularList:
11
- def __init__(self, items):
12
  self.items = items
13
  self.index = 0
14
  self.lock = asyncio.Lock()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  async def next(self):
17
  async with self.lock:
18
- item = self.items[self.index]
19
- self.index = (self.index + 1) % len(self.items)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  return item
21
 
 
 
 
 
 
 
 
 
22
  def circular_list_encoder(obj):
23
  if isinstance(obj, ThreadSafeCircularList):
24
  return obj.to_dict()
@@ -84,9 +197,15 @@ def update_config(config_data, use_config_url=False):
84
  provider_api = provider.get('api', None)
85
  if provider_api:
86
  if isinstance(provider_api, str):
87
- provider_api_circular_list[provider['provider']] = ThreadSafeCircularList([provider_api])
 
 
 
88
  if isinstance(provider_api, list):
89
- provider_api_circular_list[provider['provider']] = ThreadSafeCircularList(provider_api)
 
 
 
90
 
91
  if not provider.get("model"):
92
  model_list = update_initial_model(provider['base_url'], provider['api'])
 
3
  import httpx
4
 
5
  from log_config import logger
6
+
7
+ import re
8
+ from time import time
9
+ def parse_rate_limit(limit_string):
10
+ # 定义时间单位到秒的映射
11
+ time_units = {
12
+ 's': 1, 'sec': 1, 'second': 1,
13
+ 'm': 60, 'min': 60, 'minute': 60,
14
+ 'h': 3600, 'hr': 3600, 'hour': 3600,
15
+ 'd': 86400, 'day': 86400,
16
+ 'mo': 2592000, 'month': 2592000,
17
+ 'y': 31536000, 'year': 31536000
18
+ }
19
+
20
+ # 使用正则表达式匹配数字和单位
21
+ match = re.match(r'^(\d+)/(\w+)$', limit_string)
22
+ if not match:
23
+ raise ValueError(f"Invalid rate limit format: {limit_string}")
24
+
25
+ count, unit = match.groups()
26
+ count = int(count)
27
+
28
+ # 转换单位到秒
29
+ if unit not in time_units:
30
+ raise ValueError(f"Unknown time unit: {unit}")
31
+
32
+ seconds = time_units[unit]
33
+
34
+ return (count, seconds)
35
+
36
  from collections import defaultdict
37
+ class InMemoryRateLimiter:
38
+ def __init__(self):
39
+ self.requests = defaultdict(list)
40
+
41
+ async def is_rate_limited(self, key: str, limit: int, period: int) -> bool:
42
+ now = time()
43
+ self.requests[key] = [req for req in self.requests[key] if req > now - period]
44
+ if len(self.requests[key]) >= limit:
45
+ return True
46
+ self.requests[key].append(now)
47
+ return False
48
+
49
+ rate_limiter = InMemoryRateLimiter()
50
+
51
+ async def get_user_rate_limit(app, api_index: str = None):
52
+ # 这里应该实现根据 token 获取用户速率限制的逻辑
53
+ # 示例: 返回 (次数, 秒数)
54
+ config = app.state.config
55
+ raw_rate_limit = safe_get(config, 'api_keys', api_index, "preferences", "RATE_LIMIT")
56
+ # print("raw_rate_limit", raw_rate_limit)
57
+ # print("not api_index or not raw_rate_limit", api_index == None, not raw_rate_limit, api_index == None or not raw_rate_limit, api_index, raw_rate_limit)
58
+
59
+ if api_index == None or not raw_rate_limit:
60
+ return (30, 60)
61
+
62
+ rate_limit = parse_rate_limit(raw_rate_limit)
63
+ return rate_limit
64
 
65
  import asyncio
66
 
67
  class ThreadSafeCircularList:
68
+ def __init__(self, items, rate_limit="99999/min"):
69
  self.items = items
70
  self.index = 0
71
  self.lock = asyncio.Lock()
72
+ self.requests = defaultdict(list) # 用于追踪每个 API key 的请求时间
73
+ self.cooling_until = defaultdict(float) # 记录每个 item 的冷却结束时间
74
+ count, period = parse_rate_limit(rate_limit)
75
+ self.rate_limit = count
76
+ self.period = period
77
+
78
+ async def set_cooling(self, item: str, cooling_time: int = 60):
79
+ """设置某个 item 进入冷却状态
80
+
81
+ Args:
82
+ item: 需要冷却的 item
83
+ cooling_time: 冷却时间(秒),默认60秒
84
+ """
85
+ now = time()
86
+ async with self.lock:
87
+ self.cooling_until[item] = now + cooling_time
88
+ # 清空该 item 的请求记录
89
+ self.requests[item] = []
90
+ logger.warning(f"API key {item} 已进入冷却状态,冷却时间 {cooling_time} 秒")
91
+
92
+ async def is_rate_limited(self, item) -> bool:
93
+ now = time()
94
+ # 检查是否在冷却中
95
+ if now < self.cooling_until[item]:
96
+ return True
97
+
98
+ self.requests[item] = [req for req in self.requests[item] if req > now - self.period]
99
+ if len(self.requests[item]) >= self.rate_limit:
100
+ return True
101
+ self.requests[item].append(now)
102
+ return False
103
 
104
  async def next(self):
105
  async with self.lock:
106
+ start_index = self.index
107
+ while True:
108
+ item = self.items[self.index]
109
+ self.index = (self.index + 1) % len(self.items)
110
+
111
+ if not await self.is_rate_limited(item):
112
+ return item
113
+
114
+ logger.warning(f"API key {item} 已达到速率限制 ({self.rate_limit}/{self.period}秒)")
115
+
116
+ # 如果已经检查了所有的 API key 都被限制
117
+ if self.index == start_index:
118
+ logger.warning(f"所有 API key 都已达到速率限制 ({self.rate_limit}/{self.period}秒)")
119
+ return None
120
+
121
+ async def after_next_current(self):
122
+ # 返回当前取出的 API,因为已经调用了 next,所以当前API应该是上一个
123
+ async with self.lock:
124
+ item = self.items[(self.index - 1) % len(self.items)]
125
  return item
126
 
127
+ def get_items_count(self) -> int:
128
+ """返回列表中的项目数量
129
+
130
+ Returns:
131
+ int: items列表的长度
132
+ """
133
+ return len(self.items)
134
+
135
  def circular_list_encoder(obj):
136
  if isinstance(obj, ThreadSafeCircularList):
137
  return obj.to_dict()
 
197
  provider_api = provider.get('api', None)
198
  if provider_api:
199
  if isinstance(provider_api, str):
200
+ provider_api_circular_list[provider['provider']] = ThreadSafeCircularList(
201
+ [provider_api],
202
+ safe_get(provider, "preferences", "API_KEY_RATE_LIMIT", default="999999/min")
203
+ )
204
  if isinstance(provider_api, list):
205
+ provider_api_circular_list[provider['provider']] = ThreadSafeCircularList(
206
+ provider_api,
207
+ safe_get(provider, "preferences", "API_KEY_RATE_LIMIT", default="999999/min")
208
+ )
209
 
210
  if not provider.get("model"):
211
  model_list = update_initial_model(provider['base_url'], provider['api'])