yym68686 commited on
Commit
44f121d
·
1 Parent(s): ba79e57

💻 Code: When there is only one optional provider, cooling will not take effect.

Browse files
Files changed (1) hide show
  1. main.py +3 -3
main.py CHANGED
@@ -868,14 +868,14 @@ async def get_right_order_providers(request_model, config, api_index, scheduling
868
  if not matching_providers:
869
  raise HTTPException(status_code=404, detail="No matching model found")
870
 
871
- if app.state.channel_manager.cooldown_period > 0:
 
872
  matching_providers = await app.state.channel_manager.get_available_providers(matching_providers)
873
  if not matching_providers:
874
  raise HTTPException(status_code=503, detail="No available providers at the moment")
875
 
876
  # 检查是否启用轮询
877
  if scheduling_algorithm == "random":
878
- num_matching_providers = len(matching_providers)
879
  matching_providers = random.sample(matching_providers, num_matching_providers)
880
 
881
  weights = safe_get(config, 'api_keys', api_index, "weights")
@@ -987,7 +987,7 @@ class ModelRequestHandler:
987
  error_message = str(e) or f"Unknown error: {e.__class__.__name__}"
988
 
989
  channel_id = f"{provider['provider']}"
990
- if app.state.channel_manager.cooldown_period > 0:
991
  # 获取源模型名称(实际配置的模型名)
992
  # source_model = list(provider['model'][0].keys())[0]
993
  await app.state.channel_manager.exclude_model(channel_id, request_model)
 
868
  if not matching_providers:
869
  raise HTTPException(status_code=404, detail="No matching model found")
870
 
871
+ num_matching_providers = len(matching_providers)
872
+ if app.state.channel_manager.cooldown_period > 0 and num_matching_providers > 1:
873
  matching_providers = await app.state.channel_manager.get_available_providers(matching_providers)
874
  if not matching_providers:
875
  raise HTTPException(status_code=503, detail="No available providers at the moment")
876
 
877
  # 检查是否启用轮询
878
  if scheduling_algorithm == "random":
 
879
  matching_providers = random.sample(matching_providers, num_matching_providers)
880
 
881
  weights = safe_get(config, 'api_keys', api_index, "weights")
 
987
  error_message = str(e) or f"Unknown error: {e.__class__.__name__}"
988
 
989
  channel_id = f"{provider['provider']}"
990
+ if app.state.channel_manager.cooldown_period > 0 and num_matching_providers > 1:
991
  # 获取源模型名称(实际配置的模型名)
992
  # source_model = list(provider['model'][0].keys())[0]
993
  await app.state.channel_manager.exclude_model(channel_id, request_model)