yym68686 commited on
Commit
b8a7df8
·
1 Parent(s): 9410047

Support API key to filter models through wildcards.

Browse files
Files changed (1) hide show
  1. main.py +44 -4
main.py CHANGED
@@ -3,6 +3,7 @@ import json
3
  import httpx
4
  import logging
5
  import yaml
 
6
  import traceback
7
  from contextlib import asynccontextmanager
8
 
@@ -95,17 +96,17 @@ class ModelRequestHandler:
95
  # if model_name in provider['model'].keys():
96
  # print("provider", provider)
97
  api_index = api_list.index(token)
98
- provider_rules = {}
99
 
100
  for model in config['api_keys'][api_index]['model']:
101
  if "/" in model:
102
  provider_name = model.split("/")[0]
103
  model = model.split("/")[1]
104
- if model_name == model:
105
- provider_rules[provider_name] = model
106
  provider_list = []
107
  for provider in config['providers']:
108
- if model_name in provider['model'].keys() and ((provider_rules != {} and provider['provider'] in provider_rules.keys()) or provider_rules == {}):
109
  provider_list.append(provider)
110
  return provider_list
111
 
@@ -179,6 +180,39 @@ def get_all_models(token):
179
  api_index = api_list.index(token)
180
  if config['api_keys'][api_index]['model']:
181
  for model in config['api_keys'][api_index]['model']:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  if model not in unique_models:
183
  unique_models.add(model)
184
  model_info = {
@@ -210,6 +244,12 @@ async def list_models(token: str = Depends(verify_api_key)):
210
  "object": "list",
211
  "data": models
212
  }
 
 
 
 
 
 
213
  if __name__ == '__main__':
214
  import uvicorn
215
  uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True)
 
3
  import httpx
4
  import logging
5
  import yaml
6
+ import secrets
7
  import traceback
8
  from contextlib import asynccontextmanager
9
 
 
96
  # if model_name in provider['model'].keys():
97
  # print("provider", provider)
98
  api_index = api_list.index(token)
99
+ provider_rules = []
100
 
101
  for model in config['api_keys'][api_index]['model']:
102
  if "/" in model:
103
  provider_name = model.split("/")[0]
104
  model = model.split("/")[1]
105
+ if (model and model_name == model) or (model == "*"):
106
+ provider_rules.append(provider_name)
107
  provider_list = []
108
  for provider in config['providers']:
109
+ if model_name in provider['model'].keys() and ((provider_rules and provider['provider'] in provider_rules) or provider_rules == []):
110
  provider_list.append(provider)
111
  return provider_list
112
 
 
180
  api_index = api_list.index(token)
181
  if config['api_keys'][api_index]['model']:
182
  for model in config['api_keys'][api_index]['model']:
183
+ if "/" in model:
184
+ provider = model.split("/")[0]
185
+ model = model.split("/")[1]
186
+ if model == "*":
187
+ for provider_item in config["providers"]:
188
+ if provider_item['provider'] != provider:
189
+ continue
190
+ for model_item in provider_item['model'].keys():
191
+ if model_item not in unique_models:
192
+ unique_models.add(model_item)
193
+ model_info = {
194
+ "id": model_item,
195
+ "object": "model",
196
+ "created": 1720524448858,
197
+ "owned_by": provider_item['provider']
198
+ }
199
+ all_models.append(model_info)
200
+ else:
201
+ for provider_item in config["providers"]:
202
+ if provider_item['provider'] != provider:
203
+ continue
204
+ if model_item in provider_item['model'].keys() :
205
+ if model_item not in unique_models and model_item != model:
206
+ unique_models.add(model_item)
207
+ model_info = {
208
+ "id": model_item,
209
+ "object": "model",
210
+ "created": 1720524448858,
211
+ "owned_by": provider_item['provider']
212
+ }
213
+ all_models.append(model_info)
214
+ continue
215
+
216
  if model not in unique_models:
217
  unique_models.add(model)
218
  model_info = {
 
244
  "object": "list",
245
  "data": models
246
  }
247
+
248
+ @app.get("/generate-api-key")
249
+ def generate_api_key():
250
+ api_key = "sk-" + secrets.token_urlsafe(32)
251
+ return {"api_key": api_key}
252
+
253
  if __name__ == '__main__':
254
  import uvicorn
255
  uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True)