yym68686 commited on
Commit
3e22716
·
1 Parent(s): b21d1ea

Support downloading configuration files from external URLs.

Browse files
Files changed (2) hide show
  1. main.py +25 -20
  2. utils.py +35 -21
main.py CHANGED
@@ -4,38 +4,25 @@ import httpx
4
  import secrets
5
  from contextlib import asynccontextmanager
6
 
 
7
  from fastapi import FastAPI, HTTPException, Depends
 
 
8
 
9
  from models import RequestModel
10
- from utils import config, api_keys_db, api_list, error_handling_wrapper, get_all_models, verify_api_key, post_all_models, update_config
11
  from request import get_payload
12
  from response import fetch_response, fetch_response_stream
13
 
14
  from typing import List, Dict
15
  from urllib.parse import urlparse
16
- from fastapi.responses import StreamingResponse, JSONResponse
17
- from fastapi.middleware.cors import CORSMiddleware
18
 
19
  @asynccontextmanager
20
  async def lifespan(app: FastAPI):
21
  # 启动时的代码
22
  timeout = httpx.Timeout(connect=15.0, read=10.0, write=30.0, pool=30.0)
23
  app.state.client = httpx.AsyncClient(timeout=timeout)
24
- import os
25
- import yaml
26
- # 新增: 从环境变量获取配置URL并拉取配置
27
- config_url = os.environ.get('CONFIG_URL')
28
- if config_url:
29
- try:
30
- response = await app.state.client.get(config_url)
31
- response.raise_for_status()
32
- config_data = yaml.safe_load(response.text)
33
- # 更新配置
34
- global config, api_keys_db, api_list
35
- config, api_keys_db, api_list = update_config(config_data)
36
- except Exception as e:
37
- logger.error(f"Error fetching or parsing config from {config_url}: {str(e)}")
38
-
39
  yield
40
  # 关闭时的代码
41
  await app.state.client.aclose()
@@ -96,6 +83,10 @@ class ModelRequestHandler:
96
  self.last_provider_index = -1
97
 
98
  def get_matching_providers(self, model_name, token):
 
 
 
 
99
  api_index = api_list.index(token)
100
  provider_rules = []
101
 
@@ -127,6 +118,10 @@ class ModelRequestHandler:
127
  return provider_list
128
 
129
  async def request_model(self, request: RequestModel, token: str):
 
 
 
 
130
  model_name = request.model
131
  matching_providers = self.get_matching_providers(model_name, token)
132
  # print("matching_providers", json.dumps(matching_providers, indent=4, ensure_ascii=False))
@@ -164,6 +159,16 @@ class ModelRequestHandler:
164
 
165
  model_handler = ModelRequestHandler()
166
 
 
 
 
 
 
 
 
 
 
 
167
  @app.post("/v1/chat/completions")
168
  async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
169
  return await model_handler.request_model(request, token)
@@ -174,7 +179,7 @@ async def options_handler():
174
 
175
  @app.post("/v1/models")
176
  async def list_models(token: str = Depends(verify_api_key)):
177
- models = post_all_models(token)
178
  return JSONResponse(content={
179
  "object": "list",
180
  "data": models
@@ -182,7 +187,7 @@ async def list_models(token: str = Depends(verify_api_key)):
182
 
183
  @app.get("/v1/models")
184
  async def list_models():
185
- models = get_all_models()
186
  return JSONResponse(content={
187
  "object": "list",
188
  "data": models
 
4
  import secrets
5
  from contextlib import asynccontextmanager
6
 
7
+ from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi import FastAPI, HTTPException, Depends
9
+ from fastapi.responses import StreamingResponse, JSONResponse
10
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
11
 
12
  from models import RequestModel
13
+ from utils import error_handling_wrapper, get_all_models, post_all_models, load_config
14
  from request import get_payload
15
  from response import fetch_response, fetch_response_stream
16
 
17
  from typing import List, Dict
18
  from urllib.parse import urlparse
 
 
19
 
20
  @asynccontextmanager
21
  async def lifespan(app: FastAPI):
22
  # 启动时的代码
23
  timeout = httpx.Timeout(connect=15.0, read=10.0, write=30.0, pool=30.0)
24
  app.state.client = httpx.AsyncClient(timeout=timeout)
25
+ app.state.config, app.state.api_keys_db, app.state.api_list = await load_config(app)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  yield
27
  # 关闭时的代码
28
  await app.state.client.aclose()
 
83
  self.last_provider_index = -1
84
 
85
  def get_matching_providers(self, model_name, token):
86
+ config = app.state.config
87
+ # api_keys_db = app.state.api_keys_db
88
+ api_list = app.state.api_list
89
+
90
  api_index = api_list.index(token)
91
  provider_rules = []
92
 
 
118
  return provider_list
119
 
120
  async def request_model(self, request: RequestModel, token: str):
121
+ config = app.state.config
122
+ # api_keys_db = app.state.api_keys_db
123
+ api_list = app.state.api_list
124
+
125
  model_name = request.model
126
  matching_providers = self.get_matching_providers(model_name, token)
127
  # print("matching_providers", json.dumps(matching_providers, indent=4, ensure_ascii=False))
 
159
 
160
  model_handler = ModelRequestHandler()
161
 
162
+ # 安全性依赖
163
+ security = HTTPBearer()
164
+
165
+ def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
166
+ api_list = app.state.api_list
167
+ token = credentials.credentials
168
+ if token not in api_list:
169
+ raise HTTPException(status_code=403, detail="Invalid or missing API Key")
170
+ return token
171
+
172
  @app.post("/v1/chat/completions")
173
  async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
174
  return await model_handler.request_model(request, token)
 
179
 
180
  @app.post("/v1/models")
181
  async def list_models(token: str = Depends(verify_api_key)):
182
+ models = post_all_models(token, app.state.config, app.state.api_list)
183
  return JSONResponse(content={
184
  "object": "list",
185
  "data": models
 
187
 
188
  @app.get("/v1/models")
189
  async def list_models():
190
+ models = get_all_models(config=app.state.config)
191
  return JSONResponse(content={
192
  "object": "list",
193
  "data": models
utils.py CHANGED
@@ -1,4 +1,3 @@
1
- import yaml
2
  import json
3
  from fastapi import HTTPException, Depends
4
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
@@ -22,24 +21,48 @@ def update_config(config_data):
22
  return config_data, api_keys_db, api_list
23
 
24
  # 读取YAML配置文件
25
- def load_config():
 
26
  try:
27
  with open('./api.yaml', 'r') as f:
28
  # 判断是否为空文件
29
  conf = yaml.safe_load(f)
 
30
  if conf:
31
- return update_config(conf)
32
  else:
33
- logger.error("配置文件 'api.yaml' 为空。请检查文件内容。")
34
- return [], [], []
35
  except FileNotFoundError:
36
  logger.error("配置文件 'api.yaml' 未找到。请确保文件存在于正确的位置。")
37
- return [], [], []
38
  except yaml.YAMLError:
39
  logger.error("配置文件 'api.yaml' 格式不正确。请检查 YAML 格式。")
40
- return [], [], []
41
-
42
- config, api_keys_db, api_list = load_config()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  def ensure_string(item):
45
  if isinstance(item, (bytes, bytearray)):
@@ -86,7 +109,7 @@ async def error_handling_wrapper(generator, status_code=200):
86
  # 处理生成器为空的情况
87
  return async_generator(["data: {'error': 'No data returned'}\n\n"])
88
 
89
- def post_all_models(token):
90
  all_models = []
91
  unique_models = set()
92
 
@@ -141,7 +164,7 @@ def post_all_models(token):
141
 
142
  return all_models
143
 
144
- def get_all_models():
145
  all_models = []
146
  unique_models = set()
147
 
@@ -157,13 +180,4 @@ def get_all_models():
157
  }
158
  all_models.append(model_info)
159
 
160
- return all_models
161
-
162
- # 安全性依赖
163
- security = HTTPBearer()
164
-
165
- def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
166
- token = credentials.credentials
167
- if token not in api_list:
168
- raise HTTPException(status_code=403, detail="Invalid or missing API Key")
169
- return token
 
 
1
  import json
2
  from fastapi import HTTPException, Depends
3
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 
21
  return config_data, api_keys_db, api_list
22
 
23
  # 读取YAML配置文件
24
+ async def load_config(app):
25
+ import yaml
26
  try:
27
  with open('./api.yaml', 'r') as f:
28
  # 判断是否为空文件
29
  conf = yaml.safe_load(f)
30
+ # conf = None
31
  if conf:
32
+ config, api_keys_db, api_list = update_config(conf)
33
  else:
34
+ # logger.error("配置文件 'api.yaml' 为空。请检查文件内容。")
35
+ config, api_keys_db, api_list = [], [], []
36
  except FileNotFoundError:
37
  logger.error("配置文件 'api.yaml' 未找到。请确保文件存在于正确的位置。")
38
+ config, api_keys_db, api_list = [], [], []
39
  except yaml.YAMLError:
40
  logger.error("配置文件 'api.yaml' 格式不正确。请检查 YAML 格式。")
41
+ config, api_keys_db, api_list = [], [], []
42
+
43
+ if config != []:
44
+ return config, api_keys_db, api_list
45
+
46
+ import os
47
+ # 新增: 从环境变量获取配置URL并拉取配置
48
+ config_url = os.environ.get('CONFIG_URL')
49
+ if config_url:
50
+ try:
51
+ response = await app.state.client.get(config_url)
52
+ # logger.info(f"Fetching config from {response.text}")
53
+ response.raise_for_status()
54
+ config_data = yaml.safe_load(response.text)
55
+ # 更新配置
56
+ # logger.info(config_data)
57
+ if config_data:
58
+ config, api_keys_db, api_list = update_config(config_data)
59
+ else:
60
+ logger.error(f"Error fetching or parsing config from {config_url}")
61
+ config, api_keys_db, api_list = [], [], []
62
+ except Exception as e:
63
+ logger.error(f"Error fetching or parsing config from {config_url}: {str(e)}")
64
+ config, api_keys_db, api_list = [], [], []
65
+ return config, api_keys_db, api_list
66
 
67
  def ensure_string(item):
68
  if isinstance(item, (bytes, bytearray)):
 
109
  # 处理生成器为空的情况
110
  return async_generator(["data: {'error': 'No data returned'}\n\n"])
111
 
112
+ def post_all_models(token, config, api_list):
113
  all_models = []
114
  unique_models = set()
115
 
 
164
 
165
  return all_models
166
 
167
+ def get_all_models(config):
168
  all_models = []
169
  unique_models = set()
170
 
 
180
  }
181
  all_models.append(model_info)
182
 
183
+ return all_models